Author: cito
Date: Thu Dec 10 07:59:20 2015
New Revision: 664
Log:
get() should convert bytea to bytes
insert() and update() should also do this with their return values.
Modified:
trunk/module/pg.py
trunk/module/tests/test_classic_dbwrapper.py
Modified: trunk/module/pg.py
==============================================================================
--- trunk/module/pg.py Thu Dec 10 07:40:11 2015 (r663)
+++ trunk/module/pg.py Thu Dec 10 07:59:20 2015 (r664)
@@ -758,6 +758,7 @@
keyname = self.pkey(qcl)
except KeyError:
raise _prg_error('Class %s has no primary key' % qcl)
+ attnames = self.get_attnames(qcl)
# We want the oid for later updates if that isn't the key
if keyname == 'oid':
if isinstance(arg, dict):
@@ -765,26 +766,29 @@
raise _db_error('%s not in arg' % qoid)
else:
arg = {qoid: arg}
+ what = '*'
where = 'oid = %s' % arg[qoid]
- attnames = '*'
else:
- attnames = self.get_attnames(qcl)
if isinstance(keyname, basestring):
keyname = (keyname,)
if not isinstance(arg, dict):
if len(keyname) > 1:
raise _prg_error('Composite key needs dict as arg')
arg = dict([(k, arg) for k in keyname])
+ what = ', '.join(attnames)
where = ' AND '.join(['%s = %s'
% (k, self._quote(arg[k], attnames[k])) for k in keyname])
- attnames = ', '.join(attnames)
- q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
+ q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (what, qcl, where)
self._do_debug(q)
res = self.db.query(q).dictresult()
if not res:
raise _db_error('No such record in %s where %s' % (qcl, where))
- for att, value in res[0].items():
- arg[qoid if att == 'oid' else att] = value
+ for n, value in res[0].items():
+ if n == 'oid':
+ n = qoid
+ elif attnames.get(n) == 'bytea':
+ value = self.unescape_bytea(value)
+ arg[n] = value
return arg
def insert(self, cl, d=None, **kw):
@@ -822,9 +826,13 @@
self._do_debug(q)
res = self.db.query(q)
if ret:
- res = res.dictresult()
- for att, value in res[0].items():
- d[qoid if att == 'oid' else att] = value
+ res = res.dictresult()[0]
+ for n, value in res.items():
+ if n == 'oid':
+ n = qoid
+ elif attnames.get(n) == 'bytea':
+ value = self.unescape_bytea(value)
+ d[n] = value
elif isinstance(res, int):
d[qoid] = res
if selectable:
@@ -893,8 +901,12 @@
res = self.db.query(q)
if ret:
res = res.dictresult()[0]
- for att, value in res.items():
- d[qoid if att == 'oid' else att] = value
+ for n, value in res.items():
+ if n == 'oid':
+ n = qoid
+ elif attnames.get(n) == 'bytea':
+ value = self.unescape_bytea(value)
+ d[n] = value
else:
if selectable:
if qoid in d:
Modified: trunk/module/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/module/tests/test_classic_dbwrapper.py Thu Dec 10 07:40:11
2015 (r663)
+++ trunk/module/tests/test_classic_dbwrapper.py Thu Dec 10 07:59:20
2015 (r664)
@@ -1087,7 +1087,7 @@
self.assertEqual(r, s)
query('drop table bytea_test')
- def testInsertUpdateBytea(self):
+ def testInsertUpdateGetBytea(self):
query = self.db.query
query('drop table if exists bytea_test')
query('create table bytea_test (n smallint primary key, data bytea)')
@@ -1099,10 +1099,6 @@
self.assertEqual(r['n'], 5)
self.assertIn('data', r)
r = r['data']
- # the following two lines should be removed once insert()
- # will be enhanced to adapt the types of return values
- self.assertIsInstance(r, str)
- r = self.db.unescape_bytea(r)
self.assertIsInstance(r, bytes)
self.assertEqual(r, s)
# update as bytes
@@ -1113,10 +1109,6 @@
self.assertEqual(r['n'], 5)
self.assertIn('data', r)
r = r['data']
- # the following two lines should be removed once update()
- # will be enhanced to adapt the types of return values
- self.assertIsInstance(r, str)
- r = self.db.unescape_bytea(r)
self.assertIsInstance(r, bytes)
self.assertEqual(r, s)
r = query('select * from bytea_test where n=5').getresult()
@@ -1129,6 +1121,14 @@
r = self.db.unescape_bytea(r)
self.assertIsInstance(r, bytes)
self.assertEqual(r, s)
+ r = self.db.get('bytea_test', dict(n=5))
+ self.assertIsInstance(r, dict)
+ self.assertIn('n', r)
+ self.assertEqual(r['n'], 5)
+ self.assertIn('data', r)
+ r = r['data']
+ self.assertIsInstance(r, bytes)
+ self.assertEqual(r, s)
query('drop table bytea_test')
def testDebugWithCallable(self):
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql