Author: cito
Date: Wed Nov 25 12:48:37 2015
New Revision: 628

Log:
Accept non-ascii queries passed as unicode

So far, queries needed to be properly encoded. Now the query method
can do this automatically if it gets a query passed in as unicode.

Also, added missing code and tests for retrieving unicode values
with dictresult() under Python 3.

Modified:
   trunk/module/TEST_PyGreSQL_classic_connection.py
   trunk/module/pgmodule.c

Modified: trunk/module/TEST_PyGreSQL_classic_connection.py
==============================================================================
--- trunk/module/TEST_PyGreSQL_classic_connection.py    Wed Nov 25 10:09:41 
2015        (r627)
+++ trunk/module/TEST_PyGreSQL_classic_connection.py    Wed Nov 25 12:48:37 
2015        (r628)
@@ -541,6 +541,167 @@
         query("drop table test_table")
 
 
+class TestUnicodeQueries(unittest.TestCase):
+    """"Test unicode strings as queries via a basic pg connection."""
+
+    def setUp(self):
+        self.c = connect()
+        self.c.query('set client_encoding=utf8')
+
+    def tearDown(self):
+        self.c.close()
+
+    def testGetresulAscii(self):
+        result = u'Hello, world!'
+        q = u"select '%s'" % result
+        v = self.c.query(q).getresult()[0][0]
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+
+    def testDictresulAscii(self):
+        result = u'Hello, world!'
+        q = u"select '%s' as greeting" % result
+        v = self.c.query(q).dictresult()[0]['greeting']
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+
+    def testGetresultUtf8(self):
+        result = u'Hello, wörld & мир!'
+        q = u"select '%s'" % result
+        if not unicode_strings:
+            result = result.encode('utf8')
+        # pass the query as unicode
+        try:
+            v = self.c.query(q).getresult()[0][0]
+        except pg.ProgrammingError:
+            self.skipTest("database does not support utf8")
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+        q = q.encode('utf8')
+        # pass the query as bytes
+        v = self.c.query(q).getresult()[0][0]
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+
+    def testDictresultUtf8(self):
+        result = u'Hello, wörld & мир!'
+        q = u"select '%s' as greeting" % result
+        if not unicode_strings:
+            result = result.encode('utf8')
+        try:
+            v = self.c.query(q).dictresult()[0]['greeting']
+        except pg.ProgrammingError:
+            self.skipTest("database does not support utf8")
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+        q = q.encode('utf8')
+        v = self.c.query(q).dictresult()[0]['greeting']
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+
+    def testDictresultLatin1(self):
+        try:
+            self.c.query('set client_encoding=latin1')
+        except pg.ProgrammingError:
+            self.skipTest("database does not support latin1")
+        result = u'Hello, wörld!'
+        q = u"select '%s'" % result
+        if not unicode_strings:
+            result = result.encode('latin1')
+        v = self.c.query(q).getresult()[0][0]
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+        q = q.encode('latin1')
+        v = self.c.query(q).getresult()[0][0]
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+
+    def testDictresultLatin1(self):
+        try:
+            self.c.query('set client_encoding=latin1')
+        except pg.ProgrammingError:
+            self.skipTest("database does not support latin1")
+        result = u'Hello, wörld!'
+        q = u"select '%s' as greeting" % result
+        if not unicode_strings:
+            result = result.encode('latin1')
+        v = self.c.query(q).dictresult()[0]['greeting']
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+        q = q.encode('latin1')
+        v = self.c.query(q).dictresult()[0]['greeting']
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+
+    def testGetresultCyrillic(self):
+        try:
+            self.c.query('set client_encoding=iso_8859_5')
+        except pg.ProgrammingError:
+            self.skipTest("database does not support cyrillic")
+        result = u'Hello, мир!'
+        q = u"select '%s'" % result
+        if not unicode_strings:
+            result = result.encode('cyrillic')
+        v = self.c.query(q).getresult()[0][0]
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+        q = q.encode('cyrillic')
+        v = self.c.query(q).getresult()[0][0]
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+
+    def testDictresultCyrillic(self):
+        try:
+            self.c.query('set client_encoding=iso_8859_5')
+        except pg.ProgrammingError:
+            self.skipTest("database does not support cyrillic")
+        result = u'Hello, мир!'
+        q = u"select '%s' as greeting" % result
+        if not unicode_strings:
+            result = result.encode('cyrillic')
+        v = self.c.query(q).dictresult()[0]['greeting']
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+        q = q.encode('cyrillic')
+        v = self.c.query(q).dictresult()[0]['greeting']
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+
+    def testGetresultLatin9(self):
+        try:
+            self.c.query('set client_encoding=latin9')
+        except pg.ProgrammingError:
+            self.skipTest("database does not support latin9")
+        result = u'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)'
+        q = u"select '%s'" % result
+        if not unicode_strings:
+            result = result.encode('latin9')
+        v = self.c.query(q).getresult()[0][0]
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+        q = q.encode('latin9')
+        v = self.c.query(q).getresult()[0][0]
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+
+    def testDictresultLatin9(self):
+        try:
+            self.c.query('set client_encoding=latin9')
+        except pg.ProgrammingError:
+            self.skipTest("database does not support latin9")
+        result = u'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)'
+        q = u"select '%s' as menu" % result
+        if not unicode_strings:
+            result = result.encode('latin9')
+        v = self.c.query(q).dictresult()[0]['menu']
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+        q = q.encode('latin9')
+        v = self.c.query(q).dictresult()[0]['menu']
+        self.assertIsInstance(v, str)
+        self.assertEqual(v, result)
+
+
 class TestParamQueries(unittest.TestCase):
     """"Test queries with parameters via a basic pg connection."""
 
@@ -694,15 +855,6 @@
         self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
             ).dictresult(), [{'garbage': garbage}])
 
-    def testUnicodeQuery(self):
-        query = self.c.query
-        self.assertEqual(query(u"select 1+1").getresult(), [(2,)])
-        if unicode_strings:
-            self.assertEqual(query("select 'Hello, wörld!'").getresult(),
-                [('Hello, wörld!',)])
-        else:
-            self.assertRaises(TypeError, query, u"select 'Hello, wörld!'")
-
 
 class TestInserttable(unittest.TestCase):
     """"Test inserttable method."""

Modified: trunk/module/pgmodule.c
==============================================================================
--- trunk/module/pgmodule.c     Wed Nov 25 10:09:41 2015        (r627)
+++ trunk/module/pgmodule.c     Wed Nov 25 12:48:37 2015        (r628)
@@ -1073,8 +1073,9 @@
 static PyObject *
 connQuery(connObject *self, PyObject *args)
 {
-       char            *query;
+       PyObject        *query_obj;
        PyObject        *oargs = NULL;
+       char            *query = NULL;
        PGresult        *result;
        queryObject *npgobj;
        const char*     encoding_name=NULL;
@@ -1089,9 +1090,38 @@
        }
 
        /* get query args */
-       if (!PyArg_ParseTuple(args, "s|O", &query, &oargs))
+       if (!PyArg_ParseTuple(args, "O|O", &query_obj, &oargs))
+       {
+               return NULL;
+       }
+
+       encoding = PQclientEncoding(self->cnx);
+       if (encoding != pg_encoding_utf8 && encoding != pg_encoding_latin1
+                       && encoding != pg_encoding_ascii)
+               /* should be translated to Python here */
+               encoding_name = pg_encoding_to_char(encoding);
+
+       if (PyBytes_Check(query_obj))
+       {
+               query = PyBytes_AsString(query_obj);
+       }
+       else if (PyUnicode_Check(query_obj))
        {
-               PyErr_SetString(PyExc_TypeError, "query(sql, [args]), with sql 
(string).");
+               if (encoding == pg_encoding_utf8)
+                       query_obj = PyUnicode_AsUTF8String(query_obj);
+               else if (encoding == pg_encoding_latin1)
+                       query_obj = PyUnicode_AsLatin1String(query_obj);
+               else if (encoding == pg_encoding_ascii)
+                       query_obj = PyUnicode_AsASCIIString(query_obj);
+               else
+                       query_obj = PyUnicode_AsEncodedString(query_obj,
+                               encoding_name, "strict");
+               if (!query_obj) return NULL; /* pass the UnicodeEncodeError */
+               query = PyBytes_AsString(query_obj);
+       }
+       if (!query) {
+               PyErr_SetString(PyExc_TypeError,
+                       "query command must be a string.");
                return NULL;
        }
 
@@ -1111,12 +1141,6 @@
                nparms = (int)PySequence_Size(oargs);
        }
 
-       encoding = PQclientEncoding(self->cnx);
-       if (encoding != pg_encoding_utf8 && encoding != pg_encoding_latin1
-                       && encoding != pg_encoding_ascii)
-               /* should be translated to Python here */
-               encoding_name = pg_encoding_to_char(encoding);
-
        /* gets result */
        if (nparms)
        {
@@ -3285,6 +3309,7 @@
                                        case 5:  /* money */
                                                /* convert to decimal only if 
decimal point is set */
                                                if (!decimal_point) goto 
default_case;
+
                                                for (k = 0;
                                                         *s && k < 
sizeof(cashbuf) / sizeof(cashbuf[0]) - 1;
                                                         s++)
@@ -3376,6 +3401,10 @@
                                m,
                                n,
                           *typ;
+#if IS_PY3
+       int                     encoding;
+       const char *encoding_name=NULL;
+#endif
 
        /* checks args (args == NULL for an internal call) */
        if (args && !PyArg_ParseTuple(args, ""))
@@ -3385,6 +3414,14 @@
                return NULL;
        }
 
+#if IS_PY3
+       encoding = self->encoding;
+       if (encoding != pg_encoding_utf8 && encoding != pg_encoding_latin1
+                       && encoding != pg_encoding_ascii)
+               /* should be translated to Python here */
+               encoding_name = pg_encoding_to_char(encoding);
+#endif
+
        /* stores result in list */
        m = PQntuples(self->result);
        n = PQnfields(self->result);
@@ -3434,7 +3471,7 @@
                                                Py_DECREF(tmp_obj);
                                                break;
 
-                                       case 5:  /* pgmoney */
+                                       case 5:  /* money */
                                                /* convert to decimal only if 
decimal point is set */
                                                if (!decimal_point) goto 
default_case;
 
@@ -3473,7 +3510,21 @@
 
                                        default:
                                        default_case:
-                                               val = PyStr_FromString(s);
+#if IS_PY3
+                                               if (encoding == 
pg_encoding_utf8)
+                                                       val = 
PyUnicode_DecodeUTF8(s, strlen(s), "strict");
+                                               else if (encoding == 
pg_encoding_latin1)
+                                                       val = 
PyUnicode_DecodeLatin1(s, strlen(s), "strict");
+                                               else if (encoding == 
pg_encoding_ascii)
+                                                       val = 
PyUnicode_DecodeASCII(s, strlen(s), "strict");
+                                               else
+                                                       val = 
PyUnicode_Decode(s, strlen(s),
+                                                               encoding_name, 
"strict");
+                                               if (!val)
+                                                       val = 
PyBytes_FromString(s);
+#else
+                                               val = PyBytes_FromString(s);
+#endif
                                                break;
                                }
 
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to