Author: cito
Date: Mon Jan 18 18:21:44 2016
New Revision: 765
Log:
Improve support for access by primary key
Composite primary keys are now returned as tuples instead of frozensets,
where the ordering of the tuple reflects the primary key index.
Primary keys now takes precedence if both OID and primary key are available
(this was solved the other way around in 4.x). Use of OIDs is thus slightly
more discouraged, though it still works as before for tables with OIDs where
no primary key is available.
This changeset also clarifies some docstrings, makes the code a bit clearer,
handles and tests some more edge cases (pg module still has 100% coverage).
Modified:
trunk/docs/contents/changelog.rst
trunk/docs/contents/pg/db_wrapper.rst
trunk/pg.py
trunk/tests/test_classic_dbwrapper.py
Modified: trunk/docs/contents/changelog.rst
==============================================================================
--- trunk/docs/contents/changelog.rst Sun Jan 17 16:05:19 2016 (r764)
+++ trunk/docs/contents/changelog.rst Mon Jan 18 18:21:44 2016 (r765)
@@ -27,11 +27,14 @@
colnames and coltypes attributes, which are not part of DB-API 2 though.
- The tty parameter and attribute of database connections has been
removed since it is not supported any more since PostgreSQL 7.4.
+- The pkey() method of the classic interface now returns tuples instead
+ of frozenset. The order of the tuples is like in the primary key index.
- The table name that is affixed to the name of the OID column returned
by the get() method of the classic interface will not automatically
be fully qualified any more. This reduces overhead from the interface,
but it means you must always write the table name in the same way when
you call the methods using it and you are using tables with OIDs.
+ Also, OIDs are now only used when access via primary key is not possible.
Note that OIDs are considered deprecated anyway, and they are not created
by default any more in PostgreSQL 8.1 and later.
- The internal caching and automatic quoting of class names in the classic
Modified: trunk/docs/contents/pg/db_wrapper.rst
==============================================================================
--- trunk/docs/contents/pg/db_wrapper.rst Sun Jan 17 16:05:19 2016
(r764)
+++ trunk/docs/contents/pg/db_wrapper.rst Mon Jan 18 18:21:44 2016
(r765)
@@ -65,9 +65,10 @@
:rtype: str
:raises KeyError: the table does not have a primary key
-This method returns the primary key of a table. For composite primary
-keys, the return value will be a frozenset. Note that this raises a
-KeyError if the table does not have a primary key.
+This method returns the primary key of a table. Single primary keys are
+returned as strings unless you set the composite flag. Composite primary
+keys are always represented as tuples. Note that this raises a KeyError
+if the table does not have a primary key.
get_databases -- get list of databases in the system
----------------------------------------------------
@@ -295,17 +296,23 @@
:param str keyname: name of field to use as key (optional)
:returns: A dictionary - the keys are the attribute names,
the values are the row values.
- :raises ProgrammingError: no primary key or missing privilege
+ :raises ProgrammingError: table has no primary key or missing privilege
+ :raises KeyError: missing key value for the row
-This method is the basic mechanism to get a single row. It assumes
-that the key specifies a unique row. If *keyname* is not specified,
-then the primary key for the table is used. If *row* is a dictionary
-then the value for the key is taken from it and it is modified to
-include the new values, replacing existing values where necessary.
-For a composite key, *keyname* can also be a sequence of key names.
-The OID is also put into the dictionary if the table has one, but in
-order to allow the caller to work with multiple tables, it is munged
-as ``oid(table)``.
+This method is the basic mechanism to get a single row. It assumes
+that the *keyname* specifies a unique row. It must be the name of a
+single column or a tuple of column names. If *keyname* is not specified,
+then the primary key for the table is used.
+
+If *row* is a dictionary, then the value for the key is taken from it.
+Otherwise, the row must be a single value or a tuple of values
+corresponding to the passed *keyname* or primary key. The fetched row
+from the table will be returned as a new dictionary or used to replace
+the existing values when row was passed as aa dictionary.
+
+The OID is also put into the dictionary if the table has one, but
+in order to allow the caller to work with multiple tables, it is
+munged as ``oid(table)`` using the actual name of the table.
insert -- insert a row into a database table
--------------------------------------------
@@ -344,17 +351,20 @@
:param col: optional keyword arguments for updating the dictionary
:returns: the new row in the database
:rtype: dict
- :raises ProgrammingError: no primary key or missing privilege
+ :raises ProgrammingError: table has no primary key or missing privilege
+ :raises KeyError: missing key value for the row
-Similar to insert but updates an existing row. The update is based on the
-OID value as munged by get or passed as keyword, or on the primary key of
-the table. The dictionary is modified to reflect any changes caused by the
+Similar to insert but updates an existing row. The update is based on
+the primary key of the table or the OID value as munged by :meth:`DB.get`
+or passed as keyword.
+
+The dictionary is then modified to reflect any changes caused by the
update due to triggers, rules, default values, etc.
Like insert, the dictionary is optional and updates will be performed
on the fields in the keywords. There must be an OID or primary key
either in the dictionary where the OID must be munged, or in the keywords
-where it can be simply the string 'oid'.
+where it can be simply the string ``'oid'``.
upsert -- insert a row with conflict resolution
-----------------------------------------------
@@ -368,7 +378,7 @@
:param col: optional keyword arguments for specifying the update
:returns: the new row in the database
:rtype: dict
- :raises ProgrammingError: no primary key or missing privilege
+ :raises ProgrammingError: table has no primary key or missing privilege
This method inserts a row into a table, but instead of raising a
ProgrammingError exception in case a row with the same primary key already
@@ -474,12 +484,20 @@
:param dict d: optional dictionary of values
:param col: optional keyword arguments for updating the dictionary
:rtype: None
+ :raises ProgrammingError: table has no primary key,
+ row is still referenced or missing privilege
+ :raises KeyError: missing key value for the row
+
+This method deletes the row from a table. It deletes based on the
+primary key of the table or the OID value as munged by :meth:`DB.get`
+or passed as keyword.
-This method deletes the row from a table. It deletes based on the OID value
-as munged by get or passed as keyword, or on the primary key of the table.
The return value is the number of deleted rows (i.e. 0 if the row did not
exist and 1 if the row was deleted).
+Note that if the row cannot be deleted because e.g. it is still referenced
+by another table, this method will raise a ProgrammingError.
+
truncate -- Quickly empty database tables
-----------------------------------------
Modified: trunk/pg.py
==============================================================================
--- trunk/pg.py Sun Jan 17 16:05:19 2016 (r764)
+++ trunk/pg.py Mon Jan 18 18:21:44 2016 (r765)
@@ -355,6 +355,9 @@
raise ValueError
return d
+ _num_types = frozenset('int float num money'
+ ' int2 int4 int8 float4 float8 numeric money'.split())
+
def _prepare_num(self, d):
"""Prepare a numeric parameter."""
if not d and d != 0:
@@ -607,20 +610,20 @@
def query(self, qstr, *args):
"""Execute a SQL command string.
- This method simply sends a SQL query to the database. If the query is
+ This method simply sends a SQL query to the database. If the query is
an insert statement that inserted exactly one row into a table that
has OIDs, the return value is the OID of the newly inserted row.
If the query is an update or delete statement, or an insert statement
that did not insert exactly one row in a table with OIDs, then the
- number of rows affected is returned as a string. If it is a statement
+ number of rows affected is returned as a string. If it is a statement
that returns rows as a result (usually a select statement, but maybe
also an "insert/update ... returning" statement), this method returns
a Query object that can be accessed via getresult() or dictresult()
- or simply printed. Otherwise, it returns `None`.
+ or simply printed. Otherwise, it returns `None`.
The query can contain numbered parameters of the form $1 in place
- of any data constant. Arguments given after the query string will
- be substituted for the corresponding numbered parameter. Parameter
+ of any data constant. Arguments given after the query string will
+ be substituted for the corresponding numbered parameter. Parameter
values can also be given as a single list or tuple argument.
"""
# Wraps shared library function for debugging.
@@ -629,14 +632,16 @@
self._do_debug(qstr)
return self.db.query(qstr, args)
- def pkey(self, table, flush=False):
+ def pkey(self, table, composite=False, flush=False):
"""Get or set the primary key of a table.
- Composite primary keys are represented as frozensets. Note that
- this raises a KeyError if the table does not have a primary key.
+ Single primary keys are returned as strings unless you
+ set the composite flag. Composite primary keys are always
+ represented as tuples. Note that this raises a KeyError
+ if the table does not have a primary key.
If flush is set then the internal cache for primary keys will
- be flushed. This may be necessary after the database schema or
+ be flushed. This may be necessary after the database schema or
the search path has been changed.
"""
pkeys = self._pkeys
@@ -646,21 +651,27 @@
try: # cache lookup
pkey = pkeys[table]
except KeyError: # cache miss, check the database
- q = ("SELECT a.attname FROM pg_index i"
+ q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
" JOIN pg_attribute a ON a.attrelid = i.indrelid"
" AND a.attnum = ANY(i.indkey)"
" AND NOT a.attisdropped"
" WHERE i.indrelid=%s::regclass"
- " AND i.indisprimary") % (
+ " AND i.indisprimary ORDER BY a.attnum") % (
self._prepare_qualified_param(table, 1),)
pkey = self.db.query(q, (table,)).getresult()
if not pkey:
raise KeyError('Table %s has no primary key' % table)
+ # we want to use the order defined in the primary key index here,
+ # not the order as defined by the columns in the table
if len(pkey) > 1:
- pkey = frozenset(k[0] for k in pkey)
+ indkey = [int(k) for k in pkey[0][2].split()]
+ pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
+ pkey = tuple(row[0] for row in pkey)
else:
pkey = pkey[0][0]
pkeys[table] = pkey # cache it
+ if composite and not isinstance(pkey, tuple):
+ pkey = (pkey,)
return pkey
def get_databases(self):
@@ -754,50 +765,64 @@
def get(self, table, row, keyname=None):
"""Get a row from a database table or view.
- This method is the basic mechanism to get a single row. The keyname
- that the key specifies a unique row. If keyname is not specified
- then the primary key for the table is used. If row is a dictionary
- then the value for the key is taken from it and it is modified to
- include the new values, replacing existing values where necessary.
- For a composite key, keyname can also be a sequence of key names.
+ This method is the basic mechanism to get a single row. It assumes
+ that the keyname specifies a unique row. It must be the name of a
+ single column or a tuple of column names. If the keyname is not
+ specified, then the primary key for the table is used.
+
+ If row is a dictionary, then the value for the key is taken from it.
+ Otherwise, the row must be a single value or a tuple of values
+ corresponding to the passed keyname or primary key. The fetched row
+ from the table will be returned as a new dictionary or used to replace
+ the existing values when row was passed as aa dictionary.
+
The OID is also put into the dictionary if the table has one, but
in order to allow the caller to work with multiple tables, it is
- munged as "oid(table)".
+ munged as "oid(table)" using the actual name of the table.
"""
- if table.endswith('*'): # scan descendant tables?
- table = table[:-1].rstrip() # need parent table name
- if not keyname:
- # use the primary key by default
- try:
- keyname = self.pkey(table)
- except KeyError:
- raise _prg_error('Table %s has no primary key' % table)
+ if table.endswith('*'): # hint for descendant tables can be ignored
+ table = table[:-1].rstrip()
attnames = self.get_attnames(table)
+ qoid = _oid_key(table) if 'oid' in attnames else None
+ if keyname and isinstance(keyname, basestring):
+ keyname = (keyname,)
+ if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
+ row['oid'] = row[qoid]
+ if not keyname:
+ try: # if keyname is not specified, try using the primary key
+ keyname = self.pkey(table, True)
+ except KeyError: # the table has no primary key
+ # try using the oid instead
+ if qoid and isinstance(row, dict) and 'oid' in row:
+ keyname = ('oid',)
+ else:
+ raise _prg_error('Table %s has no primary key' % table)
+ else: # the table has a primary key
+ # check whether all key columns have values
+ if isinstance(row, dict) and not set(keyname).issubset(row):
+ # try using the oid instead
+ if qoid and 'oid' in row:
+ keyname = ('oid',)
+ else:
+ raise KeyError(
+ 'Missing value in row for specified keyname')
+ if not isinstance(row, dict):
+ if not isinstance(row, (tuple, list)):
+ row = [row]
+ if len(keyname) != len(row):
+ raise KeyError(
+ 'Differing number of items in keyname and row')
+ row = dict(zip(keyname, row))
params = []
param = partial(self._prepare_param, params=params)
col = self.escape_identifier
- # We want the oid for later updates if that isn't the key.
- # To allow users to work with multiple tables, we munge
- # the name of the "oid" key by adding the name of the table.
- qoid = _oid_key(table)
- if keyname == 'oid':
- if isinstance(row, dict):
- if qoid not in row:
- raise _prg_error('%s not in row' % qoid)
- else:
- row = {qoid: row}
- what = '*'
- where = 'oid = %s' % param(row[qoid], 'int')
- else:
- keyname = [keyname] if isinstance(
- keyname, basestring) else sorted(keyname)
- if not isinstance(row, dict):
- if len(keyname) > 1:
- raise _prg_error('Composite key needs dict as row')
- row = dict((k, row) for k in keyname)
- what = ', '.join(col(k) for k in attnames)
- where = ' AND '.join('%s = %s' % (
- col(k), param(row[k], attnames[k])) for k in keyname)
+ what = 'oid, *' if qoid else '*'
+ where = ' AND '.join('%s = %s' % (
+ col(k), param(row[k], attnames[k])) for k in keyname)
+ if 'oid' in row:
+ if qoid:
+ row[qoid] = row['oid']
+ del row['oid']
q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
what, self._escape_qualified_name(table), where)
self._do_debug(q, params)
@@ -807,9 +832,9 @@
raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
table, where, self._list_params(params)))
for n, value in res[0].items():
- if n == 'oid':
+ if qoid and n == 'oid':
n = qoid
- elif attnames.get(n) == 'bytea':
+ elif value is not None and attnames.get(n) == 'bytea':
value = self.unescape_bytea(value)
row[n] = value
return row
@@ -830,12 +855,15 @@
Note: The method currently doesn't support insert into views
although PostgreSQL does.
"""
- if 'oid' in kw:
- del kw['oid']
+ if table.endswith('*'): # hint for descendant tables can be ignored
+ table = table[:-1].rstrip()
if row is None:
row = {}
row.update(kw)
+ if 'oid' in row:
+ del row['oid'] # do not insert oid
attnames = self.get_attnames(table)
+ qoid = _oid_key(table) if 'oid' in attnames else None
params = []
param = partial(self._prepare_param, params=params)
col = self.escape_identifier
@@ -845,67 +873,76 @@
names.append(col(n))
values.append(param(row[n], attnames[n]))
names, values = ', '.join(names), ', '.join(values)
- ret = 'oid, *' if 'oid' in attnames else '*'
+ ret = 'oid, *' if qoid else '*'
q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
self._escape_qualified_name(table), names, values, ret)
self._do_debug(q, params)
q = self.db.query(q, params)
- res = q.dictresult() # this will always return a row
- for n, value in res[0].items():
- if n == 'oid':
- n = _oid_key(table)
- elif attnames.get(n) == 'bytea' and value is not None:
- value = self.unescape_bytea(value)
- row[n] = value
+ res = q.dictresult()
+ if res: # this should always be true
+ for n, value in res[0].items():
+ if qoid and n == 'oid':
+ n = qoid
+ elif value is not None and attnames.get(n) == 'bytea':
+ value = self.unescape_bytea(value)
+ row[n] = value
return row
def update(self, table, row=None, **kw):
"""Update an existing row in a database table.
Similar to insert but updates an existing row. The update is based
- on the OID value as munged by get or passed as keyword, or on the
- primary key of the table. The dictionary is modified to reflect
- any changes caused by the update due to triggers, rules, default
- values, etc.
- """
- # Update always works on the oid which get() returns if available,
- # otherwise use the primary key. Fail if neither.
- # Note that we only accept oid key from named args for safety.
- qoid = _oid_key(table)
- if 'oid' in kw:
- kw[qoid] = kw.pop('oid')
+ on the primary key of the table or the OID value as munged by get
+ or passed as keyword.
+
+ The dictionary is then modified to reflect any changes caused by the
+ update due to triggers, rules, default values, etc.
+ """
+ if table.endswith('*'):
+ table = table[:-1].rstrip() # need parent table name
+ attnames = self.get_attnames(table)
+ qoid = _oid_key(table) if 'oid' in attnames else None
if row is None:
row = {}
+ elif 'oid' in row:
+ del row['oid'] # only accept oid key from named args for safety
row.update(kw)
- attnames = self.get_attnames(table)
+ if qoid and qoid in row and 'oid' not in row:
+ row['oid'] = row[qoid]
+ try: # try using the primary key
+ keyname = self.pkey(table, True)
+ except KeyError: # the table has no primary key
+ # try using the oid instead
+ if qoid and 'oid' in row:
+ keyname = ('oid',)
+ else:
+ raise _prg_error('Table %s has no primary key' % table)
+ else: # the table has a primary key
+ # check whether all key columns have values
+ if not set(keyname).issubset(row):
+ # try using the oid instead
+ if qoid and 'oid' in row:
+ keyname = ('oid',)
+ else:
+ raise KeyError('Missing primary key in row')
params = []
param = partial(self._prepare_param, params=params)
col = self.escape_identifier
- if qoid in row:
- where = 'oid = %s' % param(row[qoid], 'int')
- keyname = []
- else:
- try:
- keyname = self.pkey(table)
- except KeyError:
- raise _prg_error('Table %s has no primary key' % table)
- keyname = [keyname] if isinstance(
- keyname, basestring) else sorted(keyname)
- try:
- where = ' AND '.join('%s = %s' % (
- col(k), param(row[k], attnames[k])) for k in keyname)
- except KeyError:
- raise _prg_error('Update operation needs primary key or oid')
- keyname = set(keyname)
- keyname.add('oid')
+ where = ' AND '.join('%s = %s' % (
+ col(k), param(row[k], attnames[k])) for k in keyname)
+ if 'oid' in row:
+ if qoid:
+ row[qoid] = row['oid']
+ del row['oid']
values = []
+ keyname = set(keyname)
for n in attnames:
if n in row and n not in keyname:
values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
if not values:
return row
values = ', '.join(values)
- ret = 'oid, *' if 'oid' in attnames else '*'
+ ret = 'oid, *' if qoid else '*'
q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
self._escape_qualified_name(table), values, where, ret)
self._do_debug(q, params)
@@ -913,9 +950,9 @@
res = q.dictresult()
if res: # may be empty when row does not exist
for n, value in res[0].items():
- if n == 'oid':
+ if qoid and n == 'oid':
n = qoid
- elif attnames.get(n) == 'bytea' and value is not None:
+ elif value is not None and attnames.get(n) == 'bytea':
value = self.unescape_bytea(value)
row[n] = value
return row
@@ -963,11 +1000,16 @@
Note: The method uses the PostgreSQL "upsert" feature which is
only available since PostgreSQL 9.5.
"""
- if 'oid' in kw:
- del kw['oid']
+ if table.endswith('*'): # hint for descendant tables can be ignored
+ table = table[:-1].rstrip()
if row is None:
row = {}
+ if 'oid' in row:
+ del row['oid'] # do not insert oid
+ if 'oid' in kw:
+ del kw['oid'] # do not update oid
attnames = self.get_attnames(table)
+ qoid = _oid_key(table) if 'oid' in attnames else None
params = []
param = partial(self._prepare_param,params=params)
col = self.escape_identifier
@@ -978,11 +1020,9 @@
values.append(param(row[n], attnames[n]))
names, values = ', '.join(names), ', '.join(values)
try:
- keyname = self.pkey(table)
+ keyname = self.pkey(table, True)
except KeyError:
raise _prg_error('Table %s has no primary key' % table)
- keyname = [keyname] if isinstance(
- keyname, basestring) else sorted(keyname)
target = ', '.join(col(k) for k in keyname)
update = []
keyname = set(keyname)
@@ -997,7 +1037,7 @@
if not values:
return row
do = 'update set %s' % ', '.join(update) if update else 'nothing'
- ret = 'oid, *' if 'oid' in attnames else '*'
+ ret = 'oid, *' if qoid else '*'
q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
' ON CONFLICT (%s) DO %s RETURNING %s') % (
self._escape_qualified_name(table), names, values,
@@ -1011,11 +1051,11 @@
'Upsert operation is not supported by PostgreSQL version')
raise # re-raise original error
res = q.dictresult()
- if update: # may be empty with "do nothing"
+ if res: # may be empty with "do nothing"
for n, value in res[0].items():
- if n == 'oid':
- n = _oid_key(table)
- elif attnames.get(n) == 'bytea' and value is not None:
+ if qoid and n == 'oid':
+ n = qoid
+ elif value is not None and attnames.get(n) == 'bytea':
value = self.unescape_bytea(value)
row[n] = value
else:
@@ -1037,11 +1077,9 @@
for n, t in attnames.items():
if n == 'oid':
continue
- if t in ('int', 'integer', 'smallint', 'bigint',
- 'float', 'real', 'double precision',
- 'num', 'numeric', 'money'):
+ if t in self._num_types:
row[n] = 0
- elif t in ('bool', 'boolean'):
+ elif t == 'bool':
row[n] = self._make_bool(False)
else:
row[n] = ''
@@ -1051,38 +1089,51 @@
"""Delete an existing row in a database table.
This method deletes the row from a table. It deletes based on the
- OID value as munged by get or passed as keyword, or on the primary
- key of the table. The return value is the number of deleted rows
- (i.e. 0 if the row did not exist and 1 if the row was deleted).
- """
- # Like update, delete works on the oid.
- # One day we will be testing that the record to be deleted
- # isn't referenced somewhere (or else PostgreSQL will).
- # Note that we only accept oid key from named args for safety.
- qoid = _oid_key(table)
- if 'oid' in kw:
- kw[qoid] = kw.pop('oid')
+ primary key of the table or the OID value as munged by get() or
+ passed as keyword.
+
+ The return value is the number of deleted rows (i.e. 0 if the row
+ did not exist and 1 if the row was deleted).
+
+ Note that if the row cannot be deleted because e.g. it is still
+ referenced by another table, this method raises a ProgrammingError.
+ """
+ if table.endswith('*'): # hint for descendant tables can be ignored
+ table = table[:-1].rstrip()
+ attnames = self.get_attnames(table)
+ qoid = _oid_key(table) if 'oid' in attnames else None
if row is None:
row = {}
+ elif 'oid' in row:
+ del row['oid'] # only accept oid key from named args for safety
row.update(kw)
+ if qoid and qoid in row and 'oid' not in row:
+ row['oid'] = row[qoid]
+ try: # try using the primary key
+ keyname = self.pkey(table, True)
+ except KeyError: # the table has no primary key
+ # try using the oid instead
+ if qoid and 'oid' in row:
+ keyname = ('oid',)
+ else:
+ raise _prg_error('Table %s has no primary key' % table)
+ else: # the table has a primary key
+ # check whether all key columns have values
+ if not set(keyname).issubset(row):
+ # try using the oid instead
+ if qoid and 'oid' in row:
+ keyname = ('oid',)
+ else:
+ raise KeyError('Missing primary key in row')
params = []
param = partial(self._prepare_param, params=params)
- if qoid in row:
- where = 'oid = %s' % param(row[qoid], 'int')
- else:
- try:
- keyname = self.pkey(table)
- except KeyError:
- raise _prg_error('Table %s has no primary key' % table)
- keyname = [keyname] if isinstance(
- keyname, basestring) else sorted(keyname)
- attnames = self.get_attnames(table)
- col = self.escape_identifier
- try:
- where = ' AND '.join('%s = %s' % (
- col(k), param(row[k], attnames[k])) for k in keyname)
- except KeyError:
- raise _prg_error('Delete operation needs primary key or oid')
+ col = self.escape_identifier
+ where = ' AND '.join('%s = %s' % (
+ col(k), param(row[k], attnames[k])) for k in keyname)
+ if 'oid' in row:
+ if qoid:
+ row[qoid] = row['oid']
+ del row['oid']
q = 'DELETE FROM %s WHERE %s' % (
self._escape_qualified_name(table), where)
self._do_debug(q, params)
Modified: trunk/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/tests/test_classic_dbwrapper.py Sun Jan 17 16:05:19 2016
(r764)
+++ trunk/tests/test_classic_dbwrapper.py Mon Jan 18 18:21:44 2016
(r765)
@@ -708,8 +708,9 @@
def testPkey(self):
query = self.db.query
pkey = self.db.pkey
+ self.assertRaises(KeyError, pkey, 'test')
for t in ('pkeytest', 'primary key test'):
- for n in range(7):
+ for n in range(8):
query('drop table if exists "%s%d"' % (t, n))
self.addCleanup(query, 'drop table "%s%d"' % (t, n))
query('create table "%s0" ('
@@ -723,10 +724,14 @@
" h smallint, i smallint,"
" primary key (f, h))" % t)
query('create table "%s4" ('
- "more_than_one_letter varchar primary key)" % t)
+ "e smallint, f smallint, g smallint,"
+ " h smallint, i smallint,"
+ " primary key (h, f))" % t)
query('create table "%s5" ('
- '"with space" date primary key)' % t)
+ "more_than_one_letter varchar primary key)" % t)
query('create table "%s6" ('
+ '"with space" date primary key)' % t)
+ query('create table "%s7" ('
'a_very_long_column_name varchar,'
' "with space" date,'
' "42" int,'
@@ -734,16 +739,26 @@
' "with space", "42"))' % t)
self.assertRaises(KeyError, pkey, '%s0' % t)
self.assertEqual(pkey('%s1' % t), 'b')
+ self.assertEqual(pkey('%s1' % t, True), ('b',))
+ self.assertEqual(pkey('%s1' % t, composite=False), 'b')
+ self.assertEqual(pkey('%s1' % t, composite=True), ('b',))
self.assertEqual(pkey('%s2' % t), 'd')
+ self.assertEqual(pkey('%s2' % t, composite=True), ('d',))
r = pkey('%s3' % t)
- self.assertIsInstance(r, frozenset)
- self.assertEqual(r, frozenset('fh'))
- self.assertEqual(pkey('%s4' % t), 'more_than_one_letter')
- self.assertEqual(pkey('%s5' % t), 'with space')
- r = pkey('%s6' % t)
- self.assertIsInstance(r, frozenset)
- self.assertEqual(r, frozenset([
- 'a_very_long_column_name', 'with space', '42']))
+ self.assertIsInstance(r, tuple)
+ self.assertEqual(r, ('f', 'h'))
+ r = pkey('%s3' % t, composite=False)
+ self.assertIsInstance(r, tuple)
+ self.assertEqual(r, ('f', 'h'))
+ r = pkey('%s4' % t)
+ self.assertIsInstance(r, tuple)
+ self.assertEqual(r, ('h', 'f'))
+ self.assertEqual(pkey('%s5' % t), 'more_than_one_letter')
+ self.assertEqual(pkey('%s6' % t), 'with space')
+ r = pkey('%s7' % t)
+ self.assertIsInstance(r, tuple)
+ self.assertEqual(r, (
+ 'a_very_long_column_name', 'with space', '42'))
# a newly added primary key will be detected
query('alter table "%s0" add primary key (a)' % t)
self.assertEqual(pkey('%s0' % t), 'a')
@@ -955,6 +970,66 @@
get = self.db.get
query = self.db.query
table = 'get_test_table'
+ self.assertRaises(TypeError, get)
+ self.assertRaises(TypeError, get, table)
+ query('drop table if exists "%s"' % table)
+ self.addCleanup(query, 'drop table "%s"' % table)
+ query('create table "%s" ('
+ "n integer, t text) without oids" % table)
+ for n, t in enumerate('xyz'):
+ query('insert into "%s" values('"%d, '%s')"
+ % (table, n + 1, t))
+ self.assertRaises(pg.ProgrammingError, get, table, 2)
+ r = get(table, 2, 'n')
+ self.assertIsInstance(r, dict)
+ self.assertEqual(r, dict(n=2, t='y'))
+ r = get(table, 1, 'n')
+ self.assertEqual(r, dict(n=1, t='x'))
+ r = get(table, (3,), ('n',))
+ self.assertEqual(r, dict(n=3, t='z'))
+ r = get(table, 'y', 't')
+ self.assertEqual(r, dict(n=2, t='y'))
+ self.assertRaises(pg.DatabaseError, get, table, 4)
+ self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
+ self.assertRaises(pg.DatabaseError, get, table, 'y')
+ self.assertRaises(pg.DatabaseError, get, table, 2, 't')
+ s = dict(n=3)
+ self.assertRaises(pg.ProgrammingError, get, table, s)
+ r = get(table, s, 'n')
+ self.assertIs(r, s)
+ self.assertEqual(r, dict(n=3, t='z'))
+ s.update(t='x')
+ r = get(table, s, 't')
+ self.assertIs(r, s)
+ self.assertEqual(s, dict(n=1, t='x'))
+ r = get(table, s, ('n', 't'))
+ self.assertIs(r, s)
+ self.assertEqual(r, dict(n=1, t='x'))
+ query('alter table "%s" alter n set not null' % table)
+ query('alter table "%s" add primary key (n)' % table)
+ r = get(table, 2)
+ self.assertIsInstance(r, dict)
+ self.assertEqual(r, dict(n=2, t='y'))
+ self.assertEqual(get(table, 1)['t'], 'x')
+ self.assertEqual(get(table, 3)['t'], 'z')
+ self.assertEqual(get(table + '*', 2)['t'], 'y')
+ self.assertEqual(get(table + ' *', 2)['t'], 'y')
+ self.assertRaises(KeyError, get, table, (2, 2))
+ s = dict(n=3)
+ r = get(table, s)
+ self.assertIs(r, s)
+ self.assertEqual(r, dict(n=3, t='z'))
+ s.update(n=1)
+ self.assertEqual(get(table, s)['t'], 'x')
+ s.update(n=2)
+ self.assertEqual(get(table, r)['t'], 'y')
+ s.pop('n')
+ self.assertRaises(KeyError, get, table, s)
+
+ def testGetWithOid(self):
+ get = self.db.get
+ query = self.db.query
+ table = 'get_with_oid_test_table'
query('drop table if exists "%s"' % table)
self.addCleanup(query, 'drop table "%s"' % table)
query('create table "%s" ('
@@ -963,14 +1038,25 @@
query('insert into "%s" values('"%d, '%s')"
% (table, n + 1, t))
self.assertRaises(pg.ProgrammingError, get, table, 2)
- self.assertRaises(pg.ProgrammingError, get, table, {}, 'oid')
+ self.assertRaises(KeyError, get, table, {}, 'oid')
r = get(table, 2, 'n')
- oid_table = 'oid(%s)' % table
- self.assertIn(oid_table, r)
- oid = r[oid_table]
+ qoid = 'oid(%s)' % table
+ self.assertIn(qoid, r)
+ oid = r[qoid]
self.assertIsInstance(oid, int)
- result = {'t': 'y', 'n': 2, oid_table: oid}
+ result = {'t': 'y', 'n': 2, qoid: oid}
+ self.assertEqual(r, result)
+ r = get(table, oid, 'oid')
self.assertEqual(r, result)
+ r = get(table, dict(oid=oid))
+ self.assertEqual(r, result)
+ r = get(table, dict(oid=oid), 'oid')
+ self.assertEqual(r, result)
+ r = get(table, {qoid: oid})
+ self.assertEqual(r, result)
+ r = get(table, {qoid: oid}, 'oid')
+ self.assertEqual(r, result)
+ self.assertEqual(get(table + '*', 2, 'n'), r)
self.assertEqual(get(table + ' *', 2, 'n'), r)
self.assertEqual(get(table, oid, 'oid')['t'], 'y')
self.assertEqual(get(table, 1, 'n')['t'], 'x')
@@ -980,6 +1066,7 @@
r['n'] = 3
self.assertEqual(get(table, r, 'n')['t'], 'z')
self.assertEqual(get(table, 1, 'n')['t'], 'x')
+ self.assertEqual(get(table, r, 'oid')['t'], 'z')
query('alter table "%s" alter n set not null' % table)
query('alter table "%s" add primary key (n)' % table)
self.assertEqual(get(table, 3)['t'], 'z')
@@ -991,6 +1078,22 @@
self.assertEqual(get(table, r)['t'], 'z')
r['n'] = 2
self.assertEqual(get(table, r)['t'], 'y')
+ r = get(table, oid, 'oid')
+ self.assertEqual(r, result)
+ r = get(table, dict(oid=oid))
+ self.assertEqual(r, result)
+ r = get(table, dict(oid=oid), 'oid')
+ self.assertEqual(r, result)
+ r = get(table, {qoid: oid})
+ self.assertEqual(r, result)
+ r = get(table, {qoid: oid}, 'oid')
+ self.assertEqual(r, result)
+ r = get(table, dict(oid=oid, n=1))
+ self.assertEqual(r['n'], 1)
+ self.assertNotEqual(r[qoid], oid)
+ r = get(table, dict(oid=oid, t='z'), 't')
+ self.assertEqual(r['n'], 3)
+ self.assertNotEqual(r[qoid], oid)
def testGetWithCompositeKey(self):
get = self.db.get
@@ -1004,6 +1107,13 @@
query('insert into "%s" values('
"%d, '%s')" % (table, n + 1, t))
self.assertEqual(get(table, 2)['t'], 'b')
+ self.assertEqual(get(table, 1, 'n')['t'], 'a')
+ self.assertEqual(get(table, 2, ('n',))['t'], 'b')
+ self.assertEqual(get(table, 3, ['n'])['t'], 'c')
+ self.assertEqual(get(table, (2,), ('n',))['t'], 'b')
+ self.assertEqual(get(table, 'b', 't')['n'], 2)
+ self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
+ self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
table = 'get_test_table_2'
query('drop table if exists "%s"' % table)
self.addCleanup(query, 'drop table "%s"' % table)
@@ -1014,12 +1124,18 @@
t = chr(ord('a') + 2 * n + m)
query('insert into "%s" values('
"%d, %d, '%s')" % (table, n + 1, m + 1, t))
- self.assertRaises(pg.ProgrammingError, get, table, 2)
+ self.assertRaises(KeyError, get, table, 2)
+ self.assertEqual(get(table, (1, 1))['t'], 'a')
+ self.assertEqual(get(table, (1, 2))['t'], 'b')
+ self.assertEqual(get(table, (2, 1))['t'], 'c')
+ self.assertEqual(get(table, (1, 2), ('n', 'm'))['t'], 'b')
+ self.assertEqual(get(table, (1, 2), ('m', 'n'))['t'], 'c')
+ self.assertEqual(get(table, (3, 1), ('n', 'm'))['t'], 'e')
+ self.assertEqual(get(table, (1, 3), ('m', 'n'))['t'], 'e')
self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
- r = get(table, dict(n=1, m=2), ('n', 'm'))
- self.assertEqual(r['t'], 'b')
- r = get(table, dict(n=3, m=2), frozenset(['n', 'm']))
- self.assertEqual(r['t'], 'f')
+ self.assertEqual(get(table, dict(n=1, m=2), ('n', 'm'))['t'], 'b')
+ self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c')
+ self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f')
def testGetWithQuotedNames(self):
get = self.db.get
@@ -1198,21 +1314,66 @@
r = insert('test_table', n=1)
self.assertIsInstance(r, dict)
self.assertEqual(r['n'], 1)
+ self.assertNotIn('oid', r)
qoid = 'oid(test_table)'
self.assertIn(qoid, r)
- r = insert('test_table', n=2, oid='invalid')
+ oid = r[qoid]
+ self.assertEqual(sorted(r.keys()), ['n', qoid])
+ r = insert('test_table', n=2, oid=oid)
self.assertIsInstance(r, dict)
self.assertEqual(r['n'], 2)
- r['n'] = 3
- r = insert('test_table', r)
+ self.assertIn(qoid, r)
+ self.assertNotEqual(r[qoid], oid)
+ self.assertNotIn('oid', r)
+ r = insert('test_table', None, n=3)
self.assertIsInstance(r, dict)
self.assertEqual(r['n'], 3)
+ s = r
+ r = insert('test_table', r)
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 3)
+ r = insert('test_table *', r)
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 3)
r = insert('test_table', r, n=4)
- self.assertIsInstance(r, dict)
+ self.assertIs(r, s)
self.assertEqual(r['n'], 4)
- q = 'select n from test_table order by 1 limit 5'
+ self.assertNotIn('oid', r)
+ self.assertIn(qoid, r)
+ oid = r[qoid]
+ r = insert('test_table', r, n=5, oid=oid)
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 5)
+ self.assertIn(qoid, r)
+ self.assertNotEqual(r[qoid], oid)
+ self.assertNotIn('oid', r)
+ r['oid'] = oid = r[qoid]
+ r = insert('test_table', r, n=6)
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 6)
+ self.assertIn(qoid, r)
+ self.assertNotEqual(r[qoid], oid)
+ self.assertNotIn('oid', r)
+ q = 'select n from test_table order by 1 limit 9'
+ r = ' '.join(str(row[0]) for row in query(q).getresult())
+ self.assertEqual(r, '1 2 3 3 3 4 5 6')
+ query("truncate test_table")
+ query("alter table test_table add unique (n)")
+ r = insert('test_table', dict(n=7))
+ self.assertIsInstance(r, dict)
+ self.assertEqual(r['n'], 7)
+ self.assertRaises(pg.ProgrammingError, insert, 'test_table', r)
+ r['n'] = 6
+ self.assertRaises(pg.ProgrammingError, insert, 'test_table', r, n=7)
+ self.assertIsInstance(r, dict)
+ self.assertEqual(r['n'], 7)
+ r['n'] = 6
+ r = insert('test_table', r)
+ self.assertIsInstance(r, dict)
+ self.assertEqual(r['n'], 6)
r = query(q).getresult()
- self.assertEqual(r, [(1,), (2,), (3,), (4,)])
+ r = ' '.join(str(row[0]) for row in query(q).getresult())
+ self.assertEqual(r, '6 7')
def testInsertWithQuotedNames(self):
insert = self.db.insert
@@ -1266,31 +1427,71 @@
self.addCleanup(query, "drop table test_table")
query("create table test_table (n int) with oids")
query("insert into test_table values (1)")
- r = get('test_table', 1, 'n')
- self.assertIsInstance(r, dict)
- self.assertEqual(r['n'], 1)
- r['n'] = 2
- r = update('test_table', r)
- self.assertIsInstance(r, dict)
+ s = get('test_table', 1, 'n')
+ self.assertIsInstance(s, dict)
+ self.assertEqual(s['n'], 1)
+ s['n'] = 2
+ r = update('test_table', s)
+ self.assertIs(r, s)
self.assertEqual(r['n'], 2)
qoid = 'oid(test_table)'
self.assertIn(qoid, r)
+ self.assertNotIn('oid', r)
+ self.assertEqual(sorted(r.keys()), ['n', qoid])
r['n'] = 3
- r = update('test_table', r, oid=r.pop(qoid))
- self.assertIsInstance(r, dict)
+ oid = r.pop(qoid)
+ r = update('test_table', r, oid=oid)
+ self.assertIs(r, s)
self.assertEqual(r['n'], 3)
r.pop(qoid)
self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
- r = get('test_table', 3, 'n')
- self.assertIsInstance(r, dict)
- self.assertEqual(r['n'], 3)
- r.pop('n')
- r = update('test_table', r)
- r.pop(qoid)
+ s = get('test_table', 3, 'n')
+ self.assertIsInstance(s, dict)
+ self.assertEqual(s['n'], 3)
+ s.pop('n')
+ r = update('test_table', s)
+ oid = r.pop(qoid)
self.assertEqual(r, {})
- q = 'select n from test_table limit 2'
+ q = "select n from test_table limit 2"
r = query(q).getresult()
self.assertEqual(r, [(3,)])
+ query("insert into test_table values (1)")
+ self.assertRaises(pg.ProgrammingError,
+ update, 'test_table', dict(oid=oid, n=4))
+ r = update('test_table', dict(n=4), oid=oid)
+ self.assertEqual(r['n'], 4)
+ r = update('test_table *', dict(n=5), oid=oid)
+ self.assertEqual(r['n'], 5)
+ query("alter table test_table add column m int")
+ query("alter table test_table add primary key (n)")
+ self.assertIn('m', self.db.get_attnames('test_table', flush=True))
+ self.assertEqual('n', self.db.pkey('test_table', flush=True))
+ s = dict(n=1, m=4)
+ r = update('test_table', s)
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 1)
+ self.assertEqual(r['m'], 4)
+ s = dict(m=7)
+ r = update('test_table', s, n=5)
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 5)
+ self.assertEqual(r['m'], 7)
+ q = "select n, m from test_table order by 1 limit 3"
+ r = query(q).getresult()
+ self.assertEqual(r, [(1, 4), (5, 7)])
+ s = dict(m=9, oid=oid)
+ self.assertRaises(KeyError, update, 'test_table', s)
+ r = update('test_table', s, oid=oid)
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 5)
+ self.assertEqual(r['m'], 9)
+ s = dict(n=1, m=3, oid=oid)
+ r = update('test_table', s)
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 1)
+ self.assertEqual(r['m'], 3)
+ r = query(q).getresult()
+ self.assertEqual(r, [(1, 3), (5, 9)])
def testUpdateWithCompositeKey(self):
update = self.db.update
@@ -1303,8 +1504,7 @@
for n, t in enumerate('abc'):
query('insert into "%s" values('
"%d, '%s')" % (table, n + 1, t))
- self.assertRaises(pg.ProgrammingError, update,
- table, dict(t='b'))
+ self.assertRaises(KeyError, update, table, dict(t='b'))
s = dict(n=2, t='d')
r = update(table, s)
self.assertIs(r, s)
@@ -1333,10 +1533,9 @@
t = chr(ord('a') + 2 * n + m)
query('insert into "%s" values('
"%d, %d, '%s')" % (table, n + 1, m + 1, t))
- self.assertRaises(pg.ProgrammingError, update,
- table, dict(n=2, t='b'))
+ self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
self.assertEqual(update(table,
- dict(n=2, m=2, t='x'))['t'], 'x')
+ dict(n=2, m=2, t='x'))['t'], 'x')
q = 'select t from "%s" where n=2 order by m' % table
r = [r[0] for r in query(q).getresult()]
self.assertEqual(r, ['c', 'x'])
@@ -1440,6 +1639,88 @@
r = upsert(table, s, oid='invalid')
self.assertIs(r, s)
+ def testUpsertWithOid(self):
+ upsert = self.db.upsert
+ get = self.db.get
+ query = self.db.query
+ query("drop table if exists test_table")
+ self.addCleanup(query, "drop table test_table")
+ query("create table test_table (n int) with oids")
+ query("insert into test_table values (1)")
+ self.assertRaises(pg.ProgrammingError,
+ upsert, 'test_table', dict(n=2))
+ r = get('test_table', 1, 'n')
+ self.assertIsInstance(r, dict)
+ self.assertEqual(r['n'], 1)
+ qoid = 'oid(test_table)'
+ self.assertIn(qoid, r)
+ self.assertNotIn('oid', r)
+ oid = r[qoid]
+ self.assertRaises(pg.ProgrammingError,
+ upsert, 'test_table', dict(n=2, oid=oid))
+ query("alter table test_table add column m int")
+ query("alter table test_table add primary key (n)")
+ self.assertIn('m', self.db.get_attnames('test_table', flush=True))
+ self.assertEqual('n', self.db.pkey('test_table', flush=True))
+ s = dict(n=2)
+ r = upsert('test_table', s)
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 2)
+ self.assertIsNone(r['m'])
+ q = query("select n, m from test_table order by n limit 3")
+ self.assertEqual(q.getresult(), [(1, None), (2, None)])
+ r['oid'] = oid
+ r = upsert('test_table', r)
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 2)
+ self.assertIsNone(r['m'])
+ self.assertIn(qoid, r)
+ self.assertNotIn('oid', r)
+ self.assertNotEqual(r[qoid], oid)
+ r['m'] = 7
+ r = upsert('test_table', r)
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 2)
+ self.assertEqual(r['m'], 7)
+ r.update(n=1, m=3)
+ r = upsert('test_table', r)
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 1)
+ self.assertEqual(r['m'], 3)
+ q = query("select n, m from test_table order by n limit 3")
+ self.assertEqual(q.getresult(), [(1, 3), (2, 7)])
+ r = upsert('test_table', r, oid='invalid')
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 1)
+ self.assertEqual(r['m'], 3)
+ r['m'] = 5
+ r = upsert('test_table', r, m=False)
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 1)
+ self.assertEqual(r['m'], 3)
+ r['m'] = 5
+ r = upsert('test_table', r, m=True)
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 1)
+ self.assertEqual(r['m'], 5)
+ r.update(n=2, m=1)
+ r = upsert('test_table', r, m='included.m')
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 2)
+ self.assertEqual(r['m'], 7)
+ r['m'] = 9
+ r = upsert('test_table', r, m='excluded.m')
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 2)
+ self.assertEqual(r['m'], 9)
+ r['m'] = 8
+ r = upsert('test_table *', r, m='included.m + 1')
+ self.assertIs(r, s)
+ self.assertEqual(r['n'], 2)
+ self.assertEqual(r['m'], 10)
+ q = query("select n, m from test_table order by n limit 3")
+ self.assertEqual(q.getresult(), [(1, 5), (2, 10)])
+
def testUpsertWithCompositeKey(self):
upsert = self.db.upsert
query = self.db.query
@@ -1625,40 +1906,91 @@
query("drop table if exists test_table")
self.addCleanup(query, "drop table test_table")
query("create table test_table (n int) with oids")
- query("insert into test_table values (1)")
- query("insert into test_table values (2)")
- query("insert into test_table values (3)")
+ for i in range(6):
+ query("insert into test_table values (%d)" % (i + 1))
r = dict(n=3)
self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
- r = get('test_table', 1, 'n')
- self.assertIsInstance(r, dict)
- self.assertEqual(r['n'], 1)
+ s = get('test_table', 1, 'n')
qoid = 'oid(test_table)'
- self.assertIn(qoid, r)
- oid = r[qoid]
- self.assertIsInstance(oid, int)
- s = delete('test_table', r)
- self.assertEqual(s, 1)
- s = delete('test_table', r)
- self.assertEqual(s, 0)
- r = get('test_table', 2, 'n')
- self.assertIsInstance(r, dict)
- self.assertEqual(r['n'], 2)
- qoid = 'oid(test_table)'
- self.assertIn(qoid, r)
- oid = r[qoid]
- self.assertIsInstance(oid, int)
- r['oid'] = r.pop(qoid)
- self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
- s = delete('test_table', r, oid=oid)
- self.assertEqual(s, 1)
- s = delete('test_table', r)
- self.assertEqual(s, 0)
- s = delete('test_table', r, n=3)
- self.assertEqual(s, 0)
- q = 'select n from test_table order by 1 limit 3'
- r = query(q).getresult()
- self.assertEqual(r, [(3,)])
+ self.assertIn(qoid, s)
+ r = delete('test_table', s)
+ self.assertEqual(r, 1)
+ r = delete('test_table', s)
+ self.assertEqual(r, 0)
+ q = "select min(n),count(n) from test_table"
+ self.assertEqual(query(q).getresult()[0], (2, 5))
+ oid = get('test_table', 2, 'n')[qoid]
+ s = dict(oid=oid, n=2)
+ self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
+ r = delete('test_table', None, oid=oid)
+ self.assertEqual(r, 1)
+ r = delete('test_table', None, oid=oid)
+ self.assertEqual(r, 0)
+ self.assertEqual(query(q).getresult()[0], (3, 4))
+ s = dict(oid=oid, n=2)
+ oid = get('test_table', 3, 'n')[qoid]
+ self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
+ r = delete('test_table', s, oid=oid)
+ self.assertEqual(r, 1)
+ r = delete('test_table', s, oid=oid)
+ self.assertEqual(r, 0)
+ self.assertEqual(query(q).getresult()[0], (4, 3))
+ s = get('test_table', 4, 'n')
+ r = delete('test_table *', s)
+ self.assertEqual(r, 1)
+ r = delete('test_table *', s)
+ self.assertEqual(r, 0)
+ self.assertEqual(query(q).getresult()[0], (5, 2))
+ oid = get('test_table', 5, 'n')[qoid]
+ s = {qoid: oid, 'm': 4}
+ r = delete('test_table', s, m=6)
+ self.assertEqual(r, 1)
+ r = delete('test_table *', s)
+ self.assertEqual(r, 0)
+ self.assertEqual(query(q).getresult()[0], (6, 1))
+ query("alter table test_table add column m int")
+ query("alter table test_table add primary key (n)")
+ self.assertIn('m', self.db.get_attnames('test_table', flush=True))
+ self.assertEqual('n', self.db.pkey('test_table', flush=True))
+ for i in range(5):
+ query("insert into test_table values (%d, %d)" % (i + 1, i + 2))
+ s = dict(m=2)
+ self.assertRaises(KeyError, delete, 'test_table', s)
+ s = dict(m=2, oid=oid)
+ self.assertRaises(KeyError, delete, 'test_table', s)
+ r = delete('test_table', dict(m=2), oid=oid)
+ self.assertEqual(r, 0)
+ oid = get('test_table', 1, 'n')[qoid]
+ s = dict(oid=oid)
+ self.assertRaises(KeyError, delete, 'test_table', s)
+ r = delete('test_table', s, oid=oid)
+ self.assertEqual(r, 1)
+ r = delete('test_table', s, oid=oid)
+ self.assertEqual(r, 0)
+ self.assertEqual(query(q).getresult()[0], (2, 5))
+ s = get('test_table', 2, 'n')
+ del s['n']
+ r = delete('test_table', s)
+ self.assertEqual(r, 1)
+ r = delete('test_table', s)
+ self.assertEqual(r, 0)
+ self.assertEqual(query(q).getresult()[0], (3, 4))
+ r = delete('test_table', n=3)
+ self.assertEqual(r, 1)
+ r = delete('test_table', n=3)
+ self.assertEqual(r, 0)
+ self.assertEqual(query(q).getresult()[0], (4, 3))
+ r = delete('test_table', None, n=4)
+ self.assertEqual(r, 1)
+ r = delete('test_table', None, n=4)
+ self.assertEqual(r, 0)
+ self.assertEqual(query(q).getresult()[0], (5, 2))
+ s = dict(n=6)
+ r = delete('test_table', s, n=5)
+ self.assertEqual(r, 1)
+ r = delete('test_table', s, n=5)
+ self.assertEqual(r, 0)
+ self.assertEqual(query(q).getresult()[0], (6, 1))
def testDeleteWithCompositeKey(self):
query = self.db.query
@@ -1670,15 +2002,12 @@
for n, t in enumerate('abc'):
query("insert into %s values("
"%d, '%s')" % (table, n + 1, t))
- self.assertRaises(pg.ProgrammingError, self.db.delete,
- table, dict(t='b'))
+ self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
self.assertEqual(self.db.delete(table, dict(n=2)), 1)
- r = query('select t from "%s" where n=2' % table
- ).getresult()
+ r = query('select t from "%s" where n=2' % table).getresult()
self.assertEqual(r, [])
self.assertEqual(self.db.delete(table, dict(n=2)), 0)
- r = query('select t from "%s" where n=3' % table
- ).getresult()[0][0]
+ r = query('select t from "%s" where n=3' % table).getresult()[0][0]
self.assertEqual(r, 'c')
table = 'delete_test_table_2'
query('drop table if exists "%s"' % table)
@@ -1690,8 +2019,7 @@
t = chr(ord('a') + 2 * n + m)
query('insert into "%s" values('
"%d, %d, '%s')" % (table, n + 1, m + 1, t))
- self.assertRaises(pg.ProgrammingError, self.db.delete,
- table, dict(n=2, t='b'))
+ self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
r = [r[0] for r in query('select t from "%s" where n=2'
' order by m' % table).getresult()]
@@ -1727,6 +2055,49 @@
r = query('select count(*) from "%s"' % table).getresult()
self.assertEqual(r[0][0], 0)
+ def testDeleteReferenced(self):
+ delete = self.db.delete
+ query = self.db.query
+ query("drop table if exists test_child")
+ query("drop table if exists test_parent")
+ self.addCleanup(query, "drop table test_parent")
+ query("create table test_parent (n smallint primary key)")
+ self.addCleanup(query, "drop table test_child")
+ query("create table test_child ("
+ " n smallint primary key references test_parent (n))")
+ for n in range(3):
+ query("insert into test_parent (n) values (%d)" % n)
+ query("insert into test_child (n) values (%d)" % n)
+ q = ("select (select count(*) from test_parent),"
+ " (select count(*) from test_child)")
+ self.assertEqual(query(q).getresult()[0], (3, 3))
+ self.assertRaises(pg.ProgrammingError,
+ delete, 'test_parent', None, n=2)
+ self.assertRaises(pg.ProgrammingError,
+ delete, 'test_parent *', None, n=2)
+ r = delete('test_child', None, n=2)
+ self.assertEqual(r, 1)
+ self.assertEqual(query(q).getresult()[0], (3, 2))
+ r = delete('test_parent', None, n=2)
+ self.assertEqual(r, 1)
+ self.assertEqual(query(q).getresult()[0], (2, 2))
+ self.assertRaises(pg.ProgrammingError,
+ delete, 'test_parent', dict(n=0))
+ self.assertRaises(pg.ProgrammingError,
+ delete, 'test_parent *', dict(n=0))
+ r = delete('test_child', dict(n=0))
+ self.assertEqual(r, 1)
+ self.assertEqual(query(q).getresult()[0], (2, 1))
+ r = delete('test_child', dict(n=0))
+ self.assertEqual(r, 0)
+ r = delete('test_parent', dict(n=0))
+ self.assertEqual(r, 1)
+ self.assertEqual(query(q).getresult()[0], (1, 1))
+ r = delete('test_parent', None, n=0)
+ self.assertEqual(r, 0)
+ q = "select n from test_parent natural join test_child limit 2"
+ self.assertEqual(query(q).getresult(), [(1,)])
+
def testTruncate(self):
truncate = self.db.truncate
self.assertRaises(TypeError, truncate, None)
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql