Author: cito
Date: Sun Aug 14 15:21:31 2011
New Revision: 432

Log:
Support PQescapeLiteral() and PQescapeIdentifier() (ticket #41).

Modified:
   trunk/module/pgmodule.c
   trunk/module/test_pg.py

Modified: trunk/module/pgmodule.c
==============================================================================
--- trunk/module/pgmodule.c     Sun Aug 14 12:25:24 2011        (r431)
+++ trunk/module/pgmodule.c     Sun Aug 14 15:21:31 2011        (r432)
@@ -2787,6 +2787,50 @@
        return Py_None;
 }
 
+/* escape literal */
+static char pg_escape_literal__doc__[] =
+"pg_escape_literal(str) -- escape a literal constant for use within SQL.";
+
+static PyObject *
+pg_escape_literal(pgobject *self, PyObject *args) {
+       char *str; /* our string argument */
+       int str_length; /* length of string */
+       char *esc; /* the escaped version of the string */
+       PyObject *ret; /* string object to return */
+
+       if (!PyArg_ParseTuple(args, "s#", &str, &str_length))
+               return NULL;
+       esc = PQescapeLiteral(self->cnx, str, (size_t)str_length);
+       ret = Py_BuildValue("s", esc);
+       if (esc)
+               PQfreemem(esc);
+       if (!ret) /* pass on exception */
+               return NULL;
+       return ret;
+}
+
+/* escape identifier */
+static char pg_escape_identifier__doc__[] =
+"pg_escape_identifier(str) -- escape an identifier for use within SQL.";
+
+static PyObject *
+pg_escape_identifier(pgobject *self, PyObject *args) {
+       char *str; /* our string argument */
+       int str_length; /* length of string */
+       char *esc; /* the escaped version of the string */
+       PyObject *ret; /* string object to return */
+
+       if (!PyArg_ParseTuple(args, "s#", &str, &str_length))
+               return NULL;
+       esc = PQescapeIdentifier(self->cnx, str, (size_t)str_length);
+       ret = Py_BuildValue("s", esc);
+       if (esc)
+               PQfreemem(esc);
+       if (!ret) /* pass on exception */
+               return NULL;
+       return ret;
+}
+
 /* escape string */
 static char pg_escape_string__doc__[] =
 "pg_escape_string(str) -- escape a string for use within SQL.";
@@ -2953,6 +2997,10 @@
                        pg_transaction__doc__},
        {"parameter", (PyCFunction) pg_parameter, METH_VARARGS,
                        pg_parameter__doc__},
+       {"escape_literal", (PyCFunction) pg_escape_literal, METH_VARARGS,
+                       pg_escape_literal__doc__},
+       {"escape_identifier", (PyCFunction) pg_escape_identifier, METH_VARARGS,
+                       pg_escape_identifier__doc__},
        {"escape_string", (PyCFunction) pg_escape_string, METH_VARARGS,
                        pg_escape_string__doc__},
        {"escape_bytea", (PyCFunction) pg_escape_bytea, METH_VARARGS,

Modified: trunk/module/test_pg.py
==============================================================================
--- trunk/module/test_pg.py     Sun Aug 14 12:25:24 2011        (r431)
+++ trunk/module/test_pg.py     Sun Aug 14 15:21:31 2011        (r432)
@@ -363,9 +363,10 @@
         self.assertEqual(attributes, connection_attributes)
 
     def testAllConnectMethods(self):
-        methods = '''cancel close endcopy escape_bytea escape_string fileno
-            getline getlo getnotify inserttable locreate loimport parameter
-            putline query reset source transaction'''.split()
+        methods = '''cancel close endcopy
+            escape_bytea escape_identifier escape_literal escape_string
+            fileno getline getlo getnotify inserttable locreate loimport
+            parameter putline query reset source transaction'''.split()
         connection_methods = [a for a in dir(self.connection)
             if callable(eval("self.connection." + a))]
         self.assertEqual(methods, connection_methods)
@@ -397,7 +398,7 @@
     def testAttributeServerVersion(self):
         server_version = self.connection.server_version
         self.assert_(isinstance(server_version, int))
-        self.assert_(70400 <= server_version < 90000)
+        self.assert_(70400 <= server_version < 100000)
 
     def testAttributeStatus(self):
         status_ok = 1
@@ -598,11 +599,11 @@
         r = filter(bool, open(t, 'r').read().splitlines())
         os.remove(t)
         self.assertEqual(r,
-            ['a|h    |world',
-            '-+-----+-----',
-            '1|hello|w    ',
-            '2|xyz  |uvw  ',
-            '(2 rows)'])
+            ['a|  h  |world',
+             '-+-----+-----',
+             '1|hello|w    ',
+             '2|xyz  |uvw  ',
+             '(2 rows)'])
 
     def testGetNotify(self):
         self.assert_(self.c.getnotify() is None)
@@ -693,12 +694,12 @@
 
     def testAllDBAttributes(self):
         attributes = '''cancel clear close db dbname debug delete endcopy
-            error escape_bytea escape_string fileno  get get_attnames
-            get_databases get_relations get_tables getline getlo getnotify
-            has_table_privilege host insert inserttable locreate loimport
-            options parameter pkey port protocol_version putline query
-            reopen reset server_version source status transaction tty
-            unescape_bytea update user'''.split()
+            error escape_bytea escape_identifier escape_literal escape_string
+            fileno get get_attnames get_databases get_relations get_tables
+            getline getlo getnotify has_table_privilege host insert inserttable
+            locreate loimport options parameter pkey port protocol_version
+            putline query reopen reset server_version source status transaction
+            tty unescape_bytea update user'''.split()
         db_attributes = [a for a in dir(self.db)
             if not a.startswith('_')]
         self.assertEqual(attributes, db_attributes)
@@ -741,7 +742,7 @@
     def testAttributeServerVersion(self):
         server_version = self.db.server_version
         self.assert_(isinstance(server_version, int))
-        self.assert_(70400 <= server_version < 90000)
+        self.assert_(70400 <= server_version < 100000)
         self.assertEqual(server_version, self.db.db.server_version)
 
     def testAttributeStatus(self):
@@ -763,6 +764,28 @@
         self.assertNotEqual(user, no_user)
         self.assertEqual(user, self.db.db.user)
 
+    def testMethodEscapeLiteral(self):
+        self.assertEqual(self.db.escape_literal("plain"), "'plain'")
+        self.assertEqual(self.db.escape_literal(
+            "that's k\xe4se"), "'that''s k\xe4se'")
+        self.assertEqual(self.db.escape_literal(
+            r"It's fine to have a \ inside."),
+            r" E'It''s fine to have a \\ inside.'")
+        self.assertEqual(self.db.escape_literal(
+            'No "quotes" must be escaped.'),
+            "'No \"quotes\" must be escaped.'")
+
+    def testMethodEscapeIdentifier(self):
+        self.assertEqual(self.db.escape_identifier("plain"), '"plain"')
+        self.assertEqual(self.db.escape_identifier(
+            "that's k\xe4se"), '"that\'s k\xe4se"')
+        self.assertEqual(self.db.escape_identifier(
+            r"It's fine to have a \ inside."),
+            '"It\'s fine to have a \\ inside."')
+        self.assertEqual(self.db.escape_identifier(
+            'All "quotes" must be escaped.'),
+            '"All ""quotes"" must be escaped."')
+
     def testMethodEscapeString(self):
         self.assertEqual(self.db.escape_string("plain"), "plain")
         self.assertEqual(self.db.escape_string(
@@ -772,18 +795,42 @@
             r"It''s fine to have a \\ inside.")
 
     def testMethodEscapeBytea(self):
-        self.assertEqual(self.db.escape_bytea("plain"), "plain")
-        self.assertEqual(self.db.escape_bytea(
-            "that's k\xe4se"), "that''s k\\\\344se")
-        self.assertEqual(self.db.escape_bytea(
-            'O\x00ps\xff!'), r'O\\000ps\\377!')
+        output = self.db.query("show bytea_output").getresult()[0][0]
+        self.assert_(output in ('escape', 'hex'))
+        if output == 'escape':
+            self.assertEqual(self.db.escape_bytea("plain"), "plain")
+            self.assertEqual(self.db.escape_bytea(
+                "that's k\xe4se"), "that''s k\\\\344se")
+            self.assertEqual(self.db.escape_bytea(
+                'O\x00ps\xff!'), r'O\\000ps\\377!')
+        else:
+            self.assertEqual(self.db.escape_bytea("plain"), r"\\x706c61696e")
+            self.assertEqual(self.db.escape_bytea(
+                "that's k\xe4se"), r"\\x746861742773206be47365")
+            self.assertEqual(self.db.escape_bytea(
+                'O\x00ps\xff!'), r"\\x4f007073ff21")
 
     def testMethodUnescapeBytea(self):
+        standard_conforming = self.db.query(
+            "show standard_conforming_strings").getresult()[0][0]
+        self.assert_(standard_conforming in ('on', 'off'))
         self.assertEqual(self.db.unescape_bytea("plain"), "plain")
         self.assertEqual(self.db.unescape_bytea(
             "that's k\\344se"), "that's k\xe4se")
         self.assertEqual(pg.unescape_bytea(
             r'O\000ps\377!'), 'O\x00ps\xff!')
+        if standard_conforming == 'on':
+            self.assertEqual(self.db.unescape_bytea(r"\\x706c61696e"), "plain")
+            self.assertEqual(self.db.unescape_bytea(
+                r"\\x746861742773206be47365"), "that's k\xe4se")
+            self.assertEqual(pg.unescape_bytea(
+                r"\\x4f007073ff21"), 'O\x00ps\xff!')
+        else:
+            self.assertEqual(self.db.unescape_bytea(r"\x706c61696e"), "plain")
+            self.assertEqual(self.db.unescape_bytea(
+                r"\x746861742773206be47365"), "that's k\xe4se")
+            self.assertEqual(pg.unescape_bytea(
+                r"\x4f007073ff21"), 'O\x00ps\xff!')
 
     def testMethodQuery(self):
         self.db.query("select 1+1")
@@ -851,18 +898,42 @@
             r"It''s fine to have a \\ inside.")
 
     def testEscapeBytea(self):
-        self.assertEqual(self.db.escape_bytea("plain"), "plain")
-        self.assertEqual(self.db.escape_bytea(
-            "that's k\xe4se"), "that''s k\\\\344se")
-        self.assertEqual(self.db.escape_bytea(
-            'O\x00ps\xff!'), r'O\\000ps\\377!')
+        output = self.db.query("show bytea_output").getresult()[0][0]
+        self.assert_(output in ('escape', 'hex'))
+        if output == 'escape':
+            self.assertEqual(self.db.escape_bytea("plain"), "plain")
+            self.assertEqual(self.db.escape_bytea(
+                "that's k\xe4se"), "that''s k\\\\344se")
+            self.assertEqual(self.db.escape_bytea(
+                'O\x00ps\xff!'), r'O\\000ps\\377!')
+        else:
+            self.assertEqual(self.db.escape_bytea("plain"), r"\\x706c61696e")
+            self.assertEqual(self.db.escape_bytea(
+                "that's k\xe4se"), r"\\x746861742773206be47365")
+            self.assertEqual(self.db.escape_bytea(
+                'O\x00ps\xff!'), r"\\x4f007073ff21")
 
     def testUnescapeBytea(self):
+        standard_conforming = self.db.query(
+            "show standard_conforming_strings").getresult()[0][0]
+        self.assert_(standard_conforming in ('on', 'off'))
         self.assertEqual(self.db.unescape_bytea("plain"), "plain")
         self.assertEqual(self.db.unescape_bytea(
             "that's k\\344se"), "that's k\xe4se")
         self.assertEqual(pg.unescape_bytea(
             r'O\000ps\377!'), 'O\x00ps\xff!')
+        if standard_conforming == 'on':
+            self.assertEqual(self.db.unescape_bytea(r"\\x706c61696e"), "plain")
+            self.assertEqual(self.db.unescape_bytea(
+                r"\\x746861742773206be47365"), "that's k\xe4se")
+            self.assertEqual(pg.unescape_bytea(
+                r"\\x4f007073ff21"), 'O\x00ps\xff!')
+        else:
+            self.assertEqual(self.db.unescape_bytea(r"\x706c61696e"), "plain")
+            self.assertEqual(self.db.unescape_bytea(
+                r"\x746861742773206be47365"), "that's k\xe4se")
+            self.assertEqual(pg.unescape_bytea(
+                r"\x4f007073ff21"), 'O\x00ps\xff!')
 
     def testQuote(self):
         f = self.db._quote
_______________________________________________
PyGreSQL mailing list
[email protected]
http://mailman.vex.net/mailman/listinfo/pygresql

Reply via email to