Author: cito
Date: Thu Dec 10 06:36:49 2015
New Revision: 662

Log:
insert() and update() should convert bytes to bytea

As suggested by D'Arcy on the mailing list.

Modified:
   trunk/module/pg.py
   trunk/module/tests/test_classic_dbwrapper.py

Modified: trunk/module/pg.py
==============================================================================
--- trunk/module/pg.py  Sun Dec  6 08:59:14 2015        (r661)
+++ trunk/module/pg.py  Thu Dec 10 06:36:49 2015        (r662)
@@ -383,10 +383,21 @@
             d = str(d)
         return d
 
+    if bytes is str:  # Python < 3.0
+        """Quote bytes value."""
+
+        def _quote_bytea(self, d):
+            return "'%s'" % self.escape_bytea(d)
+
+    else:
+
+        def _quote_bytea(self, d):
+            return "'%s'" % self.escape_bytea(d).decode('ascii')
+
     _quote_funcs = dict(  # quote methods for each type
         text=_quote_text, bool=_quote_bool, date=_quote_date,
         int=_quote_num, num=_quote_num, float=_quote_num,
-        money=_quote_money)
+        money=_quote_money, bytea=_quote_bytea)
 
     def _quote(self, d, t):
         """Return quotes if needed."""
@@ -688,6 +699,8 @@
                     typ = 'num'
                 elif typ.startswith('money'):
                     typ = 'money'
+                elif typ.startswith('bytea'):
+                    typ = 'bytea'
                 else:
                     typ = 'text'
                 t[att] = typ

Modified: trunk/module/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/module/tests/test_classic_dbwrapper.py        Sun Dec  6 08:59:14 
2015        (r661)
+++ trunk/module/tests/test_classic_dbwrapper.py        Thu Dec 10 06:36:49 
2015        (r662)
@@ -1071,16 +1071,61 @@
     def testBytea(self):
         query = self.db.query
         query('drop table if exists bytea_test')
-        query('create table bytea_test ('
-            'data bytea)')
+        query('create table bytea_test (n smallint primary key, data bytea)')
         s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
         r = self.db.escape_bytea(s)
-        query('insert into bytea_test values($1)', (r,))
-        r = query('select * from bytea_test').getresult()
-        self.assertTrue(len(r) == 1)
+        query('insert into bytea_test values(3,$1)', (r,))
+        r = query('select * from bytea_test where n=3').getresult()
+        self.assertEqual(len(r), 1)
         r = r[0]
-        self.assertTrue(len(r) == 1)
+        self.assertEqual(len(r), 2)
+        self.assertEqual(r[0], 3)
+        r = r[1]
+        self.assertIsInstance(r, str)
+        r = self.db.unescape_bytea(r)
+        self.assertIsInstance(r, bytes)
+        self.assertEqual(r, s)
+        query('drop table bytea_test')
+
+    def testInsertUpdateBytea(self):
+        query = self.db.query
+        query('drop table if exists bytea_test')
+        query('create table bytea_test (n smallint primary key, data bytea)')
+        # insert as bytes
+        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
+        r = self.db.insert('bytea_test', n=5, data=s)
+        self.assertIsInstance(r, dict)
+        self.assertIn('n', r)
+        self.assertEqual(r['n'], 5)
+        self.assertIn('data', r)
+        r = r['data']
+        # the following two lines should be removed once insert()
+        # will be enhanced to adapt the types of return values
+        self.assertIsInstance(r, str)
+        r = self.db.unescape_bytea(r)
+        self.assertIsInstance(r, bytes)
+        self.assertEqual(r, s)
+        # update as bytes
+        s += b"and now even more \x00 nasty \t stuff!\f"
+        r = self.db.update('bytea_test', n=5, data=s)
+        self.assertIsInstance(r, dict)
+        self.assertIn('n', r)
+        self.assertEqual(r['n'], 5)
+        self.assertIn('data', r)
+        r = r['data']
+        # the following two lines should be removed once update()
+        # will be enhanced to adapt the types of return values
+        self.assertIsInstance(r, str)
+        r = self.db.unescape_bytea(r)
+        self.assertIsInstance(r, bytes)
+        self.assertEqual(r, s)
+        r = query('select * from bytea_test where n=5').getresult()
+        self.assertEqual(len(r), 1)
         r = r[0]
+        self.assertEqual(len(r), 2)
+        self.assertEqual(r[0], 5)
+        r = r[1]
+        self.assertIsInstance(r, str)
         r = self.db.unescape_bytea(r)
         self.assertIsInstance(r, bytes)
         self.assertEqual(r, s)
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to