Author: cito
Date: Tue Jan 12 20:58:54 2016
New Revision: 730

Log:
Use query parameters instead of inline values

The single row methods of the DB wrapper class created queries with inline 
values
instead of passing them separately as parameters, even though our query method
does have this capability. Using query parameters also spares us a lot of 
quoting
and escaping that is necessary when passing values inline.

Modified:
   trunk/docs/contents/changelog.rst
   trunk/pg.py
   trunk/tests/test_classic.py
   trunk/tests/test_classic_connection.py
   trunk/tests/test_classic_dbwrapper.py

Modified: trunk/docs/contents/changelog.rst
==============================================================================
--- trunk/docs/contents/changelog.rst   Tue Jan 12 16:29:07 2016        (r729)
+++ trunk/docs/contents/changelog.rst   Tue Jan 12 20:58:54 2016        (r730)
@@ -34,10 +34,11 @@
   you call the methods using it and you are using tables with OIDs.
   Note that OIDs are considered deprecated anyway, and they are not created
   by default any more in PostgreSQL 8.1 and later.
-- Simplified the internal caching and mechanisms for automatic quoting
-  of class names in the classic interface, these things should now both
-  perform better and use less memory.
-
+- The internal caching and automatic quoting of class names in the classic
+  interface has been simplified and improved, it should now perform better
+  and use less memory. Also, overhead for quoting and escaping values in the
+  DB wrapper methods has been reduced and security has been improved by
+  passing the values to libpq separately as parameters instead of inline.
 
 Version 4.2
 -----------

Modified: trunk/pg.py
==============================================================================
--- trunk/pg.py Tue Jan 12 16:29:07 2016        (r729)
+++ trunk/pg.py Tue Jan 12 20:58:54 2016        (r730)
@@ -37,7 +37,7 @@
 
 from decimal import Decimal
 from collections import namedtuple
-from itertools import groupby
+from functools import partial
 
 try:
     basestring
@@ -315,9 +315,10 @@
 
     # Auxiliary methods
 
-    def _do_debug(self, s):
+    def _do_debug(self, *args):
         """Print a debug message."""
         if self.debug:
+            s = '\n'.join(args)
             if isinstance(self.debug, basestring):
                 print(self.debug % s)
             elif hasattr(self.debug, 'write'):
@@ -332,72 +333,55 @@
         """Get boolean value corresponding to d."""
         return bool(d) if get_bool() else ('t' if d else 'f')
 
-    def _quote_text(self, d):
-        """Quote text value."""
-        if not isinstance(d, basestring):
-            d = str(d)
-        return "'%s'" % self.escape_string(d)
+    _bool_true_values = frozenset('t true 1 y yes on'.split())
 
-    _bool_true = frozenset('t true 1 y yes on'.split())
-
-    def _quote_bool(self, d):
-        """Quote boolean value."""
+    def _prepare_bool(self, d):
+        """Prepare a boolean parameter."""
         if isinstance(d, basestring):
             if not d:
-                return 'NULL'
-            d = d.lower() in self._bool_true
-        return "'t'" if d else "'f'"
+                return None
+            d = d.lower() in self._bool_true_values
+        return 't' if d else 'f'
 
     _date_literals = frozenset('current_date current_time'
         ' current_timestamp localtime localtimestamp'.split())
 
-    def _quote_date(self, d):
-        """Quote date value."""
+    def _prepare_date(self, d):
+        """Prepare a date parameter."""
         if not d:
-            return 'NULL'
+            return None
         if isinstance(d, basestring) and d.lower() in self._date_literals:
-            return d
-        return self._quote_text(d)
+            raise ValueError
+        return d
 
-    def _quote_num(self, d):
-        """Quote numeric value."""
+    def _prepare_num(self, d):
+        """Prepare a numeric parameter."""
         if not d and d != 0:
-            return 'NULL'
-        return str(d)
-
-    def _quote_money(self, d):
-        """Quote money value."""
-        if d is None or d == '':
-            return 'NULL'
-        if not isinstance(d, basestring):
-            d = str(d)
+            return None
         return d
 
-    if bytes is str:  # Python < 3.0
-        """Quote bytes value."""
-
-        def _quote_bytea(self, d):
-            return "'%s'" % self.escape_bytea(d)
-
-    else:
-
-        def _quote_bytea(self, d):
-            return "'%s'" % self.escape_bytea(d).decode('ascii')
+    def _prepare_bytea(self, d):
+        return self.escape_bytea(d)
 
-    _quote_funcs = dict(  # quote methods for each type
-        text=_quote_text, bool=_quote_bool, date=_quote_date,
-        int=_quote_num, num=_quote_num, float=_quote_num,
-        money=_quote_money, bytea=_quote_bytea)
-
-    def _quote(self, d, t):
-        """Return quotes if needed."""
-        if d is None:
-            return 'NULL'
-        try:
-            quote_func = self._quote_funcs[t]
-        except KeyError:
-            quote_func = self._quote_funcs['text']
-        return quote_func(self, d)
+    _prepare_funcs = dict(  # quote methods for each type
+        bool=_prepare_bool, date=_prepare_date,
+        int=_prepare_num, num=_prepare_num, float=_prepare_num,
+        money=_prepare_num, bytea=_prepare_bytea)
+
+    def _prepare_param(self, value, typ, params):
+        """Prepare and add a parameter to the list."""
+        if value is not None and typ != 'text':
+            try:
+                prepare = self._prepare_funcs[typ]
+            except KeyError:
+                pass
+            else:
+                try:
+                    value = prepare(self, value)
+                except ValueError:
+                    return value
+        params.append(value)
+        return '$%d' % len(params)
 
     # Public methods
 
@@ -578,7 +562,6 @@
         if flush:
             attnames.clear()
             self._do_debug('pkey cache has been flushed')
-
         try:  # cache lookup
             names = attnames[cl]
         except KeyError:  # cache miss, check the database
@@ -651,6 +634,8 @@
             except KeyError:
                 raise _prg_error('Class %s has no primary key' % cl)
         attnames = self.get_attnames(cl)
+        params = []
+        param = partial(self._prepare_param, params=params)
         # We want the oid for later updates if that isn't the key
         if keyname == 'oid':
             if isinstance(arg, dict):
@@ -659,7 +644,7 @@
             else:
                 arg = {qoid: arg}
             what = '*'
-            where = 'oid = %s' % arg[qoid]
+            where = 'oid = %s' % param(arg[qoid], 'int')
         else:
             if isinstance(keyname, basestring):
                 keyname = (keyname,)
@@ -669,11 +654,11 @@
                 arg = dict([(k, arg) for k in keyname])
             what = ', '.join(attnames)
             where = ' AND '.join(['%s = %s'
-                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
+                % (k, param(arg[k], attnames[k])) for k in keyname])
         q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
             what, _quote_class_name(cl), where)
-        self._do_debug(q)
-        res = self.db.query(q).dictresult()
+        self._do_debug(q, params)
+        res = self.db.query(q, params).dictresult()
         if not res:
             raise _db_error('No such record in %s where %s' % (cl, where))
         for n, value in res[0].items():
@@ -706,11 +691,13 @@
             d = {}
         d.update(kw)
         attnames = self.get_attnames(cl)
+        params = []
+        param = partial(self._prepare_param, params=params)
         names, values = [], []
         for n in attnames:
             if n != 'oid' and n in d:
                 names.append('"%s"' % n)
-                values.append(self._quote(d[n], attnames[n]))
+                values.append(param(d[n], attnames[n]))
         names, values = ', '.join(names), ', '.join(values)
         selectable = self.has_table_privilege(cl)
         if selectable:
@@ -719,14 +706,14 @@
             ret = ''
         q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (
             _quote_class_name(cl), names, values, ret)
-        self._do_debug(q)
-        res = self.db.query(q)
+        self._do_debug(q, params)
+        res = self.db.query(q, params)
         if ret:
             res = res.dictresult()[0]
             for n, value in res.items():
                 if n == 'oid':
                     n = qoid
-                elif attnames.get(n) == 'bytea':
+                elif attnames.get(n) == 'bytea' and value is not None:
                     value = self.unescape_bytea(value)
                 d[n] = value
         elif isinstance(res, int):
@@ -764,8 +751,10 @@
             d = {}
         d.update(kw)
         attnames = self.get_attnames(cl)
+        params = []
+        param = partial(self._prepare_param, params=params)
         if qoid in d:
-            where = 'oid = %s' % d[qoid]
+            where = 'oid = %s' % param(d[qoid], 'int')
             keyname = ()
         else:
             try:
@@ -776,13 +765,13 @@
                 keyname = (keyname,)
             try:
                 where = ' AND '.join(['%s = %s'
-                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
+                    % (k, param(d[k], attnames[k])) for k in keyname])
             except KeyError:
                 raise _prg_error('Update needs primary key or oid.')
         values = []
         for n in attnames:
             if n in d and n not in keyname:
-                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
+                values.append('%s = %s' % (n, param(d[n], attnames[n])))
         if not values:
             return d
         values = ', '.join(values)
@@ -793,14 +782,14 @@
             ret = ''
         q = 'UPDATE %s SET %s WHERE %s%s' % (
             _quote_class_name(cl), values, where, ret)
-        self._do_debug(q)
-        res = self.db.query(q)
+        self._do_debug(q, params)
+        res = self.db.query(q, params)
         if ret:
             res = res.dictresult()[0]
             for n, value in res.items():
                 if n == 'oid':
                     n = qoid
-                elif attnames.get(n) == 'bytea':
+                elif attnames.get(n) == 'bytea' and value is not None:
                     value = self.unescape_bytea(value)
                 d[n] = value
         else:
@@ -857,8 +846,10 @@
         if d is None:
             d = {}
         d.update(kw)
+        params = []
+        param = partial(self._prepare_param, params=params)
         if qoid in d:
-            where = 'oid = %s' % d[qoid]
+            where = 'oid = %s' % param(d[qoid], 'int')
         else:
             try:
                 keyname = self.pkey(cl)
@@ -869,12 +860,12 @@
             attnames = self.get_attnames(cl)
             try:
                 where = ' AND '.join(['%s = %s'
-                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
+                    % (k, param(d[k], attnames[k])) for k in keyname])
             except KeyError:
                 raise _prg_error('Delete needs primary key or oid.')
         q = 'DELETE FROM %s WHERE %s' % (_quote_class_name(cl), where)
-        self._do_debug(q)
-        return int(self.db.query(q))
+        self._do_debug(q, params)
+        return int(self.db.query(q, params))
 
     def notification_handler(self, event, callback, arg_dict={}, timeout=None):
         """Get notification handler that will run the given callback."""

Modified: trunk/tests/test_classic.py
==============================================================================
--- trunk/tests/test_classic.py Tue Jan 12 16:29:07 2016        (r729)
+++ trunk/tests/test_classic.py Tue Jan 12 20:58:54 2016        (r730)
@@ -192,6 +192,7 @@
         self.assertEqual(r['dvar'], 123)
 
         r = db.get('_test_schema', 1234)
+        self.assertIn('dvar', r)
         db.update('_test_schema', _test=1234, dvar=456)
         r = db.get('_test_schema', 1234)
         self.assertEqual(r['dvar'], 456)
@@ -201,52 +202,6 @@
         r = db.get('_test_schema', 1234)
         self.assertEqual(r['dvar'], 456)
 
-    def test_quote(self):
-        db = opendb()
-        q = db._quote
-        self.assertEqual(q(0, 'int'), "0")
-        self.assertEqual(q(0, 'num'), "0")
-        self.assertEqual(q('0', 'int'), "0")
-        self.assertEqual(q('0', 'num'), "0")
-        self.assertEqual(q(1, 'int'), "1")
-        self.assertEqual(q(1, 'text'), "'1'")
-        self.assertEqual(q(1, 'num'), "1")
-        self.assertEqual(q('1', 'int'), "1")
-        self.assertEqual(q('1', 'text'), "'1'")
-        self.assertEqual(q('1', 'num'), "1")
-        self.assertEqual(q(None, 'int'), "NULL")
-        self.assertEqual(q(1, 'money'), "1")
-        self.assertEqual(q('1', 'money'), "1")
-        self.assertEqual(q(1.234, 'money'), "1.234")
-        self.assertEqual(q('1.234', 'money'), "1.234")
-        self.assertEqual(q(0, 'money'), "0")
-        self.assertEqual(q(0.00, 'money'), "0.0")
-        self.assertEqual(q(Decimal('0.00'), 'money'), "0.00")
-        self.assertEqual(q(None, 'money'), "NULL")
-        self.assertEqual(q('', 'money'), "NULL")
-        self.assertEqual(q(0, 'bool'), "'f'")
-        self.assertEqual(q('', 'bool'), "NULL")
-        self.assertEqual(q('f', 'bool'), "'f'")
-        self.assertEqual(q('off', 'bool'), "'f'")
-        self.assertEqual(q('no', 'bool'), "'f'")
-        self.assertEqual(q(1, 'bool'), "'t'")
-        self.assertEqual(q(9999, 'bool'), "'t'")
-        self.assertEqual(q(-9999, 'bool'), "'t'")
-        self.assertEqual(q('1', 'bool'), "'t'")
-        self.assertEqual(q('t', 'bool'), "'t'")
-        self.assertEqual(q('on', 'bool'), "'t'")
-        self.assertEqual(q('yes', 'bool'), "'t'")
-        self.assertEqual(q('true', 'bool'), "'t'")
-        self.assertEqual(q('y', 'bool'), "'t'")
-        self.assertEqual(q('', 'date'), "NULL")
-        self.assertEqual(q(False, 'date'), "NULL")
-        self.assertEqual(q(0, 'date'), "NULL")
-        self.assertEqual(q('some_date', 'date'), "'some_date'")
-        self.assertEqual(q('current_timestamp', 'date'), "current_timestamp")
-        self.assertEqual(q('', 'text'), "''")
-        self.assertEqual(q("'", 'text'), "''''")
-        self.assertEqual(q("\\", 'text'), "'\\\\'")
-
     def notify_callback(self, arg_dict):
         if arg_dict:
             arg_dict['called'] = True

Modified: trunk/tests/test_classic_connection.py
==============================================================================
--- trunk/tests/test_classic_connection.py      Tue Jan 12 16:29:07 2016        
(r729)
+++ trunk/tests/test_classic_connection.py      Tue Jan 12 20:58:54 2016        
(r730)
@@ -1347,7 +1347,7 @@
         en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
         en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
         de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
-        de_money = ('34,25€', '34,25 €', '€34,25' '€ 34,25',
+        de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25',
             'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
         # first try with English localization (using the point)
         for lc in en_locales:

Modified: trunk/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/tests/test_classic_dbwrapper.py       Tue Jan 12 16:29:07 2016        
(r729)
+++ trunk/tests/test_classic_dbwrapper.py       Tue Jan 12 20:58:54 2016        
(r730)
@@ -405,80 +405,6 @@
             b'\\x746861742773206be47365')
         self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
 
-    def testQuote(self):
-        f = self.db._quote
-        self.assertEqual(f(None, None), 'NULL')
-        self.assertEqual(f(None, 'int'), 'NULL')
-        self.assertEqual(f(None, 'float'), 'NULL')
-        self.assertEqual(f(None, 'num'), 'NULL')
-        self.assertEqual(f(None, 'money'), 'NULL')
-        self.assertEqual(f(None, 'bool'), 'NULL')
-        self.assertEqual(f(None, 'date'), 'NULL')
-        self.assertEqual(f('', 'int'), 'NULL')
-        self.assertEqual(f('', 'float'), 'NULL')
-        self.assertEqual(f('', 'num'), 'NULL')
-        self.assertEqual(f('', 'money'), 'NULL')
-        self.assertEqual(f('', 'bool'), 'NULL')
-        self.assertEqual(f('', 'date'), 'NULL')
-        self.assertEqual(f('', 'text'), "''")
-        self.assertEqual(f(0, 'int'), '0')
-        self.assertEqual(f(0, 'num'), '0')
-        self.assertEqual(f(1, 'int'), '1')
-        self.assertEqual(f(1, 'num'), '1')
-        self.assertEqual(f(-1, 'int'), '-1')
-        self.assertEqual(f(-1, 'num'), '-1')
-        self.assertEqual(f(123456789, 'int'), '123456789')
-        self.assertEqual(f(123456987, 'num'), '123456987')
-        self.assertEqual(f(1.23654789, 'num'), '1.23654789')
-        self.assertEqual(f(12365478.9, 'num'), '12365478.9')
-        self.assertEqual(f('123456789', 'num'), '123456789')
-        self.assertEqual(f('1.23456789', 'num'), '1.23456789')
-        self.assertEqual(f('12345678.9', 'num'), '12345678.9')
-        self.assertEqual(f(123, 'money'), '123')
-        self.assertEqual(f('123', 'money'), '123')
-        self.assertEqual(f(123.45, 'money'), '123.45')
-        self.assertEqual(f('123.45', 'money'), '123.45')
-        self.assertEqual(f(123.454, 'money'), '123.454')
-        self.assertEqual(f('123.454', 'money'), '123.454')
-        self.assertEqual(f(123.456, 'money'), '123.456')
-        self.assertEqual(f('123.456', 'money'), '123.456')
-        self.assertEqual(f('f', 'bool'), "'f'")
-        self.assertEqual(f('F', 'bool'), "'f'")
-        self.assertEqual(f('false', 'bool'), "'f'")
-        self.assertEqual(f('False', 'bool'), "'f'")
-        self.assertEqual(f('FALSE', 'bool'), "'f'")
-        self.assertEqual(f(0, 'bool'), "'f'")
-        self.assertEqual(f('0', 'bool'), "'f'")
-        self.assertEqual(f('-', 'bool'), "'f'")
-        self.assertEqual(f('n', 'bool'), "'f'")
-        self.assertEqual(f('N', 'bool'), "'f'")
-        self.assertEqual(f('no', 'bool'), "'f'")
-        self.assertEqual(f('off', 'bool'), "'f'")
-        self.assertEqual(f('t', 'bool'), "'t'")
-        self.assertEqual(f('T', 'bool'), "'t'")
-        self.assertEqual(f('true', 'bool'), "'t'")
-        self.assertEqual(f('True', 'bool'), "'t'")
-        self.assertEqual(f('TRUE', 'bool'), "'t'")
-        self.assertEqual(f(1, 'bool'), "'t'")
-        self.assertEqual(f(2, 'bool'), "'t'")
-        self.assertEqual(f(-1, 'bool'), "'t'")
-        self.assertEqual(f(0.5, 'bool'), "'t'")
-        self.assertEqual(f('1', 'bool'), "'t'")
-        self.assertEqual(f('y', 'bool'), "'t'")
-        self.assertEqual(f('Y', 'bool'), "'t'")
-        self.assertEqual(f('yes', 'bool'), "'t'")
-        self.assertEqual(f('on', 'bool'), "'t'")
-        self.assertEqual(f('01.01.2000', 'date'), "'01.01.2000'")
-        self.assertEqual(f(123, 'text'), "'123'")
-        self.assertEqual(f(1.23, 'text'), "'1.23'")
-        self.assertEqual(f('abc', 'text'), "'abc'")
-        self.assertEqual(f("ab'c", 'text'), "'ab''c'")
-        self.assertEqual(f('ab\\c', 'text'), "'ab\\c'")
-        self.assertEqual(f("a\\b'c", 'text'), "'a\\b''c'")
-        self.db.query('set standard_conforming_strings=off')
-        self.assertEqual(f('ab\\c', 'text'), "'ab\\\\c'")
-        self.assertEqual(f("a\\b'c", 'text'), "'a\\\\b''c'")
-
     def testQuery(self):
         query = self.db.query
         query("drop table if exists test_table")
@@ -786,7 +712,6 @@
     def testInsert(self):
         insert = self.db.insert
         query = self.db.query
-        server_version = self.db.server_version
         bool_on = pg.get_bool()
         decimal = pg.get_decimal()
         for table in ('insert_test_table', 'test table for insert'):
@@ -859,9 +784,6 @@
                     m = expect.get('m')
                     if m is not None:
                         expect['m'] = decimal(m)
-                if data.get('m') and server_version < 910000:
-                    # PostgreSQL < 9.1 cannot directly convert numbers to money
-                    data['m'] = "'%s'::money" % data['m']
                 self.assertEqual(insert(table, data), data)
                 self.assertIn(oid_table, data)
                 oid = data[oid_table]
@@ -1129,6 +1051,24 @@
         query = self.db.query
         query('drop table if exists bytea_test')
         query('create table bytea_test (n smallint primary key, data bytea)')
+        # insert null value
+        r = self.db.insert('bytea_test', n=0, data=None)
+        self.assertIsInstance(r, dict)
+        self.assertIn('n', r)
+        self.assertEqual(r['n'], 0)
+        self.assertIn('data', r)
+        self.assertIsNone(r['data'])
+        s = b'None'
+        r = self.db.update('bytea_test', n=0, data=s)
+        self.assertIsInstance(r, dict)
+        self.assertIn('n', r)
+        self.assertEqual(r['n'], 0)
+        self.assertIn('data', r)
+        r = r['data']
+        self.assertIsInstance(r, bytes)
+        self.assertEqual(r, s)
+        r = self.db.update('bytea_test', n=0, data=None)
+        self.assertIsNone(r['data'])
         # insert as bytes
         s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
         r = self.db.insert('bytea_test', n=5, data=s)
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to