Author: cito
Date: Sun Aug 14 15:21:31 2011
New Revision: 432
Log:
Support PQescapeLiteral() and PQescapeIdentifier() (ticket #41).
Modified:
trunk/module/pgmodule.c
trunk/module/test_pg.py
Modified: trunk/module/pgmodule.c
==============================================================================
--- trunk/module/pgmodule.c Sun Aug 14 12:25:24 2011 (r431)
+++ trunk/module/pgmodule.c Sun Aug 14 15:21:31 2011 (r432)
@@ -2787,6 +2787,50 @@
return Py_None;
}
+/* escape literal */
+static char pg_escape_literal__doc__[] =
+"pg_escape_literal(str) -- escape a literal constant for use within SQL.";
+
+static PyObject *
+pg_escape_literal(pgobject *self, PyObject *args) {
+ char *str; /* our string argument */
+ int str_length; /* length of string */
+ char *esc; /* the escaped version of the string */
+ PyObject *ret; /* string object to return */
+
+ if (!PyArg_ParseTuple(args, "s#", &str, &str_length))
+ return NULL;
+ esc = PQescapeLiteral(self->cnx, str, (size_t)str_length);
+ ret = Py_BuildValue("s", esc);
+ if (esc)
+ PQfreemem(esc);
+ if (!ret) /* pass on exception */
+ return NULL;
+ return ret;
+}
+
+/* escape identifier */
+static char pg_escape_identifier__doc__[] =
+"pg_escape_identifier(str) -- escape an identifier for use within SQL.";
+
+static PyObject *
+pg_escape_identifier(pgobject *self, PyObject *args) {
+ char *str; /* our string argument */
+ int str_length; /* length of string */
+ char *esc; /* the escaped version of the string */
+ PyObject *ret; /* string object to return */
+
+ if (!PyArg_ParseTuple(args, "s#", &str, &str_length))
+ return NULL;
+ esc = PQescapeIdentifier(self->cnx, str, (size_t)str_length);
+ ret = Py_BuildValue("s", esc);
+ if (esc)
+ PQfreemem(esc);
+ if (!ret) /* pass on exception */
+ return NULL;
+ return ret;
+}
+
/* escape string */
static char pg_escape_string__doc__[] =
"pg_escape_string(str) -- escape a string for use within SQL.";
@@ -2953,6 +2997,10 @@
pg_transaction__doc__},
{"parameter", (PyCFunction) pg_parameter, METH_VARARGS,
pg_parameter__doc__},
+ {"escape_literal", (PyCFunction) pg_escape_literal, METH_VARARGS,
+ pg_escape_literal__doc__},
+ {"escape_identifier", (PyCFunction) pg_escape_identifier, METH_VARARGS,
+ pg_escape_identifier__doc__},
{"escape_string", (PyCFunction) pg_escape_string, METH_VARARGS,
pg_escape_string__doc__},
{"escape_bytea", (PyCFunction) pg_escape_bytea, METH_VARARGS,
Modified: trunk/module/test_pg.py
==============================================================================
--- trunk/module/test_pg.py Sun Aug 14 12:25:24 2011 (r431)
+++ trunk/module/test_pg.py Sun Aug 14 15:21:31 2011 (r432)
@@ -363,9 +363,10 @@
self.assertEqual(attributes, connection_attributes)
def testAllConnectMethods(self):
- methods = '''cancel close endcopy escape_bytea escape_string fileno
- getline getlo getnotify inserttable locreate loimport parameter
- putline query reset source transaction'''.split()
+ methods = '''cancel close endcopy
+ escape_bytea escape_identifier escape_literal escape_string
+ fileno getline getlo getnotify inserttable locreate loimport
+ parameter putline query reset source transaction'''.split()
connection_methods = [a for a in dir(self.connection)
if callable(eval("self.connection." + a))]
self.assertEqual(methods, connection_methods)
@@ -397,7 +398,7 @@
def testAttributeServerVersion(self):
server_version = self.connection.server_version
self.assert_(isinstance(server_version, int))
- self.assert_(70400 <= server_version < 90000)
+ self.assert_(70400 <= server_version < 100000)
def testAttributeStatus(self):
status_ok = 1
@@ -598,11 +599,11 @@
r = filter(bool, open(t, 'r').read().splitlines())
os.remove(t)
self.assertEqual(r,
- ['a|h |world',
- '-+-----+-----',
- '1|hello|w ',
- '2|xyz |uvw ',
- '(2 rows)'])
+ ['a| h |world',
+ '-+-----+-----',
+ '1|hello|w ',
+ '2|xyz |uvw ',
+ '(2 rows)'])
def testGetNotify(self):
self.assert_(self.c.getnotify() is None)
@@ -693,12 +694,12 @@
def testAllDBAttributes(self):
attributes = '''cancel clear close db dbname debug delete endcopy
- error escape_bytea escape_string fileno get get_attnames
- get_databases get_relations get_tables getline getlo getnotify
- has_table_privilege host insert inserttable locreate loimport
- options parameter pkey port protocol_version putline query
- reopen reset server_version source status transaction tty
- unescape_bytea update user'''.split()
+ error escape_bytea escape_identifier escape_literal escape_string
+ fileno get get_attnames get_databases get_relations get_tables
+ getline getlo getnotify has_table_privilege host insert inserttable
+ locreate loimport options parameter pkey port protocol_version
+ putline query reopen reset server_version source status transaction
+ tty unescape_bytea update user'''.split()
db_attributes = [a for a in dir(self.db)
if not a.startswith('_')]
self.assertEqual(attributes, db_attributes)
@@ -741,7 +742,7 @@
def testAttributeServerVersion(self):
server_version = self.db.server_version
self.assert_(isinstance(server_version, int))
- self.assert_(70400 <= server_version < 90000)
+ self.assert_(70400 <= server_version < 100000)
self.assertEqual(server_version, self.db.db.server_version)
def testAttributeStatus(self):
@@ -763,6 +764,28 @@
self.assertNotEqual(user, no_user)
self.assertEqual(user, self.db.db.user)
+ def testMethodEscapeLiteral(self):
+ self.assertEqual(self.db.escape_literal("plain"), "'plain'")
+ self.assertEqual(self.db.escape_literal(
+ "that's k\xe4se"), "'that''s k\xe4se'")
+ self.assertEqual(self.db.escape_literal(
+ r"It's fine to have a \ inside."),
+ r" E'It''s fine to have a \\ inside.'")
+ self.assertEqual(self.db.escape_literal(
+ 'No "quotes" must be escaped.'),
+ "'No \"quotes\" must be escaped.'")
+
+ def testMethodEscapeIdentifier(self):
+ self.assertEqual(self.db.escape_identifier("plain"), '"plain"')
+ self.assertEqual(self.db.escape_identifier(
+ "that's k\xe4se"), '"that\'s k\xe4se"')
+ self.assertEqual(self.db.escape_identifier(
+ r"It's fine to have a \ inside."),
+ '"It\'s fine to have a \\ inside."')
+ self.assertEqual(self.db.escape_identifier(
+ 'All "quotes" must be escaped.'),
+ '"All ""quotes"" must be escaped."')
+
def testMethodEscapeString(self):
self.assertEqual(self.db.escape_string("plain"), "plain")
self.assertEqual(self.db.escape_string(
@@ -772,18 +795,42 @@
r"It''s fine to have a \\ inside.")
def testMethodEscapeBytea(self):
- self.assertEqual(self.db.escape_bytea("plain"), "plain")
- self.assertEqual(self.db.escape_bytea(
- "that's k\xe4se"), "that''s k\\\\344se")
- self.assertEqual(self.db.escape_bytea(
- 'O\x00ps\xff!'), r'O\\000ps\\377!')
+ output = self.db.query("show bytea_output").getresult()[0][0]
+ self.assert_(output in ('escape', 'hex'))
+ if output == 'escape':
+ self.assertEqual(self.db.escape_bytea("plain"), "plain")
+ self.assertEqual(self.db.escape_bytea(
+ "that's k\xe4se"), "that''s k\\\\344se")
+ self.assertEqual(self.db.escape_bytea(
+ 'O\x00ps\xff!'), r'O\\000ps\\377!')
+ else:
+ self.assertEqual(self.db.escape_bytea("plain"), r"\\x706c61696e")
+ self.assertEqual(self.db.escape_bytea(
+ "that's k\xe4se"), r"\\x746861742773206be47365")
+ self.assertEqual(self.db.escape_bytea(
+ 'O\x00ps\xff!'), r"\\x4f007073ff21")
def testMethodUnescapeBytea(self):
+ standard_conforming = self.db.query(
+ "show standard_conforming_strings").getresult()[0][0]
+ self.assert_(standard_conforming in ('on', 'off'))
self.assertEqual(self.db.unescape_bytea("plain"), "plain")
self.assertEqual(self.db.unescape_bytea(
"that's k\\344se"), "that's k\xe4se")
self.assertEqual(pg.unescape_bytea(
r'O\000ps\377!'), 'O\x00ps\xff!')
+ if standard_conforming == 'on':
+ self.assertEqual(self.db.unescape_bytea(r"\\x706c61696e"), "plain")
+ self.assertEqual(self.db.unescape_bytea(
+ r"\\x746861742773206be47365"), "that's k\xe4se")
+ self.assertEqual(pg.unescape_bytea(
+ r"\\x4f007073ff21"), 'O\x00ps\xff!')
+ else:
+ self.assertEqual(self.db.unescape_bytea(r"\x706c61696e"), "plain")
+ self.assertEqual(self.db.unescape_bytea(
+ r"\x746861742773206be47365"), "that's k\xe4se")
+ self.assertEqual(pg.unescape_bytea(
+ r"\x4f007073ff21"), 'O\x00ps\xff!')
def testMethodQuery(self):
self.db.query("select 1+1")
@@ -851,18 +898,42 @@
r"It''s fine to have a \\ inside.")
def testEscapeBytea(self):
- self.assertEqual(self.db.escape_bytea("plain"), "plain")
- self.assertEqual(self.db.escape_bytea(
- "that's k\xe4se"), "that''s k\\\\344se")
- self.assertEqual(self.db.escape_bytea(
- 'O\x00ps\xff!'), r'O\\000ps\\377!')
+ output = self.db.query("show bytea_output").getresult()[0][0]
+ self.assert_(output in ('escape', 'hex'))
+ if output == 'escape':
+ self.assertEqual(self.db.escape_bytea("plain"), "plain")
+ self.assertEqual(self.db.escape_bytea(
+ "that's k\xe4se"), "that''s k\\\\344se")
+ self.assertEqual(self.db.escape_bytea(
+ 'O\x00ps\xff!'), r'O\\000ps\\377!')
+ else:
+ self.assertEqual(self.db.escape_bytea("plain"), r"\\x706c61696e")
+ self.assertEqual(self.db.escape_bytea(
+ "that's k\xe4se"), r"\\x746861742773206be47365")
+ self.assertEqual(self.db.escape_bytea(
+ 'O\x00ps\xff!'), r"\\x4f007073ff21")
def testUnescapeBytea(self):
+ standard_conforming = self.db.query(
+ "show standard_conforming_strings").getresult()[0][0]
+ self.assert_(standard_conforming in ('on', 'off'))
self.assertEqual(self.db.unescape_bytea("plain"), "plain")
self.assertEqual(self.db.unescape_bytea(
"that's k\\344se"), "that's k\xe4se")
self.assertEqual(pg.unescape_bytea(
r'O\000ps\377!'), 'O\x00ps\xff!')
+ if standard_conforming == 'on':
+ self.assertEqual(self.db.unescape_bytea(r"\\x706c61696e"), "plain")
+ self.assertEqual(self.db.unescape_bytea(
+ r"\\x746861742773206be47365"), "that's k\xe4se")
+ self.assertEqual(pg.unescape_bytea(
+ r"\\x4f007073ff21"), 'O\x00ps\xff!')
+ else:
+ self.assertEqual(self.db.unescape_bytea(r"\x706c61696e"), "plain")
+ self.assertEqual(self.db.unescape_bytea(
+ r"\x746861742773206be47365"), "that's k\xe4se")
+ self.assertEqual(pg.unescape_bytea(
+ r"\x4f007073ff21"), 'O\x00ps\xff!')
def testQuote(self):
f = self.db._quote
_______________________________________________
PyGreSQL mailing list
[email protected]
http://mailman.vex.net/mailman/listinfo/pygresql