Author: cito
Date: Thu Jan 28 15:07:41 2016
New Revision: 793
Log:
Improve quoting and typecasting in the pg module
Larger refactoring of the code for adapting and typecasting in the pg module.
Things are now a lot cleaner and clearer.
The _Adapt class is responsible for all adapting of Python objects to their
PostgreSQL equivalents when sending data to the database. The typecasting
from PostgreSQL on output happens in the C module, except for the typecasting
of records which is new and provided by the _CastRecord class.
The classic module also did not work properly when regular type names were
switched on with use_regtypes(True), since the adapting of types relied on
the PyGreSQL type names. This has been solved by adding a new _PgType class
that is essentially the old type name, but augmented with all the necessary
information necessary to adapt types, particularly record types.
All tests in test_classic_dbwrapper now run twice, using opposite settings
for the various configuration settings like use_bool() or use_regtypes(),
in order to make sure that no internal functions rely on default settings.
Modified:
trunk/pg.py
trunk/tests/test_classic_dbwrapper.py
trunk/tests/test_classic_functions.py
Modified: trunk/pg.py
==============================================================================
--- trunk/pg.py Thu Jan 28 14:54:34 2016 (r792)
+++ trunk/pg.py Thu Jan 28 15:07:41 2016 (r793)
@@ -166,6 +166,311 @@
_simpletype = _SimpleType()
+class _Adapt:
+ """Mixin providing methods for adapting records and record elements.
+
+ This is used when passing values from one of the higher level DB
+ methods as parameters for a query.
+
+ This class must be mixed in to a connection class, because it needs
+ connection specific methods such as escape_bytea().
+ """
+
+ _bool_true_values = frozenset('t true 1 y yes on'.split())
+
+ _date_literals = frozenset('current_date current_time'
+ ' current_timestamp localtime localtimestamp'.split())
+
+ _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
+ _re_record_quote = regex(r'[(,"\\]')
+ _re_array_escape = _re_record_escape = regex(r'(["\\])')
+
+ @classmethod
+ def _adapt_bool(cls, v):
+ """Adapt a boolean parameter."""
+ if isinstance(v, basestring):
+ if not v:
+ return None
+ v = v.lower() in cls._bool_true_values
+ return 't' if v else 'f'
+
+ @classmethod
+ def _adapt_date(cls, v):
+ """Adapt a date parameter."""
+ if not v:
+ return None
+ if isinstance(v, basestring) and v.lower() in cls._date_literals:
+ return _Literal(v)
+ return v
+
+ @staticmethod
+ def _adapt_num(v):
+ """Adapt a numeric parameter."""
+ if not v and v != 0:
+ return None
+ return v
+
+ _adapt_int = _adapt_float = _adapt_money = _adapt_num
+
+ def _adapt_bytea(self, v):
+ """Adapt a bytea parameter."""
+ return self.escape_bytea(v)
+
+ def _adapt_json(self, v):
+ """Adapt a json parameter."""
+ if not v:
+ return None
+ if isinstance(v, basestring):
+ return v
+ return self.encode_json(v)
+
+ @classmethod
+ def _adapt_text_array(cls, v):
+ """Adapt a text type array parameter."""
+ if isinstance(v, list):
+ adapt = cls._adapt_text_array
+ return '{%s}' % ','.join(adapt(v) for v in v)
+ if v is None:
+ return 'null'
+ if not v:
+ return '""'
+ v = str(v)
+ if cls._re_array_quote.search(v):
+ v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
+ return v
+
+ _adapt_date_array = _adapt_text_array
+
+ @classmethod
+ def _adapt_bool_array(cls, v):
+ """Adapt a boolean array parameter."""
+ if isinstance(v, list):
+ adapt = cls._adapt_bool_array
+ return '{%s}' % ','.join(adapt(v) for v in v)
+ if v is None:
+ return 'null'
+ if isinstance(v, basestring):
+ if not v:
+ return 'null'
+ v = v.lower() in cls._bool_true_values
+ return 't' if v else 'f'
+
+ @classmethod
+ def _adapt_num_array(cls, v):
+ """Adapt a numeric array parameter."""
+ if isinstance(v, list):
+ adapt = cls._adapt_num_array
+ return '{%s}' % ','.join(adapt(v) for v in v)
+ if not v and v != 0:
+ return 'null'
+ return str(v)
+
+ _adapt_int_array = _adapt_float_array = _adapt_money_array = \
+ _adapt_num_array
+
+ def _adapt_bytea_array(self, v):
+ """Adapt a bytea array parameter."""
+ if isinstance(v, list):
+ return b'{' + b','.join(
+ self._adapt_bytea_array(v) for v in v) + b'}'
+ if v is None:
+ return b'null'
+ return self.escape_bytea(v).replace(b'\\', b'\\\\')
+
+ def _adapt_json_array(self, v):
+ """Adapt a json array parameter."""
+ if isinstance(v, list):
+ adapt = self._adapt_json_array
+ return '{%s}' % ','.join(adapt(v) for v in v)
+ if not v:
+ return 'null'
+ if not isinstance(v, basestring):
+ v = self.encode_json(v)
+ if self._re_array_quote.search(v):
+ v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
+ return v
+
+ def _adapt_record(self, v, typ):
+ """Adapt a record parameter with given type."""
+ typ = typ.attnames.values()
+ if len(typ) != len(v):
+ raise TypeError('Record parameter %s has wrong size' % v)
+ return '(%s)' % ','.join(getattr(self,
+ '_adapt_record_%s' % t.simple)(v) for v, t in zip(v, typ))
+
+ @classmethod
+ def _adapt_record_text(cls, v):
+ """Adapt a text type record component."""
+ if v is None:
+ return ''
+ if not v:
+ return '""'
+ v = str(v)
+ if cls._re_record_quote.search(v):
+ v = '"%s"' % cls._re_record_escape.sub(r'\\\1', v)
+ return v
+
+ _adapt_record_date = _adapt_record_text
+
+ @classmethod
+ def _adapt_record_bool(cls, v):
+ """Adapt a boolean record component."""
+ if v is None:
+ return ''
+ if isinstance(v, basestring):
+ if not v:
+ return ''
+ v = v.lower() in cls._bool_true_values
+ return 't' if v else 'f'
+
+ @staticmethod
+ def _adapt_record_num(v):
+ """Adapt a numeric record component."""
+ if not v and v != 0:
+ return ''
+ return str(v)
+
+ _adapt_record_int = _adapt_record_float = _adapt_record_money = \
+ _adapt_record_num
+
+ def _adapt_record_bytea(self, v):
+ if v is None:
+ return ''
+ v = self.escape_bytea(v)
+ if bytes is not str and isinstance(v, bytes):
+ v = v.decode('ascii')
+ return v.replace('\\', '\\\\')
+
+ def _adapt_record_json(self, v):
+ """Adapt a bytea record component."""
+ if not v:
+ return ''
+ if not isinstance(v, basestring):
+ v = self.encode_json(v)
+ if self._re_array_quote.search(v):
+ v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
+ return v
+
+ def _adapt_param(self, value, typ, params):
+ """Adapt and add a parameter to the list."""
+ if isinstance(value, _Literal):
+ return value
+ if value is not None:
+ simple = typ.simple
+ if simple == 'text':
+ pass
+ elif simple == 'record':
+ if isinstance(value, tuple):
+ value = self._adapt_record(value, typ)
+ elif simple.endswith('[]'):
+ if isinstance(value, list):
+ adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
+ value = adapt(value)
+ else:
+ adapt = getattr(self, '_adapt_%s' % simple)
+ value = adapt(value)
+ if isinstance(value, _Literal):
+ return value
+ params.append(value)
+ return '$%d' % len(params)
+
+
+class _CastRecord:
+ """Class providing methods for casting records and record elements.
+
+ This is needed when getting result values from one of the higher level DB
+ methods, since the lower level query method only casts the other types.
+ """
+
+ @staticmethod
+ def cast_bool(v):
+ if not get_bool():
+ return v
+ return v[0] == 't'
+
+ @staticmethod
+ def cast_bytea(v):
+ return unescape_bytea(v)
+
+ @staticmethod
+ def cast_float(v):
+ return float(v)
+
+ @staticmethod
+ def cast_int(v):
+ return int(v)
+
+ @staticmethod
+ def cast_json(v):
+ cast = get_jsondecode()
+ if not cast:
+ return v
+ return cast(v)
+
+ @staticmethod
+ def cast_num(v):
+ return (get_decimal() or float)(v)
+
+ @staticmethod
+ def cast_money(v):
+ point = get_decimal_point()
+ if not point:
+ return v
+ if point != '.':
+ v = v.replace(point, '.')
+ v = v.replace('(', '-')
+ v = ''.join(c for c in v if c.isdigit() or c in '.-')
+ return (get_decimal() or float)(v)
+
+ @classmethod
+ def cast(cls, v, typ):
+ types = typ.attnames.values()
+ cast = [getattr(cls, 'cast_%s' % t.simple, None) for t in types]
+ v = cast_record(v, cast)
+ return typ.namedtuple(*v)
+
+
+class _PgType(str):
+ """Class augmenting the simple type name with additional info."""
+
+ _num_types = frozenset('int float num money'
+ ' int2 int4 int8 float4 float8 numeric money'.split())
+
+ @classmethod
+ def create(cls, db, pgtype, regtype, typrelid):
+ """Create a PostgreSQL type name with additional info."""
+ simple = 'record' if typrelid else _simpletype[pgtype]
+ self = cls(regtype if db._regtypes else simple)
+ self.db = db
+ self.simple = simple
+ self.pgtype = pgtype
+ self.regtype = regtype
+ self.typrelid = typrelid
+ self._attnames = self._namedtuple = None
+ return self
+
+ @property
+ def attnames(self):
+ """Get names and types of the fields of a composite type."""
+ if not self.typrelid:
+ return None
+ if not self._attnames:
+ self._attnames = self.db.get_attnames(self.typrelid)
+ return self._attnames
+
+ @property
+ def namedtuple(self):
+ """Return named tuple class representing a composite type."""
+ if not self._namedtuple:
+ self._namedtuple = namedtuple(self, self.attnames)
+ return self._namedtuple
+
+ def cast(self, value):
+ if value is not None and self.typrelid:
+ value = _CastRecord.cast(value, self)
+ return value
+
+
class _Literal(str):
"""Wrapper class for literal SQL."""
@@ -349,7 +654,7 @@
# The actual PostGreSQL database connection interface:
-class DB(object):
+class DB(_Adapt):
"""Wrapper class for the _pg connection type."""
def __init__(self, *args, **kw):
@@ -451,146 +756,12 @@
"""Get boolean value corresponding to d."""
return bool(d) if get_bool() else ('t' if d else 'f')
- _bool_true_values = frozenset('t true 1 y yes on'.split())
-
- def _prepare_bool(self, d):
- """Prepare a boolean parameter."""
- if isinstance(d, basestring):
- if not d:
- return None
- d = d.lower() in self._bool_true_values
- return 't' if d else 'f'
-
- _date_literals = frozenset('current_date current_time'
- ' current_timestamp localtime localtimestamp'.split())
-
- def _prepare_date(self, d):
- """Prepare a date parameter."""
- if not d:
- return None
- if isinstance(d, basestring) and d.lower() in self._date_literals:
- return _Literal(d)
- return d
-
- _num_types = frozenset('int float num money'
- ' int2 int4 int8 float4 float8 numeric money'.split())
-
- @staticmethod
- def _prepare_num(d):
- """Prepare a numeric parameter."""
- if not d and d != 0:
- return None
- return d
-
- _prepare_int = _prepare_float = _prepare_money = _prepare_num
-
- def _prepare_bytea(self, d):
- """Prepare a bytea parameter."""
- return self.escape_bytea(d)
-
- def _prepare_json(self, d):
- """Prepare a json parameter."""
- if not d:
- return None
- if isinstance(d, basestring):
- return d
- return self.encode_json(d)
-
- _re_array_escape = regex(r'(["\\])')
- _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
-
- def _prepare_bool_array(self, d):
- """Prepare a bool array parameter."""
- if isinstance(d, list):
- return '{%s}' % ','.join(self._prepare_bool_array(v) for v in d)
- if d is None:
- return 'null'
- if isinstance(d, basestring):
- if not d:
- return 'null'
- d = d.lower() in self._bool_true_values
- return 't' if d else 'f'
-
- def _prepare_num_array(self, d):
- """Prepare a numeric array parameter."""
- if isinstance(d, list):
- return '{%s}' % ','.join(self._prepare_num_array(v) for v in d)
- if not d and d != 0:
- return 'null'
- return str(d)
-
- _prepare_int_array = _prepare_float_array = _prepare_money_array = \
- _prepare_num_array
-
- def _prepare_text_array(self, d):
- """Prepare a text array parameter."""
- if isinstance(d, list):
- return '{%s}' % ','.join(self._prepare_text_array(v) for v in d)
- if d is None:
- return 'null'
- if not d:
- return '""'
- d = str(d)
- if self._re_array_quote.search(d):
- d = '"%s"' % self._re_array_escape.sub(r'\\\1', d)
- return d
-
- def _prepare_bytea_array(self, d):
- """Prepare a bytea array parameter."""
- if isinstance(d, list):
- return b'{' + b','.join(
- self._prepare_bytea_array(v) for v in d) + b'}'
- if d is None:
- return b'null'
- return self.escape_bytea(d).replace(b'\\', b'\\\\')
-
- def _prepare_json_array(self, d):
- """Prepare a json array parameter."""
- if isinstance(d, list):
- return '{%s}' % ','.join(self._prepare_json_array(v) for v in d)
- if not d:
- return 'null'
- if not isinstance(d, basestring):
- d = self.encode_json(d)
- if self._re_array_quote.search(d):
- d = '"%s"' % self._re_array_escape.sub(r'\\\1', d)
- return d
-
- def _prepare_param(self, value, typ, params):
- """Prepare and add a parameter to the list."""
- if isinstance(value, _Literal):
- return value
- if value is not None and typ != 'text':
- if typ.endswith('[]'):
- if isinstance(value, list):
- prepare = getattr(self, '_prepare_%s_array' % typ[:-2])
- value = prepare(value)
- elif isinstance(value, basestring):
- value = value.strip()
- if not value.startswith('{') or not value.endswith('}'):
- if value[:5].lower() == 'array':
- value = value[5:].lstrip()
- if value.startswith('[') and value.endswith(']'):
- value = _Literal('ARRAY%s' % value)
- else:
- raise ValueError(
- 'Invalid array expression: %s' % value)
- else:
- raise ValueError('Invalid array parameter: %s' % value)
- else:
- prepare = getattr(self, '_prepare_%s' % typ)
- value = prepare(value)
- if isinstance(value, _Literal):
- return value
- params.append(value)
- return '$%d' % len(params)
-
def _list_params(self, params):
"""Create a human readable parameter list."""
return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
@staticmethod
- def _prepare_qualified_param(name, param):
+ def _adapt_qualified_param(name, param):
"""Quote parameter representing a qualified name.
Escapes the name for use as an SQL parameter, unless the
@@ -601,7 +772,7 @@
"""
if isinstance(param, int):
param = "$%d" % param
- if '.' not in name:
+ if isinstance(name, basestring) and '.' not in name:
param = 'quote_ident(%s)' % (param,)
return param
@@ -869,7 +1040,7 @@
" AND NOT a.attisdropped"
" WHERE i.indrelid=%s::regclass"
" AND i.indisprimary ORDER BY a.attnum") % (
- self._prepare_qualified_param(table, 1),)
+ self._adapt_qualified_param(table, 1),)
pkey = self.db.query(q, (table,)).getresult()
if not pkey:
raise KeyError('Table %s has no primary key' % table)
@@ -933,17 +1104,16 @@
try: # cache lookup
names = attnames[table]
except KeyError: # cache miss, check the database
- q = ("SELECT a.attname, t.typname%s"
+ q = ("SELECT a.attname, t.typname, t.typname::regtype, t.typrelid"
" FROM pg_attribute a"
" JOIN pg_type t ON t.oid = a.atttypid"
" WHERE a.attrelid = %s::regclass"
" AND (a.attnum > 0 OR a.attname = 'oid')"
" AND NOT a.attisdropped ORDER BY a.attnum") % (
- '::regtype' if self._regtypes else '',
- self._prepare_qualified_param(table, 1))
+ self._adapt_qualified_param(table, 1))
names = self.db.query(q, (table,)).getresult()
- if not self._regtypes:
- names = ((name, _simpletype[typ]) for name, typ in names)
+ names = ((name, _PgType.create(self, pgtype, regtype, typrelid))
+ for name, pgtype, regtype, typrelid in names)
names = AttrDict(names)
attnames[table] = names # cache it
return names
@@ -966,7 +1136,7 @@
return self._privileges[(table, privilege)]
except KeyError: # cache miss, ask the database
q = "SELECT has_table_privilege(%s, $2)" % (
- self._prepare_qualified_param(table, 1),)
+ self._adapt_qualified_param(table, 1),)
q = self.db.query(q, (table, privilege))
ret = q.getresult()[0][0] == self._make_bool(True)
self._privileges[(table, privilege)] = ret # cache it
@@ -1024,7 +1194,7 @@
'Differing number of items in keyname and row')
row = dict(zip(keyname, row))
params = []
- param = partial(self._prepare_param, params=params)
+ param = partial(self._adapt_param, params=params)
col = self.escape_identifier
what = 'oid, *' if qoid else '*'
where = ' AND '.join('%s = %s' % (
@@ -1044,6 +1214,8 @@
for n, value in res[0].items():
if qoid and n == 'oid':
n = qoid
+ else:
+ value = attnames[n].cast(value)
row[n] = value
return row
@@ -1070,7 +1242,7 @@
attnames = self.get_attnames(table)
qoid = _oid_key(table) if 'oid' in attnames else None
params = []
- param = partial(self._prepare_param, params=params)
+ param = partial(self._adapt_param, params=params)
col = self.escape_identifier
names, values = [], []
for n in attnames:
@@ -1090,6 +1262,8 @@
for n, value in res[0].items():
if qoid and n == 'oid':
n = qoid
+ else:
+ value = attnames[n].cast(value)
row[n] = value
return row
@@ -1131,7 +1305,7 @@
else:
raise KeyError('Missing primary key in row')
params = []
- param = partial(self._prepare_param, params=params)
+ param = partial(self._adapt_param, params=params)
col = self.escape_identifier
where = ' AND '.join('%s = %s' % (
col(k), param(row[k], attnames[k])) for k in keyname)
@@ -1157,6 +1331,8 @@
for n, value in res[0].items():
if qoid and n == 'oid':
n = qoid
+ else:
+ value = attnames[n].cast(value)
row[n] = value
return row
@@ -1214,7 +1390,7 @@
attnames = self.get_attnames(table)
qoid = _oid_key(table) if 'oid' in attnames else None
params = []
- param = partial(self._prepare_param,params=params)
+ param = partial(self._adapt_param,params=params)
col = self.escape_identifier
names, values, updates = [], [], []
for n in attnames:
@@ -1258,6 +1434,8 @@
for n, value in res[0].items():
if qoid and n == 'oid':
n = qoid
+ else:
+ value = attnames[n].cast(value)
row[n] = value
else:
self.get(table, row)
@@ -1278,7 +1456,8 @@
for n, t in attnames.items():
if n == 'oid':
continue
- if t in self._num_types:
+ t = t.simple
+ if t in _PgType._num_types:
row[n] = 0
elif t == 'bool':
row[n] = self._make_bool(False)
@@ -1327,7 +1506,7 @@
else:
raise KeyError('Missing primary key in row')
params = []
- param = partial(self._prepare_param, params=params)
+ param = partial(self._adapt_param, params=params)
col = self.escape_identifier
where = ' AND '.join('%s = %s' % (
col(k), param(row[k], attnames[k])) for k in keyname)
Modified: trunk/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/tests/test_classic_dbwrapper.py Thu Jan 28 14:54:34 2016
(r792)
+++ trunk/tests/test_classic_dbwrapper.py Thu Jan 28 15:07:41 2016
(r793)
@@ -379,6 +379,8 @@
cls_set_up = False
+ regtypes = None
+
@classmethod
def setUpClass(cls):
db = DB()
@@ -401,6 +403,10 @@
def setUp(self):
self.assertTrue(self.cls_set_up)
self.db = DB()
+ if self.regtypes is None:
+ self.regtypes = self.db.use_regtypes()
+ else:
+ self.db.use_regtypes(self.regtypes)
query = self.db.query
query('set client_encoding=utf8')
query('set standard_conforming_strings=on')
@@ -985,18 +991,29 @@
self.db.get_attnames, 'has.too.many.dots')
r = get_attnames('test')
self.assertIsInstance(r, dict)
- self.assertEqual(r, dict(
- i2='int', i4='int', i8='int', d='num',
- f4='float', f8='float', m='money',
- v4='text', c4='text', t='text'))
+ if self.regtypes:
+ self.assertEqual(r, dict(
+ i2='smallint', i4='integer', i8='bigint', d='numeric',
+ f4='real', f8='double precision', m='money',
+ v4='character varying', c4='character', t='text'))
+ else:
+ self.assertEqual(r, dict(
+ i2='int', i4='int', i8='int', d='num',
+ f4='float', f8='float', m='money',
+ v4='text', c4='text', t='text'))
self.createTable('test_table',
'n int, alpha smallint, beta bool,'
' gamma char(5), tau text, v varchar(3)')
r = get_attnames('test_table')
self.assertIsInstance(r, dict)
- self.assertEqual(r, dict(
- n='int', alpha='int', beta='bool',
- gamma='text', tau='text', v='text'))
+ if self.regtypes:
+ self.assertEqual(r, dict(
+ n='integer', alpha='smallint', beta='boolean',
+ gamma='character', tau='text', v='character varying'))
+ else:
+ self.assertEqual(r, dict(
+ n='int', alpha='int', beta='bool',
+ gamma='text', tau='text', v='text'))
def testGetAttnamesWithQuotes(self):
get_attnames = self.db.get_attnames
@@ -1005,23 +1022,37 @@
'"Prime!" smallint, "much space" integer, "Questions?" text')
r = get_attnames(table)
self.assertIsInstance(r, dict)
- self.assertEqual(r, {
- 'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
+ if self.regtypes:
+ self.assertEqual(r, {
+ 'Prime!': 'smallint', 'much space': 'integer',
+ 'Questions?': 'text'})
+ else:
+ self.assertEqual(r, {
+ 'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
table = 'yet another test table for get_attnames()'
self.createTable(table,
'a smallint, b integer, c bigint,'
- ' e numeric, f float, f2 double precision, m money,'
+ ' e numeric, f real, f2 double precision, m money,'
' x smallint, y smallint, z smallint,'
' Normal_NaMe smallint, "Special Name" smallint,'
' t text, u char(2), v varchar(2),'
' primary key (y, u)', oids=True)
r = get_attnames(table)
self.assertIsInstance(r, dict)
- self.assertEqual(r, {'a': 'int', 'c': 'int', 'b': 'int',
- 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
- 'normal_name': 'int', 'Special Name': 'int',
- 'u': 'text', 't': 'text', 'v': 'text',
- 'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
+ if self.regtypes:
+ self.assertEqual(r, {
+ 'a': 'smallint', 'b': 'integer', 'c': 'bigint',
+ 'e': 'numeric', 'f': 'real', 'f2': 'double precision',
+ 'm': 'money', 'normal_name': 'smallint',
+ 'Special Name': 'smallint', 'u': 'character',
+ 't': 'text', 'v': 'character varying', 'y': 'smallint',
+ 'x': 'smallint', 'z': 'smallint', 'oid': 'oid'})
+ else:
+ self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int',
+ 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
+ 'normal_name': 'int', 'Special Name': 'int',
+ 'u': 'text', 't': 'text', 'v': 'text',
+ 'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
def testGetAttnamesWithRegtypes(self):
get_attnames = self.db.get_attnames
@@ -1030,7 +1061,7 @@
' gamma char(5), tau text, v varchar(3)')
use_regtypes = self.db.use_regtypes
regtypes = use_regtypes()
- self.assertFalse(regtypes)
+ self.assertEqual(regtypes, self.regtypes)
use_regtypes(True)
try:
r = get_attnames("test_table")
@@ -1041,27 +1072,47 @@
n='integer', alpha='smallint', beta='boolean',
gamma='character', tau='text', v='character varying'))
+ def testGetAttnamesWithoutRegtypes(self):
+ get_attnames = self.db.get_attnames
+ self.createTable('test_table',
+ ' n int, alpha smallint, beta bool,'
+ ' gamma char(5), tau text, v varchar(3)')
+ use_regtypes = self.db.use_regtypes
+ regtypes = use_regtypes()
+ self.assertEqual(regtypes, self.regtypes)
+ use_regtypes(False)
+ try:
+ r = get_attnames("test_table")
+ self.assertIsInstance(r, dict)
+ finally:
+ use_regtypes(regtypes)
+ self.assertEqual(r, dict(
+ n='int', alpha='int', beta='bool',
+ gamma='text', tau='text', v='text'))
+
def testGetAttnamesIsCached(self):
get_attnames = self.db.get_attnames
+ int_type = 'integer' if self.regtypes else 'int'
+ text_type = 'text'
query = self.db.query
self.createTable('test_table', 'col int')
r = get_attnames("test_table")
self.assertIsInstance(r, dict)
- self.assertEqual(r, dict(col='int'))
+ self.assertEqual(r, dict(col=int_type))
query("alter table test_table alter column col type text")
query("alter table test_table add column col2 int")
r = get_attnames("test_table")
- self.assertEqual(r, dict(col='int'))
+ self.assertEqual(r, dict(col=int_type))
r = get_attnames("test_table", flush=True)
- self.assertEqual(r, dict(col='text', col2='int'))
+ self.assertEqual(r, dict(col=text_type, col2=int_type))
query("alter table test_table drop column col2")
r = get_attnames("test_table")
- self.assertEqual(r, dict(col='text', col2='int'))
+ self.assertEqual(r, dict(col=text_type, col2=int_type))
r = get_attnames("test_table", flush=True)
- self.assertEqual(r, dict(col='text'))
+ self.assertEqual(r, dict(col=text_type))
query("alter table test_table drop column col")
r = get_attnames("test_table")
- self.assertEqual(r, dict(col='text'))
+ self.assertEqual(r, dict(col=text_type))
r = get_attnames("test_table", flush=True)
self.assertEqual(r, dict())
@@ -1069,10 +1120,17 @@
get_attnames = self.db.get_attnames
r = get_attnames('test', flush=True)
self.assertIsInstance(r, OrderedDict)
- self.assertEqual(r, OrderedDict([
- ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
- ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
- ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
+ if self.regtypes:
+ self.assertEqual(r, OrderedDict([
+ ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
+ ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
+ ('m', 'money'), ('v4', 'character varying'),
+ ('c4', 'character'), ('t', 'text')]))
+ else:
+ self.assertEqual(r, OrderedDict([
+ ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
+ ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
+ ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
if OrderedDict is not dict:
r = ' '.join(list(r.keys()))
self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
@@ -1082,9 +1140,15 @@
' gamma char(5), tau text, beta bool')
r = get_attnames(table)
self.assertIsInstance(r, OrderedDict)
- self.assertEqual(r, OrderedDict([
- ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
- ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
+ if self.regtypes:
+ self.assertEqual(r, OrderedDict([
+ ('n', 'integer'), ('alpha', 'smallint'),
+ ('v', 'character varying'), ('gamma', 'character'),
+ ('tau', 'text'), ('beta', 'boolean')]))
+ else:
+ self.assertEqual(r, OrderedDict([
+ ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
+ ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
if OrderedDict is not dict:
r = ' '.join(list(r.keys()))
self.assertEqual(r, 'n alpha v gamma tau beta')
@@ -1096,10 +1160,17 @@
get_attnames = self.db.get_attnames
r = get_attnames('test', flush=True)
self.assertIsInstance(r, AttrDict)
- self.assertEqual(r, AttrDict([
- ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
- ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
- ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
+ if self.regtypes:
+ self.assertEqual(r, AttrDict([
+ ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
+ ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
+ ('m', 'money'), ('v4', 'character varying'),
+ ('c4', 'character'), ('t', 'text')]))
+ else:
+ self.assertEqual(r, AttrDict([
+ ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
+ ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
+ ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
r = ' '.join(list(r.keys()))
self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
table = 'test table for get_attnames'
@@ -1108,9 +1179,15 @@
' gamma char(5), tau text, beta bool')
r = get_attnames(table)
self.assertIsInstance(r, AttrDict)
- self.assertEqual(r, AttrDict([
- ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
- ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
+ if self.regtypes:
+ self.assertEqual(r, AttrDict([
+ ('n', 'integer'), ('alpha', 'smallint'),
+ ('v', 'character varying'), ('gamma', 'character'),
+ ('tau', 'text'), ('beta', 'boolean')]))
+ else:
+ self.assertEqual(r, AttrDict([
+ ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
+ ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
r = ' '.join(list(r.keys()))
self.assertEqual(r, 'n alpha v gamma tau beta')
@@ -2969,9 +3046,17 @@
' d numeric[], f4 real[], f8 double precision[], m money[],'
' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]')
r = self.db.get_attnames('arraytest')
- self.assertEqual(r, dict(id='int', i2='int[]', i4='int[]', i8='int[]',
- d='num[]', f4='float[]', f8='float[]', m='money[]',
- b='bool[]', v4='text[]', c4='text[]', t='text[]'))
+ if self.regtypes:
+ self.assertEqual(r, dict(
+ id='smallint', i2='smallint[]', i4='integer[]', i8='bigint[]',
+ d='numeric[]', f4='real[]', f8='double precision[]',
+ m='money[]', b='boolean[]',
+ v4='character varying[]', c4='character[]', t='text[]'))
+ else:
+ self.assertEqual(r, dict(
+ id='int', i2='int[]', i4='int[]', i8='int[]',
+ d='num[]', f4='float[]', f8='float[]', m='money[]',
+ b='bool[]', v4='text[]', c4='text[]', t='text[]'))
decimal = pg.get_decimal()
if decimal is Decimal:
long_decimal = decimal('123456789.123456789')
@@ -3004,7 +3089,7 @@
r = self.db.query('select * from arraytest limit 1').dictresult()[0]
self.assertEqual(r, data)
- def testArrayInput(self):
+ def testArrayLiteral(self):
insert = self.db.insert
self.createTable('arraytest', 'i int[], t text[]', oids=True)
r = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
@@ -3015,25 +3100,23 @@
self.db.insert('arraytest', r)
self.assertEqual(r['i'], [1, 2, 3])
self.assertEqual(r['t'], ['a', 'b', 'c'])
- r = dict(i="[1, 2, 3]", t="['a', 'b', 'c']")
- self.db.insert('arraytest', r)
- self.assertEqual(r['i'], [1, 2, 3])
- self.assertEqual(r['t'], ['a', 'b', 'c'])
- r = dict(i="array[1, 2, 3]", t="array['a', 'b', 'c']")
- self.db.insert('arraytest', r)
- self.assertEqual(r['i'], [1, 2, 3])
- self.assertEqual(r['t'], ['a', 'b', 'c'])
- r = dict(i="ARRAY[1, 2, 3]", t="ARRAY['a', 'b', 'c']")
+ L = pg._Literal
+ r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']"))
self.db.insert('arraytest', r)
self.assertEqual(r['i'], [1, 2, 3])
self.assertEqual(r['t'], ['a', 'b', 'c'])
r = dict(i="1, 2, 3", t="'a', 'b', 'c'")
- self.assertRaises(ValueError, self.db.insert, 'arraytest', r)
+ self.assertRaises(pg.ProgrammingError, self.db.insert, 'arraytest', r)
def testArrayOfIds(self):
self.createTable('arraytest', 'c cid[], o oid[], x xid[]', oids=True)
r = self.db.get_attnames('arraytest')
- self.assertEqual(r, dict(oid='int', c='int[]', o='int[]', x='int[]'))
+ if self.regtypes:
+ self.assertEqual(r, dict(
+ oid='oid', c='cid[]', o='oid[]', x='xid[]'))
+ else:
+ self.assertEqual(r, dict(
+ oid='int', c='int[]', o='int[]', x='int[]'))
data = dict(c=[11, 12, 13], o=[21, 22, 23], x=[31, 32, 33])
r = data.copy()
self.db.insert('arraytest', r)
@@ -3125,7 +3208,7 @@
self.skipTest('database does not support jsonb')
self.fail(str(error))
r = self.db.get_attnames('arraytest')
- self.assertEqual(r['data'], 'json[]')
+ self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]')
data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
jsondecode = pg.get_jsondecode()
r = dict(data=data)
@@ -3166,6 +3249,170 @@
self.db.get('arraytest', r)
self.assertEqual(r['data'], data)
+ def testInsertUpdateGetRecord(self):
+ query = self.db.query
+ query('create type test_person_type as'
+ ' (name varchar, age smallint, married bool,'
+ ' weight real, salary money)')
+ self.addCleanup(query, 'drop type test_person_type')
+ self.createTable('test_person', 'person test_person_type',
+ temporary=False, oids=True)
+ attnames = self.db.get_attnames('test_person')
+ self.assertEqual(len(attnames), 2)
+ self.assertIn('oid', attnames)
+ self.assertIn('person', attnames)
+ person_typ = attnames['person']
+ if self.regtypes:
+ self.assertEqual(person_typ, 'test_person_type')
+ else:
+ self.assertEqual(person_typ, 'record')
+ if self.regtypes:
+ self.assertEqual(person_typ.attnames,
+ dict(name='character varying', age='smallint',
+ married='boolean', weight='real', salary='money'))
+ else:
+ self.assertEqual(person_typ.attnames,
+ dict(name='text', age='int', married='bool',
+ weight='float', salary='money'))
+ decimal = pg.get_decimal()
+ if pg.get_bool():
+ bool_class = bool
+ t, f = True, False
+ else:
+ bool_class = str
+ t, f = 't', 'f'
+ person = ('John Doe', 61, t, 99.5, decimal('93456.75'))
+ r = self.db.insert('test_person', None, person=person)
+ p = r['person']
+ self.assertIsInstance(p, tuple)
+ self.assertEqual(p, person)
+ self.assertEqual(p.name, 'John Doe')
+ self.assertIsInstance(p.name, str)
+ self.assertIsInstance(p.age, int)
+ self.assertIsInstance(p.married, bool_class)
+ self.assertIsInstance(p.weight, float)
+ self.assertIsInstance(p.salary, decimal)
+ person = ('Jane Roe', 59, f, 64.5, decimal('96543.25'))
+ r['person'] = person
+ self.db.update('test_person', r)
+ p = r['person']
+ self.assertIsInstance(p, tuple)
+ self.assertEqual(p, person)
+ self.assertEqual(p.name, 'Jane Roe')
+ self.assertIsInstance(p.name, str)
+ self.assertIsInstance(p.age, int)
+ self.assertIsInstance(p.married, bool_class)
+ self.assertIsInstance(p.weight, float)
+ self.assertIsInstance(p.salary, decimal)
+ r['person'] = None
+ self.db.get('test_person', r)
+ p = r['person']
+ self.assertIsInstance(p, tuple)
+ self.assertEqual(p, person)
+ self.assertEqual(p.name, 'Jane Roe')
+ self.assertIsInstance(p.name, str)
+ self.assertIsInstance(p.age, int)
+ self.assertIsInstance(p.married, bool_class)
+ self.assertIsInstance(p.weight, float)
+ self.assertIsInstance(p.salary, decimal)
+ person = (None,) * 5
+ r = self.db.insert('test_person', None, person=person)
+ p = r['person']
+ self.assertIsInstance(p, tuple)
+ self.assertIsNone(p.name)
+ self.assertIsNone(p.age)
+ self.assertIsNone(p.married)
+ self.assertIsNone(p.weight)
+ self.assertIsNone(p.salary)
+ r['person'] = None
+ self.db.get('test_person', r)
+ p = r['person']
+ self.assertIsInstance(p, tuple)
+ self.assertIsNone(p.name)
+ self.assertIsNone(p.age)
+ self.assertIsNone(p.married)
+ self.assertIsNone(p.weight)
+ self.assertIsNone(p.salary)
+ r = self.db.insert('test_person', None, person=None)
+ self.assertIsNone(r['person'])
+ r['person'] = None
+ self.db.get('test_person', r)
+ self.assertIsNone(r['person'])
+
+ def testRecordInsertBytea(self):
+ query = self.db.query
+ query('create type test_person_type as'
+ ' (name text, picture bytea)')
+ self.addCleanup(query, 'drop type test_person_type')
+ self.createTable('test_person', 'person test_person_type',
+ temporary=False, oids=True)
+ person_typ = self.db.get_attnames('test_person')['person']
+ self.assertEqual(person_typ.attnames,
+ dict(name='text', picture='bytea'))
+ person = ('John Doe', b'O\x00ps\xff!')
+ r = self.db.insert('test_person', None, person=person)
+ p = r['person']
+ self.assertIsInstance(p, tuple)
+ self.assertEqual(p, person)
+ self.assertEqual(p.name, 'John Doe')
+ self.assertIsInstance(p.name, str)
+ self.assertEqual(p.picture, person[1])
+ self.assertIsInstance(p.picture, bytes)
+
+ def testRecordInsertJson(self):
+ query = self.db.query
+ try:
+ query('create type test_person_type as'
+ ' (name text, data json)')
+ except pg.ProgrammingError as error:
+ if self.db.server_version < 90200:
+ self.skipTest('database does not support json')
+ self.fail(str(error))
+ self.addCleanup(query, 'drop type test_person_type')
+ self.createTable('test_person', 'person test_person_type',
+ temporary=False, oids=True)
+ person_typ = self.db.get_attnames('test_person')['person']
+ self.assertEqual(person_typ.attnames,
+ dict(name='text', data='json'))
+ person = ('John Doe', dict(age=61, married=True, weight=99.5))
+ r = self.db.insert('test_person', None, person=person)
+ p = r['person']
+ self.assertIsInstance(p, tuple)
+ if pg.get_jsondecode() is None:
+ p = p._replace(data=json.loads(p.data))
+ self.assertEqual(p, person)
+ self.assertEqual(p.name, 'John Doe')
+ self.assertIsInstance(p.name, str)
+ self.assertEqual(p.data, person[1])
+ self.assertIsInstance(p.data, dict)
+
+ def testRecordLiteral(self):
+ query = self.db.query
+ query('create type test_person_type as'
+ ' (name varchar, age smallint)')
+ self.addCleanup(query, 'drop type test_person_type')
+ self.createTable('test_person', 'person test_person_type',
+ temporary=False, oids=True)
+ person_typ = self.db.get_attnames('test_person')['person']
+ if self.regtypes:
+ self.assertEqual(person_typ, 'test_person_type')
+ else:
+ self.assertEqual(person_typ, 'record')
+ if self.regtypes:
+ self.assertEqual(person_typ.attnames,
+ dict(name='character varying', age='smallint'))
+ else:
+ self.assertEqual(person_typ.attnames,
+ dict(name='text', age='int'))
+ person = pg._Literal("('John Doe', 61)")
+ r = self.db.insert('test_person', None, person=person)
+ p = r['person']
+ self.assertIsInstance(p, tuple)
+ self.assertEqual(p.name, 'John Doe')
+ self.assertIsInstance(p.name, str)
+ self.assertEqual(p.age, 61)
+ self.assertIsInstance(p.age, int)
+
def testNotificationHandler(self):
# the notification handler itself is tested separately
f = self.db.notification_handler
@@ -3255,6 +3502,7 @@
cls.set_option('bool', not_bool)
cls.set_option('namedresult', None)
cls.set_option('jsondecode', None)
+ cls.regtypes = not DB().use_regtypes()
super(TestDBClassNonStdOpts, cls).setUpClass()
@classmethod
Modified: trunk/tests/test_classic_functions.py
==============================================================================
--- trunk/tests/test_classic_functions.py Thu Jan 28 14:54:34 2016
(r792)
+++ trunk/tests/test_classic_functions.py Thu Jan 28 15:07:41 2016
(r793)
@@ -216,14 +216,14 @@
self.assertRaises(TypeError, f)
self.assertRaises(TypeError, f, None)
self.assertRaises(TypeError, f, '{}', 1)
- self.assertRaises(TypeError, f, '{}', ',',)
+ self.assertRaises(TypeError, f, '{}', b',',)
self.assertRaises(TypeError, f, '{}', None, None)
self.assertRaises(TypeError, f, '{}', None, 1)
- self.assertRaises(TypeError, f, '{}', None, '')
- self.assertRaises(ValueError, f, '{}', None, '\\')
- self.assertRaises(ValueError, f, '{}', None, '{')
- self.assertRaises(ValueError, f, '{}', None, '}')
- self.assertRaises(TypeError, f, '{}', None, ',;')
+ self.assertRaises(TypeError, f, '{}', None, b'')
+ self.assertRaises(ValueError, f, '{}', None, b'\\')
+ self.assertRaises(ValueError, f, '{}', None, b'{')
+ self.assertRaises(ValueError, f, '{}', None, b'}')
+ self.assertRaises(TypeError, f, '{}', None, b',;')
self.assertEqual(f('{}'), [])
self.assertEqual(f('{}', None), [])
self.assertEqual(f('{}', None, b';'), [])
@@ -488,14 +488,14 @@
self.assertRaises(TypeError, f)
self.assertRaises(TypeError, f, None)
self.assertRaises(TypeError, f, '()', 1)
- self.assertRaises(TypeError, f, '()', ',',)
+ self.assertRaises(TypeError, f, '()', b',',)
self.assertRaises(TypeError, f, '()', None, None)
self.assertRaises(TypeError, f, '()', None, 1)
- self.assertRaises(TypeError, f, '()', None, '')
- self.assertRaises(ValueError, f, '()', None, '\\')
- self.assertRaises(ValueError, f, '()', None, '(')
- self.assertRaises(ValueError, f, '()', None, ')')
- self.assertRaises(TypeError, f, '{}', None, ',;')
+ self.assertRaises(TypeError, f, '()', None, b'')
+ self.assertRaises(ValueError, f, '()', None, b'\\')
+ self.assertRaises(ValueError, f, '()', None, b'(')
+ self.assertRaises(ValueError, f, '()', None, b')')
+ self.assertRaises(TypeError, f, '{}', None, b',;')
self.assertEqual(f('()'), (None,))
self.assertEqual(f('()', None), (None,))
self.assertEqual(f('()', None, b';'), (None,))
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql