Author: cito
Date: Thu Dec 10 07:59:20 2015
New Revision: 664

Log:
get() should convert bytea to bytes

insert() and update() should also do this with their return values.

Modified:
   trunk/module/pg.py
   trunk/module/tests/test_classic_dbwrapper.py

Modified: trunk/module/pg.py
==============================================================================
--- trunk/module/pg.py  Thu Dec 10 07:40:11 2015        (r663)
+++ trunk/module/pg.py  Thu Dec 10 07:59:20 2015        (r664)
@@ -758,6 +758,7 @@
                 keyname = self.pkey(qcl)
             except KeyError:
                 raise _prg_error('Class %s has no primary key' % qcl)
+        attnames = self.get_attnames(qcl)
         # We want the oid for later updates if that isn't the key
         if keyname == 'oid':
             if isinstance(arg, dict):
@@ -765,26 +766,29 @@
                     raise _db_error('%s not in arg' % qoid)
             else:
                 arg = {qoid: arg}
+            what = '*'
             where = 'oid = %s' % arg[qoid]
-            attnames = '*'
         else:
-            attnames = self.get_attnames(qcl)
             if isinstance(keyname, basestring):
                 keyname = (keyname,)
             if not isinstance(arg, dict):
                 if len(keyname) > 1:
                     raise _prg_error('Composite key needs dict as arg')
                 arg = dict([(k, arg) for k in keyname])
+            what = ', '.join(attnames)
             where = ' AND '.join(['%s = %s'
                 % (k, self._quote(arg[k], attnames[k])) for k in keyname])
-            attnames = ', '.join(attnames)
-        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
+        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (what, qcl, where)
         self._do_debug(q)
         res = self.db.query(q).dictresult()
         if not res:
             raise _db_error('No such record in %s where %s' % (qcl, where))
-        for att, value in res[0].items():
-            arg[qoid if att == 'oid' else att] = value
+        for n, value in res[0].items():
+            if n == 'oid':
+                n = qoid
+            elif attnames.get(n) == 'bytea':
+                value = self.unescape_bytea(value)
+            arg[n] = value
         return arg
 
     def insert(self, cl, d=None, **kw):
@@ -822,9 +826,13 @@
         self._do_debug(q)
         res = self.db.query(q)
         if ret:
-            res = res.dictresult()
-            for att, value in res[0].items():
-                d[qoid if att == 'oid' else att] = value
+            res = res.dictresult()[0]
+            for n, value in res.items():
+                if n == 'oid':
+                    n = qoid
+                elif attnames.get(n) == 'bytea':
+                    value = self.unescape_bytea(value)
+                d[n] = value
         elif isinstance(res, int):
             d[qoid] = res
             if selectable:
@@ -893,8 +901,12 @@
         res = self.db.query(q)
         if ret:
             res = res.dictresult()[0]
-            for att, value in res.items():
-                d[qoid if att == 'oid' else att] = value
+            for n, value in res.items():
+                if n == 'oid':
+                    n = qoid
+                elif attnames.get(n) == 'bytea':
+                    value = self.unescape_bytea(value)
+                d[n] = value
         else:
             if selectable:
                 if qoid in d:

Modified: trunk/module/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/module/tests/test_classic_dbwrapper.py        Thu Dec 10 07:40:11 
2015        (r663)
+++ trunk/module/tests/test_classic_dbwrapper.py        Thu Dec 10 07:59:20 
2015        (r664)
@@ -1087,7 +1087,7 @@
         self.assertEqual(r, s)
         query('drop table bytea_test')
 
-    def testInsertUpdateBytea(self):
+    def testInsertUpdateGetBytea(self):
         query = self.db.query
         query('drop table if exists bytea_test')
         query('create table bytea_test (n smallint primary key, data bytea)')
@@ -1099,10 +1099,6 @@
         self.assertEqual(r['n'], 5)
         self.assertIn('data', r)
         r = r['data']
-        # the following two lines should be removed once insert()
-        # will be enhanced to adapt the types of return values
-        self.assertIsInstance(r, str)
-        r = self.db.unescape_bytea(r)
         self.assertIsInstance(r, bytes)
         self.assertEqual(r, s)
         # update as bytes
@@ -1113,10 +1109,6 @@
         self.assertEqual(r['n'], 5)
         self.assertIn('data', r)
         r = r['data']
-        # the following two lines should be removed once update()
-        # will be enhanced to adapt the types of return values
-        self.assertIsInstance(r, str)
-        r = self.db.unescape_bytea(r)
         self.assertIsInstance(r, bytes)
         self.assertEqual(r, s)
         r = query('select * from bytea_test where n=5').getresult()
@@ -1129,6 +1121,14 @@
         r = self.db.unescape_bytea(r)
         self.assertIsInstance(r, bytes)
         self.assertEqual(r, s)
+        r = self.db.get('bytea_test', dict(n=5))
+        self.assertIsInstance(r, dict)
+        self.assertIn('n', r)
+        self.assertEqual(r['n'], 5)
+        self.assertIn('data', r)
+        r = r['data']
+        self.assertIsInstance(r, bytes)
+        self.assertEqual(r, s)
         query('drop table bytea_test')
 
     def testDebugWithCallable(self):
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to