Author: cito
Date: Wed Jan 13 07:45:34 2016
New Revision: 732

Log:
Better handling of quoted identifiers

Methods like get(), update() did not handle quoted identifiers properly
(i.e. identifiers with spaces, mixed case characters or special characters).
This has been improved and tests have been added to make sure this works.

Modified:
   trunk/pg.py
   trunk/tests/test_classic_dbwrapper.py

Modified: trunk/pg.py
==============================================================================
--- trunk/pg.py Tue Jan 12 21:00:04 2016        (r731)
+++ trunk/pg.py Wed Jan 13 07:45:34 2016        (r732)
@@ -49,34 +49,8 @@
 
 # Auxiliary functions that are independent from a DB connection:
 
-def _quote_class_name(cl):
-    """Quote a class name.
-
-    Class names are always quoted unless they contain a dot.
-    In this ambiguous case quotes must be added manually.
-
-    """
-    if '.' not in cl:
-        cl = '"%s"' % cl
-    return cl
-
-
-def _quote_class_param(cl, param):
-    """Quote parameter representing a class name.
-
-    The parameter is automatically quoted unless the class name contains a dot.
-    In this ambiguous case quotes must be added manually.
-
-    """
-    if isinstance(param, int):
-        param = "$%d" % param
-    if '.' not in cl:
-        param = 'quote_ident(%s)' % (param,)
-    return param
-
-
 def _oid_key(cl):
-    """Build oid key from qualified class name."""
+    """Build oid key from a class name."""
     return 'oid(%s)' % cl
 
 
@@ -328,6 +302,19 @@
             else:
                 print(s)
 
+    def _escape_qualified_name(self, s):
+        """Escape a qualified name.
+
+        Escapes the name for use as an SQL identifier, unless the
+        name contains a dot, in which case the name is ambiguous
+        (could be a qualified name or just a name with a dot in it)
+        and must be quoted manually by the caller.
+
+        """
+        if '.' not in s:
+            s = self.escape_identifier(s)
+        return s
+
     @staticmethod
     def _make_bool(d):
         """Get boolean value corresponding to d."""
@@ -361,6 +348,7 @@
         return d
 
     def _prepare_bytea(self, d):
+        """Prepare a bytea parameter."""
         return self.escape_bytea(d)
 
     _prepare_funcs = dict(  # quote methods for each type
@@ -383,6 +371,22 @@
         params.append(value)
         return '$%d' % len(params)
 
+    @staticmethod
+    def _prepare_qualified_param(cl, param):
+        """Quote parameter representing a qualified name.
+
+        Escapes the name for use as an SQL parameter, unless the
+        name contains a dot, in which case the name is ambiguous
+        (could be a qualified name or just a name with a dot in it)
+        and must be quoted manually by the caller.
+
+        """
+        if isinstance(param, int):
+            param = "$%d" % param
+        if '.' not in cl:
+            param = 'quote_ident(%s)' % (param,)
+        return param
+
     # Public methods
 
     # escape_string and escape_bytea exist as methods,
@@ -507,7 +511,7 @@
                 " AND a.attnum = ANY(i.indkey)"
                 " AND NOT a.attisdropped"
                 " WHERE i.indrelid=%s::regclass"
-                " AND i.indisprimary" % _quote_class_param(cl, 1))
+                " AND i.indisprimary" % self._prepare_qualified_param(cl, 1))
             pkey = self.db.query(q, (cl,)).getresult()
             if not pkey:
                 raise KeyError('Class %s has no primary key' % cl)
@@ -572,7 +576,7 @@
                 " AND (a.attnum > 0 OR a.attname = 'oid')"
                 " AND NOT a.attisdropped") % (
                     '::regtype' if self._regtypes else '',
-                    _quote_class_param(cl, 1))
+                    self._prepare_qualified_param(cl, 1))
             names = self.db.query(q, (cl,)).getresult()
             if not names:
                 raise KeyError('Class %s does not exist' % cl)
@@ -601,7 +605,7 @@
             return self._privileges[(cl, privilege)]
         except KeyError:  # cache miss, ask the database
             q = "SELECT has_table_privilege(%s, $2)" % (
-                _quote_class_param(cl, 1),)
+                self._prepare_qualified_param(cl, 1),)
             q = self.db.query(q, (cl, privilege))
             ret = q.getresult()[0][0] == self._make_bool(True)
             self._privileges[(cl, privilege)] = ret  # cache it
@@ -636,6 +640,7 @@
         attnames = self.get_attnames(cl)
         params = []
         param = partial(self._prepare_param, params=params)
+        col = self.escape_identifier
         # We want the oid for later updates if that isn't the key
         if keyname == 'oid':
             if isinstance(arg, dict):
@@ -651,12 +656,12 @@
             if not isinstance(arg, dict):
                 if len(keyname) > 1:
                     raise _prg_error('Composite key needs dict as arg')
-                arg = dict([(k, arg) for k in keyname])
-            what = ', '.join(attnames)
+                arg = dict((k, arg) for k in keyname)
+            what = ', '.join(col(k) for k in attnames)
             where = ' AND '.join(['%s = %s'
-                % (k, param(arg[k], attnames[k])) for k in keyname])
+                % (col(k), param(arg[k], attnames[k])) for k in keyname])
         q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
-            what, _quote_class_name(cl), where)
+            what, self._escape_qualified_name(cl), where)
         self._do_debug(q, params)
         res = self.db.query(q, params).dictresult()
         if not res:
@@ -693,10 +698,11 @@
         attnames = self.get_attnames(cl)
         params = []
         param = partial(self._prepare_param, params=params)
+        col = self.escape_identifier
         names, values = [], []
         for n in attnames:
             if n != 'oid' and n in d:
-                names.append('"%s"' % n)
+                names.append(col(n))
                 values.append(param(d[n], attnames[n]))
         names, values = ', '.join(names), ', '.join(values)
         selectable = self.has_table_privilege(cl)
@@ -705,7 +711,7 @@
         else:
             ret = ''
         q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (
-            _quote_class_name(cl), names, values, ret)
+            self._escape_qualified_name(cl), names, values, ret)
         self._do_debug(q, params)
         res = self.db.query(q, params)
         if ret:
@@ -753,6 +759,7 @@
         attnames = self.get_attnames(cl)
         params = []
         param = partial(self._prepare_param, params=params)
+        col = self.escape_identifier
         if qoid in d:
             where = 'oid = %s' % param(d[qoid], 'int')
             keyname = ()
@@ -765,13 +772,13 @@
                 keyname = (keyname,)
             try:
                 where = ' AND '.join(['%s = %s'
-                    % (k, param(d[k], attnames[k])) for k in keyname])
+                    % (col(k), param(d[k], attnames[k])) for k in keyname])
             except KeyError:
                 raise _prg_error('Update needs primary key or oid.')
         values = []
         for n in attnames:
             if n in d and n not in keyname:
-                values.append('%s = %s' % (n, param(d[n], attnames[n])))
+                values.append('%s = %s' % (col(n), param(d[n], attnames[n])))
         if not values:
             return d
         values = ', '.join(values)
@@ -781,7 +788,7 @@
         else:
             ret = ''
         q = 'UPDATE %s SET %s WHERE %s%s' % (
-            _quote_class_name(cl), values, where, ret)
+            self._escape_qualified_name(cl), values, where, ret)
         self._do_debug(q, params)
         res = self.db.query(q, params)
         if ret:
@@ -858,12 +865,14 @@
             if isinstance(keyname, basestring):
                 keyname = (keyname,)
             attnames = self.get_attnames(cl)
+            col = self.escape_identifier
             try:
                 where = ' AND '.join(['%s = %s'
-                    % (k, param(d[k], attnames[k])) for k in keyname])
+                    % (col(k), param(d[k], attnames[k])) for k in keyname])
             except KeyError:
                 raise _prg_error('Delete needs primary key or oid.')
-        q = 'DELETE FROM %s WHERE %s' % (_quote_class_name(cl), where)
+        q = 'DELETE FROM %s WHERE %s' % (
+            self._escape_qualified_name(cl), where)
         self._do_debug(q, params)
         return int(self.db.query(q, params))
 

Modified: trunk/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/tests/test_classic_dbwrapper.py       Tue Jan 12 21:00:04 2016        
(r731)
+++ trunk/tests/test_classic_dbwrapper.py       Wed Jan 13 07:45:34 2016        
(r732)
@@ -635,63 +635,63 @@
     def testGet(self):
         get = self.db.get
         query = self.db.query
-        for table in ('get_test_table', 'test table for get'):
-            query('drop table if exists "%s"' % table)
-            query('create table "%s" ('
-                "n integer, t text) with oids" % table)
-            for n, t in enumerate('xyz'):
-                query('insert into "%s" values('"%d, '%s')"
-                    % (table, n + 1, t))
-            self.assertRaises(pg.ProgrammingError, get, table, 2)
-            r = get(table, 2, 'n')
-            oid_table = 'oid(%s)' % table
-            self.assertIn(oid_table, r)
-            oid = r[oid_table]
-            self.assertIsInstance(oid, int)
-            result = {'t': 'y', 'n': 2, oid_table: oid}
-            self.assertEqual(r, result)
-            self.assertEqual(get(table + ' *', 2, 'n'), r)
-            self.assertEqual(get(table, oid, 'oid')['t'], 'y')
-            self.assertEqual(get(table, 1, 'n')['t'], 'x')
-            self.assertEqual(get(table, 3, 'n')['t'], 'z')
-            self.assertEqual(get(table, 2, 'n')['t'], 'y')
-            self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
-            r['n'] = 3
-            self.assertEqual(get(table, r, 'n')['t'], 'z')
-            self.assertEqual(get(table, 1, 'n')['t'], 'x')
-            query('alter table "%s" alter n set not null' % table)
-            query('alter table "%s" add primary key (n)' % table)
-            self.assertEqual(get(table, 3)['t'], 'z')
-            self.assertEqual(get(table, 1)['t'], 'x')
-            self.assertEqual(get(table, 2)['t'], 'y')
-            r['n'] = 1
-            self.assertEqual(get(table, r)['t'], 'x')
-            r['n'] = 3
-            self.assertEqual(get(table, r)['t'], 'z')
-            r['n'] = 2
-            self.assertEqual(get(table, r)['t'], 'y')
-            query('drop table "%s"' % table)
+        table = 'get_test_table'
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
+            "n integer, t text) with oids" % table)
+        for n, t in enumerate('xyz'):
+            query('insert into "%s" values('"%d, '%s')"
+                % (table, n + 1, t))
+        self.assertRaises(pg.ProgrammingError, get, table, 2)
+        r = get(table, 2, 'n')
+        oid_table = 'oid(%s)' % table
+        self.assertIn(oid_table, r)
+        oid = r[oid_table]
+        self.assertIsInstance(oid, int)
+        result = {'t': 'y', 'n': 2, oid_table: oid}
+        self.assertEqual(r, result)
+        self.assertEqual(get(table + ' *', 2, 'n'), r)
+        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
+        self.assertEqual(get(table, 1, 'n')['t'], 'x')
+        self.assertEqual(get(table, 3, 'n')['t'], 'z')
+        self.assertEqual(get(table, 2, 'n')['t'], 'y')
+        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
+        r['n'] = 3
+        self.assertEqual(get(table, r, 'n')['t'], 'z')
+        self.assertEqual(get(table, 1, 'n')['t'], 'x')
+        query('alter table "%s" alter n set not null' % table)
+        query('alter table "%s" add primary key (n)' % table)
+        self.assertEqual(get(table, 3)['t'], 'z')
+        self.assertEqual(get(table, 1)['t'], 'x')
+        self.assertEqual(get(table, 2)['t'], 'y')
+        r['n'] = 1
+        self.assertEqual(get(table, r)['t'], 'x')
+        r['n'] = 3
+        self.assertEqual(get(table, r)['t'], 'z')
+        r['n'] = 2
+        self.assertEqual(get(table, r)['t'], 'y')
+        query('drop table "%s"' % table)
 
     def testGetWithCompositeKey(self):
         get = self.db.get
         query = self.db.query
         table = 'get_test_table_1'
-        query("drop table if exists %s" % table)
-        query("create table %s ("
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
             "n integer, t text, primary key (n))" % table)
         for n, t in enumerate('abc'):
-            query("insert into %s values("
+            query('insert into "%s" values('
                 "%d, '%s')" % (table, n + 1, t))
         self.assertEqual(get(table, 2)['t'], 'b')
-        query("drop table %s" % table)
+        query('drop table "%s"' % table)
         table = 'get_test_table_2'
-        query("drop table if exists %s" % table)
-        query("create table %s ("
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
             "n integer, m integer, t text, primary key (n, m))" % table)
         for n in range(3):
             for m in range(2):
                 t = chr(ord('a') + 2 * n + m)
-                query("insert into %s values("
+                query('insert into "%s" values('
                     "%d, %d, '%s')" % (table, n + 1, m + 1, t))
         self.assertRaises(pg.ProgrammingError, get, table, 2)
         self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
@@ -699,7 +699,24 @@
                              ('n', 'm'))['t'], 'b')
         self.assertEqual(get(table, dict(n=3, m=2),
                              frozenset(['n', 'm']))['t'], 'f')
-        query("drop table %s" % table)
+        query('drop table "%s"' % table)
+
+    def testGetWithQuotedNames(self):
+        get = self.db.get
+        query = self.db.query
+        table = 'test table for get()'
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
+            '"Prime!" smallint primary key,'
+            '"much space" integer, "Questions?" text)' % table)
+        query('insert into "%s"'
+              " values(17, 1001, 'No!')" % table)
+        r = get(table, 17)
+        self.assertIsInstance(r, dict)
+        self.assertEqual(r['Prime!'], 17)
+        self.assertEqual(r['much space'], 1001)
+        self.assertEqual(r['Questions?'], 'No!')
+        query('drop table "%s"' % table)
 
     def testGetFromView(self):
         self.db.query('delete from test where i4=14')
@@ -714,134 +731,156 @@
         query = self.db.query
         bool_on = pg.get_bool()
         decimal = pg.get_decimal()
-        for table in ('insert_test_table', 'test table for insert'):
-            query('drop table if exists "%s"' % table)
-            query('create table "%s" ('
-                "i2 smallint, i4 integer, i8 bigint,"
-                " d numeric, f4 real, f8 double precision, m money,"
-                " v4 varchar(4), c4 char(4), t text,"
-                " b boolean, ts timestamp) with oids" % table)
-            oid_table = 'oid(%s)' % table
-            tests = [dict(i2=None, i4=None, i8=None),
-                (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
-                (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
-                dict(i2=42, i4=123456, i8=9876543210),
-                dict(i2=2 ** 15 - 1,
-                    i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
-                dict(d=None), (dict(d=''), dict(d=None)),
-                dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
-                dict(f4=None, f8=None), dict(f4=0, f8=0),
-                (dict(f4='', f8=''), dict(f4=None, f8=None)),
-                (dict(d=1234.5, f4=1234.5, f8=1234.5),
-                      dict(d=Decimal('1234.5'))),
-                dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
-                dict(d=Decimal('123456789.9876543212345678987654321')),
-                dict(m=None), (dict(m=''), dict(m=None)),
-                dict(m=Decimal('-1234.56')),
-                (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
-                dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
-                (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
-                (dict(m=1234.5), dict(m=Decimal('1234.5'))),
-                (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
-                (dict(m=123456), dict(m=Decimal('123456'))),
-                (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
-                dict(b=None), (dict(b=''), dict(b=None)),
-                dict(b='f'), dict(b='t'),
-                (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
-                (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
-                (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
-                (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
-                (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
-                (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
-                dict(v4=None, c4=None, t=None),
-                (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
-                dict(v4='1234', c4='1234', t='1234' * 10),
-                dict(v4='abcd', c4='abcd', t='abcdefg'),
-                (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
-                dict(ts=None), (dict(ts=''), dict(ts=None)),
-                (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
-                dict(ts='2012-12-21 00:00:00'),
-                (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
-                dict(ts='2012-12-21 12:21:12'),
-                dict(ts='2013-01-05 12:13:14'),
-                dict(ts='current_timestamp')]
-            for test in tests:
-                if isinstance(test, dict):
-                    data = test
-                    change = {}
+        table = 'insert_test_table'
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
+            "i2 smallint, i4 integer, i8 bigint,"
+            " d numeric, f4 real, f8 double precision, m money,"
+            " v4 varchar(4), c4 char(4), t text,"
+            " b boolean, ts timestamp) with oids" % table)
+        oid_table = 'oid(%s)' % table
+        tests = [dict(i2=None, i4=None, i8=None),
+            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
+            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
+            dict(i2=42, i4=123456, i8=9876543210),
+            dict(i2=2 ** 15 - 1,
+                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
+            dict(d=None), (dict(d=''), dict(d=None)),
+            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
+            dict(f4=None, f8=None), dict(f4=0, f8=0),
+            (dict(f4='', f8=''), dict(f4=None, f8=None)),
+            (dict(d=1234.5, f4=1234.5, f8=1234.5),
+                  dict(d=Decimal('1234.5'))),
+            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
+            dict(d=Decimal('123456789.9876543212345678987654321')),
+            dict(m=None), (dict(m=''), dict(m=None)),
+            dict(m=Decimal('-1234.56')),
+            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
+            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
+            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
+            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
+            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
+            (dict(m=123456), dict(m=Decimal('123456'))),
+            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
+            dict(b=None), (dict(b=''), dict(b=None)),
+            dict(b='f'), dict(b='t'),
+            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
+            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
+            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
+            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
+            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
+            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
+            dict(v4=None, c4=None, t=None),
+            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
+            dict(v4='1234', c4='1234', t='1234' * 10),
+            dict(v4='abcd', c4='abcd', t='abcdefg'),
+            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
+            dict(ts=None), (dict(ts=''), dict(ts=None)),
+            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
+            dict(ts='2012-12-21 00:00:00'),
+            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
+            dict(ts='2012-12-21 12:21:12'),
+            dict(ts='2013-01-05 12:13:14'),
+            dict(ts='current_timestamp')]
+        for test in tests:
+            if isinstance(test, dict):
+                data = test
+                change = {}
+            else:
+                data, change = test
+            expect = data.copy()
+            expect.update(change)
+            if bool_on:
+                b = expect.get('b')
+                if b is not None:
+                    expect['b'] = b == 't'
+            if decimal is not Decimal:
+                d = expect.get('d')
+                if d is not None:
+                    expect['d'] = decimal(d)
+                m = expect.get('m')
+                if m is not None:
+                    expect['m'] = decimal(m)
+            self.assertEqual(insert(table, data), data)
+            self.assertIn(oid_table, data)
+            oid = data[oid_table]
+            self.assertIsInstance(oid, int)
+            data = dict(item for item in data.items()
+                if item[0] in expect)
+            ts = expect.get('ts')
+            if ts == 'current_timestamp':
+                ts = expect['ts'] = data['ts']
+                if len(ts) > 19:
+                    self.assertEqual(ts[19], '.')
+                    ts = ts[:19]
                 else:
-                    data, change = test
-                expect = data.copy()
-                expect.update(change)
-                if bool_on:
-                    b = expect.get('b')
-                    if b is not None:
-                        expect['b'] = b == 't'
-                if decimal is not Decimal:
-                    d = expect.get('d')
-                    if d is not None:
-                        expect['d'] = decimal(d)
-                    m = expect.get('m')
-                    if m is not None:
-                        expect['m'] = decimal(m)
-                self.assertEqual(insert(table, data), data)
-                self.assertIn(oid_table, data)
-                oid = data[oid_table]
-                self.assertIsInstance(oid, int)
-                data = dict(item for item in data.items()
-                    if item[0] in expect)
-                ts = expect.get('ts')
-                if ts == 'current_timestamp':
-                    ts = expect['ts'] = data['ts']
-                    if len(ts) > 19:
-                        self.assertEqual(ts[19], '.')
-                        ts = ts[:19]
-                    else:
-                        self.assertEqual(len(ts), 19)
-                    self.assertTrue(ts[:4].isdigit())
-                    self.assertEqual(ts[4], '-')
-                    self.assertEqual(ts[10], ' ')
-                    self.assertTrue(ts[11:13].isdigit())
-                    self.assertEqual(ts[13], ':')
-                self.assertEqual(data, expect)
-                data = query(
-                    'select oid,* from "%s"' % table).dictresult()[0]
-                self.assertEqual(data['oid'], oid)
-                data = dict(item for item in data.items()
-                    if item[0] in expect)
-                self.assertEqual(data, expect)
-                query('delete from "%s"' % table)
-            query('drop table "%s"' % table)
+                    self.assertEqual(len(ts), 19)
+                self.assertTrue(ts[:4].isdigit())
+                self.assertEqual(ts[4], '-')
+                self.assertEqual(ts[10], ' ')
+                self.assertTrue(ts[11:13].isdigit())
+                self.assertEqual(ts[13], ':')
+            self.assertEqual(data, expect)
+            data = query(
+                'select oid,* from "%s"' % table).dictresult()[0]
+            self.assertEqual(data['oid'], oid)
+            data = dict(item for item in data.items()
+                if item[0] in expect)
+            self.assertEqual(data, expect)
+            query('delete from "%s"' % table)
+        query('drop table "%s"' % table)
+
+    def testInsertWithQuotedNames(self):
+        insert = self.db.insert
+        query = self.db.query
+        table = 'test table for insert()'
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
+            '"Prime!" smallint primary key,'
+            '"much space" integer, "Questions?" text)' % table)
+        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
+        r = insert(table, r)
+        self.assertIsInstance(r, dict)
+        self.assertEqual(r['Prime!'], 11)
+        self.assertEqual(r['much space'], 2002)
+        self.assertEqual(r['Questions?'], 'What?')
+        r = query('select * from "%s" limit 2' % table).dictresult()
+        self.assertEqual(len(r), 1)
+        r = r[0]
+        self.assertEqual(r['Prime!'], 11)
+        self.assertEqual(r['much space'], 2002)
+        self.assertEqual(r['Questions?'], 'What?')
+        query('drop table "%s"' % table)
 
     def testUpdate(self):
         update = self.db.update
         query = self.db.query
-        for table in ('update_test_table', 'test table for update'):
-            query('drop table if exists "%s"' % table)
-            query('create table "%s" ('
-                "n integer, t text) with oids" % table)
-            for n, t in enumerate('xyz'):
-                query('insert into "%s" values('
-                    "%d, '%s')" % (table, n + 1, t))
-            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
-            r = self.db.get(table, 2, 'n')
-            r['t'] = 'u'
-            s = update(table, r)
-            self.assertEqual(s, r)
-            r = query('select t from "%s" where n=2' % table
-                      ).getresult()[0][0]
-            self.assertEqual(r, 'u')
-            query('drop table "%s"' % table)
+        table = 'update_test_table'
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
+            "n integer, t text) with oids" % table)
+        for n, t in enumerate('xyz'):
+            query('insert into "%s" values('
+                "%d, '%s')" % (table, n + 1, t))
+        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
+        r = self.db.get(table, 2, 'n')
+        r['t'] = 'u'
+        s = update(table, r)
+        self.assertEqual(s, r)
+        r = query('select t from "%s" where n=2' % table
+                  ).getresult()[0][0]
+        self.assertEqual(r, 'u')
+        query('drop table "%s"' % table)
 
     def testUpdateWithCompositeKey(self):
         update = self.db.update
         query = self.db.query
         table = 'update_test_table_1'
-        query("drop table if exists %s" % table)
-        query("create table %s ("
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
             "n integer, t text, primary key (n))" % table)
         for n, t in enumerate('abc'):
-            query("insert into %s values("
+            query('insert into "%s" values('
                 "%d, '%s')" % (table, n + 1, t))
         self.assertRaises(pg.ProgrammingError, update,
                           table, dict(t='b'))
@@ -849,15 +888,15 @@
         r = query('select t from "%s" where n=2' % table
                   ).getresult()[0][0]
         self.assertEqual(r, 'd')
-        query("drop table %s" % table)
+        query('drop table "%s"' % table)
         table = 'update_test_table_2'
-        query("drop table if exists %s" % table)
-        query("create table %s ("
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
             "n integer, m integer, t text, primary key (n, m))" % table)
         for n in range(3):
             for m in range(2):
                 t = chr(ord('a') + 2 * n + m)
-                query("insert into %s values("
+                query('insert into "%s" values('
                     "%d, %d, '%s')" % (table, n + 1, m + 1, t))
         self.assertRaises(pg.ProgrammingError, update,
                           table, dict(n=2, t='b'))
@@ -866,66 +905,105 @@
         r = [r[0] for r in query('select t from "%s" where n=2'
             ' order by m' % table).getresult()]
         self.assertEqual(r, ['c', 'x'])
-        query("drop table %s" % table)
+        query('drop table "%s"' % table)
+
+    def testUpdateWithQuotedNames(self):
+        update = self.db.update
+        query = self.db.query
+        table = 'test table for update()'
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
+            '"Prime!" smallint primary key,'
+            '"much space" integer, "Questions?" text)' % table)
+        query('insert into "%s"'
+              " values(13, 3003, 'Why!')" % table)
+        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
+        r = update(table, r)
+        self.assertIsInstance(r, dict)
+        self.assertEqual(r['Prime!'], 13)
+        self.assertEqual(r['much space'], 7007)
+        self.assertEqual(r['Questions?'], 'When?')
+        r = query('select * from "%s" limit 2' % table).dictresult()
+        self.assertEqual(len(r), 1)
+        r = r[0]
+        self.assertEqual(r['Prime!'], 13)
+        self.assertEqual(r['much space'], 7007)
+        self.assertEqual(r['Questions?'], 'When?')
+        query('drop table "%s"' % table)
 
     def testClear(self):
         clear = self.db.clear
         query = self.db.query
         f = False if pg.get_bool() else 'f'
-        for table in ('clear_test_table', 'test table for clear'):
-            query('drop table if exists "%s"' % table)
-            query('create table "%s" ('
-                "n integer, b boolean, d date, t text)" % table)
-            r = clear(table)
-            result = {'n': 0, 'b': f, 'd': '', 't': ''}
-            self.assertEqual(r, result)
-            r['a'] = r['n'] = 1
-            r['d'] = r['t'] = 'x'
-            r['b'] = 't'
-            r['oid'] = long(1)
-            r = clear(table, r)
-            result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
-                'oid': long(1)}
-            self.assertEqual(r, result)
-            query('drop table "%s"' % table)
+        table = 'clear_test_table'
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
+            "n integer, b boolean, d date, t text)" % table)
+        r = clear(table)
+        result = {'n': 0, 'b': f, 'd': '', 't': ''}
+        self.assertEqual(r, result)
+        r['a'] = r['n'] = 1
+        r['d'] = r['t'] = 'x'
+        r['b'] = 't'
+        r['oid'] = long(1)
+        r = clear(table, r)
+        result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
+            'oid': long(1)}
+        self.assertEqual(r, result)
+        query('drop table "%s"' % table)
+
+    def testClearWithQuotedNames(self):
+        clear = self.db.clear
+        query = self.db.query
+        table = 'test table for clear()'
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
+            '"Prime!" smallint primary key,'
+            '"much space" integer, "Questions?" text)' % table)
+        r = clear(table)
+        self.assertIsInstance(r, dict)
+        self.assertEqual(r['Prime!'], 0)
+        self.assertEqual(r['much space'], 0)
+        self.assertEqual(r['Questions?'], '')
+        query('drop table "%s"' % table)
 
     def testDelete(self):
         delete = self.db.delete
         query = self.db.query
-        for table in ('delete_test_table', 'test table for delete'):
-            query('drop table if exists "%s"' % table)
-            query('create table "%s" ('
-                "n integer, t text) with oids" % table)
-            for n, t in enumerate('xyz'):
-                query('insert into "%s" values('
-                    "%d, '%s')" % (table, n + 1, t))
-            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
-            r = self.db.get(table, 1, 'n')
-            s = delete(table, r)
-            self.assertEqual(s, 1)
-            r = self.db.get(table, 3, 'n')
-            s = delete(table, r)
-            self.assertEqual(s, 1)
-            s = delete(table, r)
-            self.assertEqual(s, 0)
-            r = query('select * from "%s"' % table).dictresult()
-            self.assertEqual(len(r), 1)
-            r = r[0]
-            result = {'n': 2, 't': 'y'}
-            self.assertEqual(r, result)
-            r = self.db.get(table, 2, 'n')
-            s = delete(table, r)
-            self.assertEqual(s, 1)
-            s = delete(table, r)
-            self.assertEqual(s, 0)
-            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
-            query('drop table "%s"' % table)
+        table = 'delete_test_table'
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
+            "n integer, t text) with oids" % table)
+        for n, t in enumerate('xyz'):
+            query('insert into "%s" values('
+                "%d, '%s')" % (table, n + 1, t))
+        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
+        r = self.db.get(table, 1, 'n')
+        s = delete(table, r)
+        self.assertEqual(s, 1)
+        r = self.db.get(table, 3, 'n')
+        s = delete(table, r)
+        self.assertEqual(s, 1)
+        s = delete(table, r)
+        self.assertEqual(s, 0)
+        r = query('select * from "%s"' % table).dictresult()
+        self.assertEqual(len(r), 1)
+        r = r[0]
+        result = {'n': 2, 't': 'y'}
+        self.assertEqual(r, result)
+        r = self.db.get(table, 2, 'n')
+        s = delete(table, r)
+        self.assertEqual(s, 1)
+        s = delete(table, r)
+        self.assertEqual(s, 0)
+        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
+        query('drop table "%s"' % table)
 
     def testDeleteWithCompositeKey(self):
         query = self.db.query
         table = 'delete_test_table_1'
-        query("drop table if exists %s" % table)
-        query("create table %s ("
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
             "n integer, t text, primary key (n))" % table)
         for n, t in enumerate('abc'):
             query("insert into %s values("
@@ -940,15 +1018,15 @@
         r = query('select t from "%s" where n=3' % table
                   ).getresult()[0][0]
         self.assertEqual(r, 'c')
-        query("drop table %s" % table)
+        query('drop table "%s"' % table)
         table = 'delete_test_table_2'
-        query("drop table if exists %s" % table)
-        query("create table %s ("
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
             "n integer, m integer, t text, primary key (n, m))" % table)
         for n in range(3):
             for m in range(2):
                 t = chr(ord('a') + 2 * n + m)
-                query("insert into %s values("
+                query('insert into "%s" values('
                     "%d, %d, '%s')" % (table, n + 1, m + 1, t))
         self.assertRaises(pg.ProgrammingError, self.db.delete,
             table, dict(n=2, t='b'))
@@ -964,7 +1042,29 @@
         r = [r[0] for r in query('select t from "%s" where n=3'
             ' order by m' % table).getresult()]
         self.assertEqual(r, ['f'])
-        query("drop table %s" % table)
+        query('drop table "%s"' % table)
+
+    def testDeleteWithQuotedNames(self):
+        delete = self.db.delete
+        query = self.db.query
+        table = 'test table for delete()'
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" ('
+            '"Prime!" smallint primary key,'
+            '"much space" integer, "Questions?" text)' % table)
+        query('insert into "%s"'
+              " values(19, 5005, 'Yes!')" % table)
+        r = {'Prime!': 17}
+        r = delete(table, r)
+        self.assertEqual(r, 0)
+        r = query('select count(*) from "%s"' % table).getresult()
+        self.assertEqual(r[0][0], 1)
+        r = {'Prime!': 19}
+        r = delete(table, r)
+        self.assertEqual(r, 1)
+        r = query('select count(*) from "%s"' % table).getresult()
+        self.assertEqual(r[0][0], 0)
+        query('drop table "%s"' % table)
 
     def testTransaction(self):
         query = self.db.query
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to