Author: cito
Date: Thu Jan 28 15:07:41 2016
New Revision: 793

Log:
Improve quoting and typecasting in the pg module

Larger refactoring of the code for adapting and typecasting in the pg module.
Things are now a lot cleaner and clearer.

The _Adapt class is responsible for all adapting of Python objects to their
PostgreSQL equivalents when sending data to the database. The typecasting
from PostgreSQL on output happens in the C module, except for the typecasting
of records which is new and provided by the _CastRecord class.

The classic module also did not work properly when regular type names were
switched on with use_regtypes(True), since the adapting of types relied on
the PyGreSQL type names. This has been solved by adding a new _PgType class
that is essentially the old type name, but augmented with all the necessary
information necessary to adapt types, particularly record types.

All tests in test_classic_dbwrapper now run twice, using opposite settings
for the various configuration settings like use_bool() or use_regtypes(),
in order to make sure that no internal functions rely on default settings.

Modified:
   trunk/pg.py
   trunk/tests/test_classic_dbwrapper.py
   trunk/tests/test_classic_functions.py

Modified: trunk/pg.py
==============================================================================
--- trunk/pg.py Thu Jan 28 14:54:34 2016        (r792)
+++ trunk/pg.py Thu Jan 28 15:07:41 2016        (r793)
@@ -166,6 +166,311 @@
 _simpletype = _SimpleType()
 
 
+class _Adapt:
+    """Mixin providing methods for adapting records and record elements.
+
+    This is used when passing values from one of the higher level DB
+    methods as parameters for a query.
+
+    This class must be mixed in to a connection class, because it needs
+    connection specific methods such as escape_bytea().
+    """
+
+    _bool_true_values = frozenset('t true 1 y yes on'.split())
+
+    _date_literals = frozenset('current_date current_time'
+        ' current_timestamp localtime localtimestamp'.split())
+
+    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
+    _re_record_quote = regex(r'[(,"\\]')
+    _re_array_escape = _re_record_escape = regex(r'(["\\])')
+
+    @classmethod
+    def _adapt_bool(cls, v):
+        """Adapt a boolean parameter."""
+        if isinstance(v, basestring):
+            if not v:
+                return None
+            v = v.lower() in cls._bool_true_values
+        return 't' if v else 'f'
+
+    @classmethod
+    def _adapt_date(cls, v):
+        """Adapt a date parameter."""
+        if not v:
+            return None
+        if isinstance(v, basestring) and v.lower() in cls._date_literals:
+            return _Literal(v)
+        return v
+
+    @staticmethod
+    def _adapt_num(v):
+        """Adapt a numeric parameter."""
+        if not v and v != 0:
+            return None
+        return v
+
+    _adapt_int = _adapt_float = _adapt_money = _adapt_num
+
+    def _adapt_bytea(self, v):
+        """Adapt a bytea parameter."""
+        return self.escape_bytea(v)
+
+    def _adapt_json(self, v):
+        """Adapt a json parameter."""
+        if not v:
+            return None
+        if isinstance(v, basestring):
+            return v
+        return self.encode_json(v)
+
+    @classmethod
+    def _adapt_text_array(cls, v):
+        """Adapt a text type array parameter."""
+        if isinstance(v, list):
+            adapt = cls._adapt_text_array
+            return '{%s}' % ','.join(adapt(v) for v in v)
+        if v is None:
+            return 'null'
+        if not v:
+            return '""'
+        v = str(v)
+        if cls._re_array_quote.search(v):
+            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
+        return v
+
+    _adapt_date_array = _adapt_text_array
+
+    @classmethod
+    def _adapt_bool_array(cls, v):
+        """Adapt a boolean array parameter."""
+        if isinstance(v, list):
+            adapt = cls._adapt_bool_array
+            return '{%s}' % ','.join(adapt(v) for v in v)
+        if v is None:
+            return 'null'
+        if isinstance(v, basestring):
+            if not v:
+                return 'null'
+            v = v.lower() in cls._bool_true_values
+        return 't' if v else 'f'
+
+    @classmethod
+    def _adapt_num_array(cls, v):
+        """Adapt a numeric array parameter."""
+        if isinstance(v, list):
+            adapt = cls._adapt_num_array
+            return '{%s}' % ','.join(adapt(v) for v in v)
+        if not v and v != 0:
+            return 'null'
+        return str(v)
+
+    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
+            _adapt_num_array
+
+    def _adapt_bytea_array(self, v):
+        """Adapt a bytea array parameter."""
+        if isinstance(v, list):
+            return b'{' + b','.join(
+                self._adapt_bytea_array(v) for v in v) + b'}'
+        if v is None:
+            return b'null'
+        return self.escape_bytea(v).replace(b'\\', b'\\\\')
+
+    def _adapt_json_array(self, v):
+        """Adapt a json array parameter."""
+        if isinstance(v, list):
+            adapt = self._adapt_json_array
+            return '{%s}' % ','.join(adapt(v) for v in v)
+        if not v:
+            return 'null'
+        if not isinstance(v, basestring):
+            v = self.encode_json(v)
+        if self._re_array_quote.search(v):
+            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
+        return v
+
+    def _adapt_record(self, v, typ):
+        """Adapt a record parameter with given type."""
+        typ = typ.attnames.values()
+        if len(typ) != len(v):
+            raise TypeError('Record parameter %s has wrong size' % v)
+        return '(%s)' % ','.join(getattr(self,
+            '_adapt_record_%s' % t.simple)(v) for v, t in zip(v, typ))
+
+    @classmethod
+    def _adapt_record_text(cls, v):
+        """Adapt a text type record component."""
+        if v is None:
+            return ''
+        if not v:
+            return '""'
+        v = str(v)
+        if cls._re_record_quote.search(v):
+            v = '"%s"' % cls._re_record_escape.sub(r'\\\1', v)
+        return v
+
+    _adapt_record_date = _adapt_record_text
+
+    @classmethod
+    def _adapt_record_bool(cls, v):
+        """Adapt a boolean record component."""
+        if v is None:
+            return ''
+        if isinstance(v, basestring):
+            if not v:
+                return ''
+            v = v.lower() in cls._bool_true_values
+        return 't' if v else 'f'
+
+    @staticmethod
+    def _adapt_record_num(v):
+        """Adapt a numeric record component."""
+        if not v and v != 0:
+            return ''
+        return str(v)
+
+    _adapt_record_int = _adapt_record_float = _adapt_record_money = \
+        _adapt_record_num
+
+    def _adapt_record_bytea(self, v):
+        if v is None:
+            return ''
+        v = self.escape_bytea(v)
+        if bytes is not str and isinstance(v, bytes):
+            v = v.decode('ascii')
+        return v.replace('\\', '\\\\')
+
+    def _adapt_record_json(self, v):
+        """Adapt a bytea record component."""
+        if not v:
+            return ''
+        if not isinstance(v, basestring):
+            v = self.encode_json(v)
+        if self._re_array_quote.search(v):
+            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
+        return v
+
+    def _adapt_param(self, value, typ, params):
+        """Adapt and add a parameter to the list."""
+        if isinstance(value, _Literal):
+            return value
+        if value is not None:
+            simple = typ.simple
+            if simple == 'text':
+                pass
+            elif simple == 'record':
+                if isinstance(value, tuple):
+                    value = self._adapt_record(value, typ)
+            elif simple.endswith('[]'):
+                if isinstance(value, list):
+                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
+                    value = adapt(value)
+            else:
+                adapt = getattr(self, '_adapt_%s' % simple)
+                value = adapt(value)
+                if isinstance(value, _Literal):
+                    return value
+        params.append(value)
+        return '$%d' % len(params)
+
+
+class _CastRecord:
+    """Class providing methods for casting records and record elements.
+
+    This is needed when getting result values from one of the higher level DB
+    methods, since the lower level query method only casts the other types.
+    """
+
+    @staticmethod
+    def cast_bool(v):
+        if not get_bool():
+            return v
+        return v[0] == 't'
+
+    @staticmethod
+    def cast_bytea(v):
+        return unescape_bytea(v)
+
+    @staticmethod
+    def cast_float(v):
+        return float(v)
+
+    @staticmethod
+    def cast_int(v):
+        return int(v)
+
+    @staticmethod
+    def cast_json(v):
+        cast = get_jsondecode()
+        if not cast:
+            return v
+        return cast(v)
+
+    @staticmethod
+    def cast_num(v):
+        return (get_decimal() or float)(v)
+
+    @staticmethod
+    def cast_money(v):
+        point = get_decimal_point()
+        if not point:
+            return v
+        if point != '.':
+            v = v.replace(point, '.')
+        v = v.replace('(', '-')
+        v = ''.join(c for c in v if c.isdigit() or c in '.-')
+        return (get_decimal() or float)(v)
+
+    @classmethod
+    def cast(cls, v, typ):
+        types = typ.attnames.values()
+        cast = [getattr(cls, 'cast_%s' % t.simple, None) for t in types]
+        v = cast_record(v, cast)
+        return typ.namedtuple(*v)
+
+
+class _PgType(str):
+    """Class augmenting the simple type name with additional info."""
+
+    _num_types = frozenset('int float num money'
+        ' int2 int4 int8 float4 float8 numeric money'.split())
+
+    @classmethod
+    def create(cls, db, pgtype, regtype, typrelid):
+        """Create a PostgreSQL type name with additional info."""
+        simple = 'record' if typrelid else _simpletype[pgtype]
+        self = cls(regtype if db._regtypes else simple)
+        self.db = db
+        self.simple = simple
+        self.pgtype = pgtype
+        self.regtype = regtype
+        self.typrelid = typrelid
+        self._attnames = self._namedtuple = None
+        return self
+
+    @property
+    def attnames(self):
+        """Get names and types of the fields of a composite type."""
+        if not self.typrelid:
+            return None
+        if not self._attnames:
+            self._attnames = self.db.get_attnames(self.typrelid)
+        return self._attnames
+
+    @property
+    def namedtuple(self):
+        """Return named tuple class representing a composite type."""
+        if not self._namedtuple:
+            self._namedtuple = namedtuple(self, self.attnames)
+        return self._namedtuple
+
+    def cast(self, value):
+        if value is not None and self.typrelid:
+            value = _CastRecord.cast(value, self)
+        return value
+
+
 class _Literal(str):
     """Wrapper class for literal SQL."""
 
@@ -349,7 +654,7 @@
 
 # The actual PostGreSQL database connection interface:
 
-class DB(object):
+class DB(_Adapt):
     """Wrapper class for the _pg connection type."""
 
     def __init__(self, *args, **kw):
@@ -451,146 +756,12 @@
         """Get boolean value corresponding to d."""
         return bool(d) if get_bool() else ('t' if d else 'f')
 
-    _bool_true_values = frozenset('t true 1 y yes on'.split())
-
-    def _prepare_bool(self, d):
-        """Prepare a boolean parameter."""
-        if isinstance(d, basestring):
-            if not d:
-                return None
-            d = d.lower() in self._bool_true_values
-        return 't' if d else 'f'
-
-    _date_literals = frozenset('current_date current_time'
-        ' current_timestamp localtime localtimestamp'.split())
-
-    def _prepare_date(self, d):
-        """Prepare a date parameter."""
-        if not d:
-            return None
-        if isinstance(d, basestring) and d.lower() in self._date_literals:
-            return _Literal(d)
-        return d
-
-    _num_types = frozenset('int float num money'
-        ' int2 int4 int8 float4 float8 numeric money'.split())
-
-    @staticmethod
-    def _prepare_num(d):
-        """Prepare a numeric parameter."""
-        if not d and d != 0:
-            return None
-        return d
-
-    _prepare_int = _prepare_float = _prepare_money = _prepare_num
-
-    def _prepare_bytea(self, d):
-        """Prepare a bytea parameter."""
-        return self.escape_bytea(d)
-
-    def _prepare_json(self, d):
-        """Prepare a json parameter."""
-        if not d:
-            return None
-        if isinstance(d, basestring):
-            return d
-        return self.encode_json(d)
-
-    _re_array_escape = regex(r'(["\\])')
-    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
-
-    def _prepare_bool_array(self, d):
-        """Prepare a bool array parameter."""
-        if isinstance(d, list):
-            return '{%s}' % ','.join(self._prepare_bool_array(v) for v in d)
-        if d is None:
-            return 'null'
-        if isinstance(d, basestring):
-            if not d:
-                return 'null'
-            d = d.lower() in self._bool_true_values
-        return 't' if d else 'f'
-
-    def _prepare_num_array(self, d):
-        """Prepare a numeric array parameter."""
-        if isinstance(d, list):
-            return '{%s}' % ','.join(self._prepare_num_array(v) for v in d)
-        if not d and d != 0:
-            return 'null'
-        return str(d)
-
-    _prepare_int_array = _prepare_float_array = _prepare_money_array = \
-            _prepare_num_array
-
-    def _prepare_text_array(self, d):
-        """Prepare a text array parameter."""
-        if isinstance(d, list):
-            return '{%s}' % ','.join(self._prepare_text_array(v) for v in d)
-        if d is None:
-            return 'null'
-        if not d:
-            return '""'
-        d = str(d)
-        if self._re_array_quote.search(d):
-            d = '"%s"' % self._re_array_escape.sub(r'\\\1', d)
-        return d
-
-    def _prepare_bytea_array(self, d):
-        """Prepare a bytea array parameter."""
-        if isinstance(d, list):
-            return b'{' + b','.join(
-                self._prepare_bytea_array(v) for v in d) + b'}'
-        if d is None:
-            return b'null'
-        return self.escape_bytea(d).replace(b'\\', b'\\\\')
-
-    def _prepare_json_array(self, d):
-        """Prepare a json array parameter."""
-        if isinstance(d, list):
-            return '{%s}' % ','.join(self._prepare_json_array(v) for v in d)
-        if not d:
-            return 'null'
-        if not isinstance(d, basestring):
-            d = self.encode_json(d)
-        if self._re_array_quote.search(d):
-            d = '"%s"' % self._re_array_escape.sub(r'\\\1', d)
-        return d
-
-    def _prepare_param(self, value, typ, params):
-        """Prepare and add a parameter to the list."""
-        if isinstance(value, _Literal):
-            return value
-        if value is not None and typ != 'text':
-            if typ.endswith('[]'):
-                if isinstance(value, list):
-                    prepare = getattr(self, '_prepare_%s_array' % typ[:-2])
-                    value = prepare(value)
-                elif isinstance(value, basestring):
-                    value = value.strip()
-                    if not value.startswith('{') or not value.endswith('}'):
-                        if value[:5].lower() == 'array':
-                            value = value[5:].lstrip()
-                        if value.startswith('[') and value.endswith(']'):
-                            value = _Literal('ARRAY%s' % value)
-                        else:
-                            raise ValueError(
-                                'Invalid array expression: %s' % value)
-                else:
-                    raise ValueError('Invalid array parameter: %s' % value)
-            else:
-                prepare = getattr(self, '_prepare_%s' % typ)
-                value = prepare(value)
-            if isinstance(value, _Literal):
-                return value
-        params.append(value)
-        return '$%d' % len(params)
-
     def _list_params(self, params):
         """Create a human readable parameter list."""
         return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
 
     @staticmethod
-    def _prepare_qualified_param(name, param):
+    def _adapt_qualified_param(name, param):
         """Quote parameter representing a qualified name.
 
         Escapes the name for use as an SQL parameter, unless the
@@ -601,7 +772,7 @@
         """
         if isinstance(param, int):
             param = "$%d" % param
-        if '.' not in name:
+        if isinstance(name, basestring) and '.' not in name:
             param = 'quote_ident(%s)' % (param,)
         return param
 
@@ -869,7 +1040,7 @@
                 " AND NOT a.attisdropped"
                 " WHERE i.indrelid=%s::regclass"
                 " AND i.indisprimary ORDER BY a.attnum") % (
-                    self._prepare_qualified_param(table, 1),)
+                    self._adapt_qualified_param(table, 1),)
             pkey = self.db.query(q, (table,)).getresult()
             if not pkey:
                 raise KeyError('Table %s has no primary key' % table)
@@ -933,17 +1104,16 @@
         try:  # cache lookup
             names = attnames[table]
         except KeyError:  # cache miss, check the database
-            q = ("SELECT a.attname, t.typname%s"
+            q = ("SELECT a.attname, t.typname, t.typname::regtype, t.typrelid"
                 " FROM pg_attribute a"
                 " JOIN pg_type t ON t.oid = a.atttypid"
                 " WHERE a.attrelid = %s::regclass"
                 " AND (a.attnum > 0 OR a.attname = 'oid')"
                 " AND NOT a.attisdropped ORDER BY a.attnum") % (
-                    '::regtype' if self._regtypes else '',
-                    self._prepare_qualified_param(table, 1))
+                    self._adapt_qualified_param(table, 1))
             names = self.db.query(q, (table,)).getresult()
-            if not self._regtypes:
-                names = ((name, _simpletype[typ]) for name, typ in names)
+            names = ((name, _PgType.create(self, pgtype, regtype, typrelid))
+                for name, pgtype, regtype, typrelid in names)
             names = AttrDict(names)
             attnames[table] = names  # cache it
         return names
@@ -966,7 +1136,7 @@
             return self._privileges[(table, privilege)]
         except KeyError:  # cache miss, ask the database
             q = "SELECT has_table_privilege(%s, $2)" % (
-                self._prepare_qualified_param(table, 1),)
+                self._adapt_qualified_param(table, 1),)
             q = self.db.query(q, (table, privilege))
             ret = q.getresult()[0][0] == self._make_bool(True)
             self._privileges[(table, privilege)] = ret  # cache it
@@ -1024,7 +1194,7 @@
                     'Differing number of items in keyname and row')
             row = dict(zip(keyname, row))
         params = []
-        param = partial(self._prepare_param, params=params)
+        param = partial(self._adapt_param, params=params)
         col = self.escape_identifier
         what = 'oid, *' if qoid else '*'
         where = ' AND '.join('%s = %s' % (
@@ -1044,6 +1214,8 @@
         for n, value in res[0].items():
             if qoid and n == 'oid':
                 n = qoid
+            else:
+                value = attnames[n].cast(value)
             row[n] = value
         return row
 
@@ -1070,7 +1242,7 @@
         attnames = self.get_attnames(table)
         qoid = _oid_key(table) if 'oid' in attnames else None
         params = []
-        param = partial(self._prepare_param, params=params)
+        param = partial(self._adapt_param, params=params)
         col = self.escape_identifier
         names, values = [], []
         for n in attnames:
@@ -1090,6 +1262,8 @@
             for n, value in res[0].items():
                 if qoid and n == 'oid':
                     n = qoid
+                else:
+                    value = attnames[n].cast(value)
                 row[n] = value
         return row
 
@@ -1131,7 +1305,7 @@
                 else:
                     raise KeyError('Missing primary key in row')
         params = []
-        param = partial(self._prepare_param, params=params)
+        param = partial(self._adapt_param, params=params)
         col = self.escape_identifier
         where = ' AND '.join('%s = %s' % (
             col(k), param(row[k], attnames[k])) for k in keyname)
@@ -1157,6 +1331,8 @@
             for n, value in res[0].items():
                 if qoid and n == 'oid':
                     n = qoid
+                else:
+                    value = attnames[n].cast(value)
                 row[n] = value
         return row
 
@@ -1214,7 +1390,7 @@
         attnames = self.get_attnames(table)
         qoid = _oid_key(table) if 'oid' in attnames else None
         params = []
-        param = partial(self._prepare_param,params=params)
+        param = partial(self._adapt_param,params=params)
         col = self.escape_identifier
         names, values, updates = [], [], []
         for n in attnames:
@@ -1258,6 +1434,8 @@
             for n, value in res[0].items():
                 if qoid and n == 'oid':
                     n = qoid
+                else:
+                    value = attnames[n].cast(value)
                 row[n] = value
         else:
             self.get(table, row)
@@ -1278,7 +1456,8 @@
         for n, t in attnames.items():
             if n == 'oid':
                 continue
-            if t in self._num_types:
+            t = t.simple
+            if t in _PgType._num_types:
                 row[n] = 0
             elif t == 'bool':
                 row[n] = self._make_bool(False)
@@ -1327,7 +1506,7 @@
                 else:
                     raise KeyError('Missing primary key in row')
         params = []
-        param = partial(self._prepare_param, params=params)
+        param = partial(self._adapt_param, params=params)
         col = self.escape_identifier
         where = ' AND '.join('%s = %s' % (
             col(k), param(row[k], attnames[k])) for k in keyname)

Modified: trunk/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/tests/test_classic_dbwrapper.py       Thu Jan 28 14:54:34 2016        
(r792)
+++ trunk/tests/test_classic_dbwrapper.py       Thu Jan 28 15:07:41 2016        
(r793)
@@ -379,6 +379,8 @@
 
     cls_set_up = False
 
+    regtypes = None
+
     @classmethod
     def setUpClass(cls):
         db = DB()
@@ -401,6 +403,10 @@
     def setUp(self):
         self.assertTrue(self.cls_set_up)
         self.db = DB()
+        if self.regtypes is None:
+            self.regtypes = self.db.use_regtypes()
+        else:
+            self.db.use_regtypes(self.regtypes)
         query = self.db.query
         query('set client_encoding=utf8')
         query('set standard_conforming_strings=on')
@@ -985,18 +991,29 @@
             self.db.get_attnames, 'has.too.many.dots')
         r = get_attnames('test')
         self.assertIsInstance(r, dict)
-        self.assertEqual(r, dict(
-            i2='int', i4='int', i8='int', d='num',
-            f4='float', f8='float', m='money',
-            v4='text', c4='text', t='text'))
+        if self.regtypes:
+            self.assertEqual(r, dict(
+                i2='smallint', i4='integer', i8='bigint', d='numeric',
+                f4='real', f8='double precision', m='money',
+                v4='character varying', c4='character', t='text'))
+        else:
+            self.assertEqual(r, dict(
+                i2='int', i4='int', i8='int', d='num',
+                f4='float', f8='float', m='money',
+                v4='text', c4='text', t='text'))
         self.createTable('test_table',
             'n int, alpha smallint, beta bool,'
             ' gamma char(5), tau text, v varchar(3)')
         r = get_attnames('test_table')
         self.assertIsInstance(r, dict)
-        self.assertEqual(r, dict(
-            n='int', alpha='int', beta='bool',
-            gamma='text', tau='text', v='text'))
+        if self.regtypes:
+            self.assertEqual(r, dict(
+                n='integer', alpha='smallint', beta='boolean',
+                gamma='character', tau='text', v='character varying'))
+        else:
+            self.assertEqual(r, dict(
+                n='int', alpha='int', beta='bool',
+                gamma='text', tau='text', v='text'))
 
     def testGetAttnamesWithQuotes(self):
         get_attnames = self.db.get_attnames
@@ -1005,23 +1022,37 @@
             '"Prime!" smallint, "much space" integer, "Questions?" text')
         r = get_attnames(table)
         self.assertIsInstance(r, dict)
-        self.assertEqual(r, {
-            'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
+        if self.regtypes:
+            self.assertEqual(r, {
+                'Prime!': 'smallint', 'much space': 'integer',
+                'Questions?': 'text'})
+        else:
+            self.assertEqual(r, {
+                'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
         table = 'yet another test table for get_attnames()'
         self.createTable(table,
             'a smallint, b integer, c bigint,'
-            ' e numeric, f float, f2 double precision, m money,'
+            ' e numeric, f real, f2 double precision, m money,'
             ' x smallint, y smallint, z smallint,'
             ' Normal_NaMe smallint, "Special Name" smallint,'
             ' t text, u char(2), v varchar(2),'
             ' primary key (y, u)', oids=True)
         r = get_attnames(table)
         self.assertIsInstance(r, dict)
-        self.assertEqual(r, {'a': 'int', 'c': 'int', 'b': 'int',
-            'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
-            'normal_name': 'int', 'Special Name': 'int',
-            'u': 'text', 't': 'text', 'v': 'text',
-            'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
+        if self.regtypes:
+            self.assertEqual(r, {
+                'a': 'smallint', 'b': 'integer', 'c': 'bigint',
+                'e': 'numeric', 'f': 'real', 'f2': 'double precision',
+                'm': 'money', 'normal_name': 'smallint',
+                'Special Name': 'smallint', 'u': 'character',
+                't': 'text', 'v': 'character varying', 'y': 'smallint',
+                'x': 'smallint', 'z': 'smallint', 'oid': 'oid'})
+        else:
+            self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int',
+                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
+                'normal_name': 'int', 'Special Name': 'int',
+                'u': 'text', 't': 'text', 'v': 'text',
+                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
 
     def testGetAttnamesWithRegtypes(self):
         get_attnames = self.db.get_attnames
@@ -1030,7 +1061,7 @@
             ' gamma char(5), tau text, v varchar(3)')
         use_regtypes = self.db.use_regtypes
         regtypes = use_regtypes()
-        self.assertFalse(regtypes)
+        self.assertEqual(regtypes, self.regtypes)
         use_regtypes(True)
         try:
             r = get_attnames("test_table")
@@ -1041,27 +1072,47 @@
             n='integer', alpha='smallint', beta='boolean',
             gamma='character', tau='text', v='character varying'))
 
+    def testGetAttnamesWithoutRegtypes(self):
+        get_attnames = self.db.get_attnames
+        self.createTable('test_table',
+            ' n int, alpha smallint, beta bool,'
+            ' gamma char(5), tau text, v varchar(3)')
+        use_regtypes = self.db.use_regtypes
+        regtypes = use_regtypes()
+        self.assertEqual(regtypes, self.regtypes)
+        use_regtypes(False)
+        try:
+            r = get_attnames("test_table")
+            self.assertIsInstance(r, dict)
+        finally:
+            use_regtypes(regtypes)
+        self.assertEqual(r, dict(
+            n='int', alpha='int', beta='bool',
+            gamma='text', tau='text', v='text'))
+
     def testGetAttnamesIsCached(self):
         get_attnames = self.db.get_attnames
+        int_type = 'integer' if self.regtypes else 'int'
+        text_type = 'text'
         query = self.db.query
         self.createTable('test_table', 'col int')
         r = get_attnames("test_table")
         self.assertIsInstance(r, dict)
-        self.assertEqual(r, dict(col='int'))
+        self.assertEqual(r, dict(col=int_type))
         query("alter table test_table alter column col type text")
         query("alter table test_table add column col2 int")
         r = get_attnames("test_table")
-        self.assertEqual(r, dict(col='int'))
+        self.assertEqual(r, dict(col=int_type))
         r = get_attnames("test_table", flush=True)
-        self.assertEqual(r, dict(col='text', col2='int'))
+        self.assertEqual(r, dict(col=text_type, col2=int_type))
         query("alter table test_table drop column col2")
         r = get_attnames("test_table")
-        self.assertEqual(r, dict(col='text', col2='int'))
+        self.assertEqual(r, dict(col=text_type, col2=int_type))
         r = get_attnames("test_table", flush=True)
-        self.assertEqual(r, dict(col='text'))
+        self.assertEqual(r, dict(col=text_type))
         query("alter table test_table drop column col")
         r = get_attnames("test_table")
-        self.assertEqual(r, dict(col='text'))
+        self.assertEqual(r, dict(col=text_type))
         r = get_attnames("test_table", flush=True)
         self.assertEqual(r, dict())
 
@@ -1069,10 +1120,17 @@
         get_attnames = self.db.get_attnames
         r = get_attnames('test', flush=True)
         self.assertIsInstance(r, OrderedDict)
-        self.assertEqual(r, OrderedDict([
-            ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
-            ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
-            ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
+        if self.regtypes:
+            self.assertEqual(r, OrderedDict([
+                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
+                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
+                ('m', 'money'), ('v4', 'character varying'),
+                ('c4', 'character'), ('t', 'text')]))
+        else:
+            self.assertEqual(r, OrderedDict([
+                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
+                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
+                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
         if OrderedDict is not dict:
             r = ' '.join(list(r.keys()))
             self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
@@ -1082,9 +1140,15 @@
             ' gamma char(5), tau text, beta bool')
         r = get_attnames(table)
         self.assertIsInstance(r, OrderedDict)
-        self.assertEqual(r, OrderedDict([
-            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
-            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
+        if self.regtypes:
+            self.assertEqual(r, OrderedDict([
+                ('n', 'integer'), ('alpha', 'smallint'),
+                ('v', 'character varying'), ('gamma', 'character'),
+                ('tau', 'text'), ('beta', 'boolean')]))
+        else:
+            self.assertEqual(r, OrderedDict([
+                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
+                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
         if OrderedDict is not dict:
             r = ' '.join(list(r.keys()))
             self.assertEqual(r, 'n alpha v gamma tau beta')
@@ -1096,10 +1160,17 @@
         get_attnames = self.db.get_attnames
         r = get_attnames('test', flush=True)
         self.assertIsInstance(r, AttrDict)
-        self.assertEqual(r, AttrDict([
-            ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
-            ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
-            ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
+        if self.regtypes:
+            self.assertEqual(r, AttrDict([
+                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
+                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
+                ('m', 'money'), ('v4', 'character varying'),
+                ('c4', 'character'), ('t', 'text')]))
+        else:
+            self.assertEqual(r, AttrDict([
+                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
+                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
+                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
         r = ' '.join(list(r.keys()))
         self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
         table = 'test table for get_attnames'
@@ -1108,9 +1179,15 @@
             ' gamma char(5), tau text, beta bool')
         r = get_attnames(table)
         self.assertIsInstance(r, AttrDict)
-        self.assertEqual(r, AttrDict([
-            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
-            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
+        if self.regtypes:
+            self.assertEqual(r, AttrDict([
+                ('n', 'integer'), ('alpha', 'smallint'),
+                ('v', 'character varying'), ('gamma', 'character'),
+                ('tau', 'text'), ('beta', 'boolean')]))
+        else:
+            self.assertEqual(r, AttrDict([
+                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
+                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
         r = ' '.join(list(r.keys()))
         self.assertEqual(r, 'n alpha v gamma tau beta')
 
@@ -2969,9 +3046,17 @@
             ' d numeric[], f4 real[], f8 double precision[], m money[],'
             ' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]')
         r = self.db.get_attnames('arraytest')
-        self.assertEqual(r, dict(id='int', i2='int[]', i4='int[]', i8='int[]',
-            d='num[]', f4='float[]', f8='float[]', m='money[]',
-            b='bool[]', v4='text[]', c4='text[]', t='text[]'))
+        if self.regtypes:
+            self.assertEqual(r, dict(
+                id='smallint', i2='smallint[]', i4='integer[]', i8='bigint[]',
+                d='numeric[]', f4='real[]', f8='double precision[]',
+                m='money[]', b='boolean[]',
+                v4='character varying[]', c4='character[]', t='text[]'))
+        else:
+            self.assertEqual(r, dict(
+                id='int', i2='int[]', i4='int[]', i8='int[]',
+                d='num[]', f4='float[]', f8='float[]', m='money[]',
+                b='bool[]', v4='text[]', c4='text[]', t='text[]'))
         decimal = pg.get_decimal()
         if decimal is Decimal:
             long_decimal = decimal('123456789.123456789')
@@ -3004,7 +3089,7 @@
         r = self.db.query('select * from arraytest limit 1').dictresult()[0]
         self.assertEqual(r, data)
 
-    def testArrayInput(self):
+    def testArrayLiteral(self):
         insert = self.db.insert
         self.createTable('arraytest', 'i int[], t text[]', oids=True)
         r = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
@@ -3015,25 +3100,23 @@
         self.db.insert('arraytest', r)
         self.assertEqual(r['i'], [1, 2, 3])
         self.assertEqual(r['t'], ['a', 'b', 'c'])
-        r = dict(i="[1, 2, 3]", t="['a', 'b', 'c']")
-        self.db.insert('arraytest', r)
-        self.assertEqual(r['i'], [1, 2, 3])
-        self.assertEqual(r['t'], ['a', 'b', 'c'])
-        r = dict(i="array[1, 2, 3]", t="array['a', 'b', 'c']")
-        self.db.insert('arraytest', r)
-        self.assertEqual(r['i'], [1, 2, 3])
-        self.assertEqual(r['t'], ['a', 'b', 'c'])
-        r = dict(i="ARRAY[1, 2, 3]", t="ARRAY['a', 'b', 'c']")
+        L = pg._Literal
+        r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']"))
         self.db.insert('arraytest', r)
         self.assertEqual(r['i'], [1, 2, 3])
         self.assertEqual(r['t'], ['a', 'b', 'c'])
         r = dict(i="1, 2, 3", t="'a', 'b', 'c'")
-        self.assertRaises(ValueError, self.db.insert, 'arraytest', r)
+        self.assertRaises(pg.ProgrammingError, self.db.insert, 'arraytest', r)
 
     def testArrayOfIds(self):
         self.createTable('arraytest', 'c cid[], o oid[], x xid[]', oids=True)
         r = self.db.get_attnames('arraytest')
-        self.assertEqual(r, dict(oid='int', c='int[]', o='int[]', x='int[]'))
+        if self.regtypes:
+            self.assertEqual(r, dict(
+                oid='oid', c='cid[]', o='oid[]', x='xid[]'))
+        else:
+            self.assertEqual(r, dict(
+                oid='int', c='int[]', o='int[]', x='int[]'))
         data = dict(c=[11, 12, 13], o=[21, 22, 23], x=[31, 32, 33])
         r = data.copy()
         self.db.insert('arraytest', r)
@@ -3125,7 +3208,7 @@
                 self.skipTest('database does not support jsonb')
             self.fail(str(error))
         r = self.db.get_attnames('arraytest')
-        self.assertEqual(r['data'], 'json[]')
+        self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]')
         data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
         jsondecode = pg.get_jsondecode()
         r = dict(data=data)
@@ -3166,6 +3249,170 @@
         self.db.get('arraytest', r)
         self.assertEqual(r['data'], data)
 
+    def testInsertUpdateGetRecord(self):
+        query = self.db.query
+        query('create type test_person_type as'
+            ' (name varchar, age smallint, married bool,'
+              ' weight real, salary money)')
+        self.addCleanup(query, 'drop type test_person_type')
+        self.createTable('test_person', 'person test_person_type',
+            temporary=False, oids=True)
+        attnames = self.db.get_attnames('test_person')
+        self.assertEqual(len(attnames), 2)
+        self.assertIn('oid', attnames)
+        self.assertIn('person', attnames)
+        person_typ = attnames['person']
+        if self.regtypes:
+            self.assertEqual(person_typ, 'test_person_type')
+        else:
+            self.assertEqual(person_typ, 'record')
+        if self.regtypes:
+            self.assertEqual(person_typ.attnames,
+                dict(name='character varying', age='smallint',
+                    married='boolean', weight='real', salary='money'))
+        else:
+            self.assertEqual(person_typ.attnames,
+                dict(name='text', age='int', married='bool',
+                    weight='float', salary='money'))
+        decimal = pg.get_decimal()
+        if pg.get_bool():
+            bool_class = bool
+            t, f = True, False
+        else:
+            bool_class = str
+            t, f = 't', 'f'
+        person = ('John Doe', 61, t, 99.5, decimal('93456.75'))
+        r = self.db.insert('test_person', None, person=person)
+        p = r['person']
+        self.assertIsInstance(p, tuple)
+        self.assertEqual(p, person)
+        self.assertEqual(p.name, 'John Doe')
+        self.assertIsInstance(p.name, str)
+        self.assertIsInstance(p.age, int)
+        self.assertIsInstance(p.married, bool_class)
+        self.assertIsInstance(p.weight, float)
+        self.assertIsInstance(p.salary, decimal)
+        person = ('Jane Roe', 59, f, 64.5, decimal('96543.25'))
+        r['person'] = person
+        self.db.update('test_person', r)
+        p = r['person']
+        self.assertIsInstance(p, tuple)
+        self.assertEqual(p, person)
+        self.assertEqual(p.name, 'Jane Roe')
+        self.assertIsInstance(p.name, str)
+        self.assertIsInstance(p.age, int)
+        self.assertIsInstance(p.married, bool_class)
+        self.assertIsInstance(p.weight, float)
+        self.assertIsInstance(p.salary, decimal)
+        r['person'] = None
+        self.db.get('test_person', r)
+        p = r['person']
+        self.assertIsInstance(p, tuple)
+        self.assertEqual(p, person)
+        self.assertEqual(p.name, 'Jane Roe')
+        self.assertIsInstance(p.name, str)
+        self.assertIsInstance(p.age, int)
+        self.assertIsInstance(p.married, bool_class)
+        self.assertIsInstance(p.weight, float)
+        self.assertIsInstance(p.salary, decimal)
+        person = (None,) * 5
+        r = self.db.insert('test_person', None, person=person)
+        p = r['person']
+        self.assertIsInstance(p, tuple)
+        self.assertIsNone(p.name)
+        self.assertIsNone(p.age)
+        self.assertIsNone(p.married)
+        self.assertIsNone(p.weight)
+        self.assertIsNone(p.salary)
+        r['person'] = None
+        self.db.get('test_person', r)
+        p = r['person']
+        self.assertIsInstance(p, tuple)
+        self.assertIsNone(p.name)
+        self.assertIsNone(p.age)
+        self.assertIsNone(p.married)
+        self.assertIsNone(p.weight)
+        self.assertIsNone(p.salary)
+        r = self.db.insert('test_person', None, person=None)
+        self.assertIsNone(r['person'])
+        r['person'] = None
+        self.db.get('test_person', r)
+        self.assertIsNone(r['person'])
+
+    def testRecordInsertBytea(self):
+        query = self.db.query
+        query('create type test_person_type as'
+            ' (name text, picture bytea)')
+        self.addCleanup(query, 'drop type test_person_type')
+        self.createTable('test_person', 'person test_person_type',
+            temporary=False, oids=True)
+        person_typ = self.db.get_attnames('test_person')['person']
+        self.assertEqual(person_typ.attnames,
+            dict(name='text', picture='bytea'))
+        person = ('John Doe', b'O\x00ps\xff!')
+        r = self.db.insert('test_person', None, person=person)
+        p = r['person']
+        self.assertIsInstance(p, tuple)
+        self.assertEqual(p, person)
+        self.assertEqual(p.name, 'John Doe')
+        self.assertIsInstance(p.name, str)
+        self.assertEqual(p.picture, person[1])
+        self.assertIsInstance(p.picture, bytes)
+
+    def testRecordInsertJson(self):
+        query = self.db.query
+        try:
+            query('create type test_person_type as'
+                ' (name text, data json)')
+        except pg.ProgrammingError as error:
+            if self.db.server_version < 90200:
+                self.skipTest('database does not support json')
+            self.fail(str(error))
+        self.addCleanup(query, 'drop type test_person_type')
+        self.createTable('test_person', 'person test_person_type',
+            temporary=False, oids=True)
+        person_typ = self.db.get_attnames('test_person')['person']
+        self.assertEqual(person_typ.attnames,
+            dict(name='text', data='json'))
+        person = ('John Doe', dict(age=61, married=True, weight=99.5))
+        r = self.db.insert('test_person', None, person=person)
+        p = r['person']
+        self.assertIsInstance(p, tuple)
+        if pg.get_jsondecode() is None:
+            p = p._replace(data=json.loads(p.data))
+        self.assertEqual(p, person)
+        self.assertEqual(p.name, 'John Doe')
+        self.assertIsInstance(p.name, str)
+        self.assertEqual(p.data, person[1])
+        self.assertIsInstance(p.data, dict)
+
+    def testRecordLiteral(self):
+        query = self.db.query
+        query('create type test_person_type as'
+            ' (name varchar, age smallint)')
+        self.addCleanup(query, 'drop type test_person_type')
+        self.createTable('test_person', 'person test_person_type',
+            temporary=False, oids=True)
+        person_typ = self.db.get_attnames('test_person')['person']
+        if self.regtypes:
+            self.assertEqual(person_typ, 'test_person_type')
+        else:
+            self.assertEqual(person_typ, 'record')
+        if self.regtypes:
+            self.assertEqual(person_typ.attnames,
+                dict(name='character varying', age='smallint'))
+        else:
+            self.assertEqual(person_typ.attnames,
+                dict(name='text', age='int'))
+        person = pg._Literal("('John Doe', 61)")
+        r = self.db.insert('test_person', None, person=person)
+        p = r['person']
+        self.assertIsInstance(p, tuple)
+        self.assertEqual(p.name, 'John Doe')
+        self.assertIsInstance(p.name, str)
+        self.assertEqual(p.age, 61)
+        self.assertIsInstance(p.age, int)
+
     def testNotificationHandler(self):
         # the notification handler itself is tested separately
         f = self.db.notification_handler
@@ -3255,6 +3502,7 @@
         cls.set_option('bool', not_bool)
         cls.set_option('namedresult', None)
         cls.set_option('jsondecode', None)
+        cls.regtypes = not DB().use_regtypes()
         super(TestDBClassNonStdOpts, cls).setUpClass()
 
     @classmethod

Modified: trunk/tests/test_classic_functions.py
==============================================================================
--- trunk/tests/test_classic_functions.py       Thu Jan 28 14:54:34 2016        
(r792)
+++ trunk/tests/test_classic_functions.py       Thu Jan 28 15:07:41 2016        
(r793)
@@ -216,14 +216,14 @@
         self.assertRaises(TypeError, f)
         self.assertRaises(TypeError, f, None)
         self.assertRaises(TypeError, f, '{}', 1)
-        self.assertRaises(TypeError, f, '{}', ',',)
+        self.assertRaises(TypeError, f, '{}', b',',)
         self.assertRaises(TypeError, f, '{}', None, None)
         self.assertRaises(TypeError, f, '{}', None, 1)
-        self.assertRaises(TypeError, f, '{}', None, '')
-        self.assertRaises(ValueError, f, '{}', None, '\\')
-        self.assertRaises(ValueError, f, '{}', None, '{')
-        self.assertRaises(ValueError, f, '{}', None, '}')
-        self.assertRaises(TypeError, f, '{}', None, ',;')
+        self.assertRaises(TypeError, f, '{}', None, b'')
+        self.assertRaises(ValueError, f, '{}', None, b'\\')
+        self.assertRaises(ValueError, f, '{}', None, b'{')
+        self.assertRaises(ValueError, f, '{}', None, b'}')
+        self.assertRaises(TypeError, f, '{}', None, b',;')
         self.assertEqual(f('{}'), [])
         self.assertEqual(f('{}', None), [])
         self.assertEqual(f('{}', None, b';'), [])
@@ -488,14 +488,14 @@
         self.assertRaises(TypeError, f)
         self.assertRaises(TypeError, f, None)
         self.assertRaises(TypeError, f, '()', 1)
-        self.assertRaises(TypeError, f, '()', ',',)
+        self.assertRaises(TypeError, f, '()', b',',)
         self.assertRaises(TypeError, f, '()', None, None)
         self.assertRaises(TypeError, f, '()', None, 1)
-        self.assertRaises(TypeError, f, '()', None, '')
-        self.assertRaises(ValueError, f, '()', None, '\\')
-        self.assertRaises(ValueError, f, '()', None, '(')
-        self.assertRaises(ValueError, f, '()', None, ')')
-        self.assertRaises(TypeError, f, '{}', None, ',;')
+        self.assertRaises(TypeError, f, '()', None, b'')
+        self.assertRaises(ValueError, f, '()', None, b'\\')
+        self.assertRaises(ValueError, f, '()', None, b'(')
+        self.assertRaises(ValueError, f, '()', None, b')')
+        self.assertRaises(TypeError, f, '{}', None, b',;')
         self.assertEqual(f('()'), (None,))
         self.assertEqual(f('()', None), (None,))
         self.assertEqual(f('()', None, b';'), (None,))
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to