Update of /usr/cvs/Public/pygresql/module
In directory druid.net:/tmp/cvs-serv11982/module
Modified Files:
pg.py test_pg.py
Log Message:
The insert() and update() methods now check if the table is selectable and use
the "returning" clause if possible. The delete() method now also works with
primary keys and returns whether the row existed.
To see the diffs for this commit:
http://www.druid.net/pygresql/viewcvs.cgi/cvs/pygresql/module/pg.py.diff?r1=1.75&r2=1.76
Index: pg.py
===================================================================
RCS file: /usr/cvs/Public/pygresql/module/pg.py,v
retrieving revision 1.75
retrieving revision 1.76
diff -u -r1.75 -r1.76
--- pg.py 5 Dec 2008 02:08:15 -0000 1.75
+++ pg.py 5 Dec 2008 15:00:19 -0000 1.76
@@ -5,7 +5,7 @@
# Written by D'Arcy J.M. Cain
# Improved by Christoph Zwerschke
#
-# $Id: pg.py,v 1.75 2008/12/05 02:08:15 cito Exp $
+# $Id: pg.py,v 1.76 2008/12/05 15:00:19 cito Exp $
#
"""PyGreSQL classic interface.
@@ -127,6 +127,7 @@
self.dbname = db.db
self._attnames = {}
self._pkeys = {}
+ self._privileges = {}
self._args = args, kw
self.debug = None # For debugging scripts, this can be set
# * to a string format specification (e.g. in CGI set to "%s<BR>"),
@@ -226,19 +227,19 @@
else:
cl = s[0]
# determine search path
- query = 'SELECT current_schemas(TRUE)'
- schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
+ q = 'SELECT current_schemas(TRUE)'
+ schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
if schemas: # non-empty path
# search schema for this object in the current search path
- query = ' UNION '.join(
+ q = ' UNION '.join(
["SELECT %d::integer AS n, '%s'::name AS nspname"
% s for s in enumerate(schemas)])
- query = ("SELECT nspname FROM pg_class"
+ q = ("SELECT nspname FROM pg_class"
" JOIN pg_namespace ON pg_class.relnamespace =
pg_namespace.oid"
" JOIN (%s) AS p USING (nspname)"
" WHERE pg_class.relname = '%s'"
- " ORDER BY n LIMIT 1" % (query, cl))
- schema = self.db.query(query).getresult()
+ " ORDER BY n LIMIT 1" % (q, cl))
+ schema = self.db.query(q).getresult()
if schema: # schema found
schema = schema[0][0]
else: # object not found in current search path
@@ -294,11 +295,15 @@
"""Executes a SQL command string.
This method simply sends a SQL query to the database. If the query is
- an insert statement, the return value is the OID of the newly
- inserted row. If it is otherwise a query that does not return a result
- (ie. is not a some kind of SELECT statement), it returns None.
- Otherwise, it returns a pgqueryobject that can be accessed via the
- getresult or dictresult method or simply printed.
+ an insert statement that inserted exactly one row into a table that
+ has OIDs, the return value is the OID of the newly inserted row.
+ If the query is an update or delete statement, or an insert statement
+ that did not insert exactly one row in a table with OIDs, then the
+ numer of rows affected is returned as a string. If it is a statement
+ that returns rows as a result (usually a select statement, but maybe
+ also an "insert/update ... returning" statement), this method returns
+ a pgqueryobject that can be accessed via getresult() or dictresult()
+ or simply printed. Otherwise, it returns `None`.
"""
# Wraps shared library function for debugging.
@@ -446,6 +451,18 @@
self._attnames[qcl] = t # cache it
return self._attnames[qcl]
+ def has_table_privilege(self, cl, privilege='select'):
+ """Check whether current user has specified table privilege."""
+ qcl = self._add_schema(cl)
+ privilege = privilege.lower()
+ try:
+ return self._privileges[(qcl, privilege)]
+ except KeyError:
+ q = "SELECT has_table_privilege('%s', '%s')" % (qcl, privilege)
+ ret = self.db.query(q).getresult()[0][0] == 't'
+ self._privileges[(qcl, privilege)] = ret
+ return ret
+
def get(self, cl, arg, keyname=None):
"""Get a tuple from a database table or view.
@@ -502,27 +519,25 @@
arg[att == 'oid' and qoid or att] = value
return arg
- def insert(self, cl, d=None, return_changes=True, **kw):
+ def insert(self, cl, d=None, **kw):
"""Insert a tuple into a database table.
This method inserts a row into a table. If a dictionary is
supplied it starts with that. Otherwise it uses a blank dictionary.
Either way the dictionary is updated from the keywords.
- The dictionary is then reloaded with the values actually inserted
- in order to pick up values modified by rules, triggers, etc. If
- the optional flag return_changes is set to False this reload will
- be skipped.
+ The dictionary is then, if possible, reloaded with the values actually
+ inserted in order to pick up values modified by rules, triggers, etc.
Note: The method currently doesn't support insert into views
although PostgreSQL does.
"""
+ qcl = self._add_schema(cl)
+ qoid = _oid_key(qcl)
if d is None:
d = {}
d.update(kw)
- qcl = self._add_schema(cl)
- qoid = _oid_key(qcl)
attnames = self.get_attnames(qcl)
names, values = [], []
for n in attnames:
@@ -530,44 +545,62 @@
names.append('"%s"' % n)
values.append(self._quote(d[n], attnames[n]))
names, values = ', '.join(names), ', '.join(values)
- q = 'INSERT INTO %s (%s) VALUES (%s)' % (qcl, names, values)
+ selectable = self.has_table_privilege(qcl)
+ if selectable and self.server_version >= 80200:
+ ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
+ else:
+ ret = ''
+ q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
self._do_debug(q)
- d[qoid] = self.db.query(q)
- # Reload the dictionary to catch things modified by engine.
- # Note that get() changes 'oid' below to oid(schema.table).
- if return_changes:
- self.get(qcl, d, 'oid')
+ res = self.db.query(q)
+ if ret:
+ res = res.dictresult()
+ for att, value in res[0].iteritems():
+ d[att == 'oid' and qoid or att] = value
+ elif isinstance(res, int):
+ d[qoid] = res
+ if selectable:
+ self.get(qcl, d, 'oid')
+ elif selectable:
+ if qoid in d:
+ self.get(qcl, d, 'oid')
+ else:
+ try:
+ self.get(qcl, d)
+ except ProgrammingError:
+ pass # table has no primary key
return d
def update(self, cl, d=None, **kw):
"""Update an existing row in a database table.
Similar to insert but updates an existing row. The update is based
- on the OID value as munged by get. The array returned is the
- one sent modified to reflect any changes caused by the update due
- to triggers, rules, defaults, etc.
+ on the OID value as munged by get or passed as keyword, or on the
+ primary key of the table. The dictionary is modified, if possible,
+ to reflect any changes caused by the update due to triggers, rules,
+ default values, etc.
"""
# Update always works on the oid which get returns if available,
# otherwise use the primary key. Fail if neither.
# Note that we only accept oid key from named args for safety
+ qcl = self._add_schema(cl)
+ qoid = _oid_key(qcl)
if 'oid' in kw:
kw[qoid] = kw['oid']
del kw['oid']
if d is None:
d = {}
d.update(kw)
- qcl = self._add_schema(cl)
- qoid = _oid_key(qcl)
attnames = self.get_attnames(qcl)
if qoid in d:
where = 'oid = %s' % d[qoid]
keyname = ()
else:
try:
- keyname = self.pkey(qcl)
+ keyname = self.pkey(qcl)
except KeyError:
- raise ProgrammingError('Class %s has no primary key' % qcl)
+ raise ProgrammingError('Class %s has no primary key' % qcl)
if isinstance(keyname, basestring):
keyname = (keyname,)
try:
@@ -582,14 +615,26 @@
if not values:
return d
values = ', '.join(values)
- q = 'UPDATE %s SET %s WHERE %s' % (qcl, values, where)
+ selectable = self.has_table_privilege(qcl)
+ if selectable and self.server_version >= 880200:
+ ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
+ else:
+ ret = ''
+ q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
self._do_debug(q)
- self.db.query(q)
- # Reload the dictionary to catch things modified by engine:
- if qoid in d:
- return self.get(qcl, d, 'oid')
+ res = self.db.query(q)
+ if ret:
+ res = self.db.query(q).dictresult()
+ for att, value in res[0].iteritems():
+ d[att == 'oid' and qoid or att] = value
else:
- return self.get(qcl, d)
+ self.db.query(q)
+ if selectable:
+ if qoid in d:
+ self.get(qcl, d, 'oid')
+ else:
+ self.get(qcl, d)
+ return d
def clear(self, cl, a=None):
"""
@@ -602,9 +647,9 @@
"""
# At some point we will need a way to get defaults from a table.
+ qcl = self._add_schema(cl)
if a is None:
a = {} # empty if argument is not present
- qcl = self._add_schema(cl)
attnames = self.get_attnames(qcl)
for n, t in attnames.iteritems():
if n == 'oid':
@@ -620,25 +665,42 @@
def delete(self, cl, d=None, **kw):
"""Delete an existing row in a database table.
- This method deletes the row from a table.
- It deletes based on the OID munged as described above.
+ This method deletes the row from a table. It deletes based on the
+ OID value as munged by get or passed as keyword, or on the primary
+ key of the table. The return value is the number of deleted rows
+ (i.e. 0 if the row did not exist and 1 if the row was deleted).
"""
# Like update, delete works on the oid.
# One day we will be testing that the record to be deleted
# isn't referenced somewhere (or else PostgreSQL will).
# Note that we only accept oid key from named args for safety
+ qcl = self._add_schema(cl)
+ qoid = _oid_key(qcl)
if 'oid' in kw:
kw[qoid] = kw['oid']
del kw['oid']
if d is None:
d = {}
d.update(kw)
- qcl = self._add_schema(cl)
- qoid = _oid_key(qcl)
- q = 'DELETE FROM %s WHERE oid=%s' % (qcl, d[qoid])
+ if qoid in d:
+ where = 'oid = %s' % d[qoid]
+ else:
+ try:
+ keyname = self.pkey(qcl)
+ except KeyError:
+ raise ProgrammingError('Class %s has no primary key' % qcl)
+ if isinstance(keyname, basestring):
+ keyname = (keyname,)
+ attnames = self.get_attnames(qcl)
+ try:
+ where = ' AND '.join(['%s = %s'
+ % (k, self._quote(d[k], attnames[k])) for k in keyname])
+ except KeyError:
+ raise ProgrammingError('Delete needs primary key or oid.')
+ q = 'DELETE FROM %s WHERE %s' % (qcl, where)
self._do_debug(q)
- self.db.query(q)
+ return int(self.db.query(q))
# if run as script, print some information
http://www.druid.net/pygresql/viewcvs.cgi/cvs/pygresql/module/test_pg.py.diff?r1=1.26&r2=1.27
Index: test_pg.py
===================================================================
RCS file: /usr/cvs/Public/pygresql/module/test_pg.py,v
retrieving revision 1.26
retrieving revision 1.27
diff -u -r1.26 -r1.27
--- test_pg.py 5 Dec 2008 02:05:28 -0000 1.26
+++ test_pg.py 5 Dec 2008 15:00:19 -0000 1.27
@@ -4,7 +4,7 @@
#
# Written by Christoph Zwerschke
#
-# $Id: test_pg.py,v 1.26 2008/12/05 02:05:28 cito Exp $
+# $Id: test_pg.py,v 1.27 2008/12/05 15:00:19 cito Exp $
#
"""Test the classic PyGreSQL interface in the pg module.
@@ -12,15 +12,16 @@
The testing is done against a real local PostgreSQL database.
There are a few drawbacks:
-* A local PostgreSQL database must be up and running, and
-the user who is running the tests must be a trusted superuser.
-* The performance of the API is not tested.
-* Connecting to a remote host is not tested.
-* Passing user, password and options is not tested.
-* Status and error messages from the connection are not tested.
-* It would be more reasonable to create a test for the underlying
-shared library functions in the _pg module and assume they are ok.
-The pg and pgdb modules should be tested against _pg mock functions.
+ * A local PostgreSQL database must be up and running, and
+ the user who is running the tests must be a trusted superuser.
+ * The performance of the API is not tested.
+ * Connecting to a remote host is not tested.
+ * Passing user, password and options is not tested.
+ * Table privilege problems (e.g. insert but no select) are not tested.
+ * Status and error messages from the connection are not tested.
+ * It would be more reasonable to create a test for the underlying
+ shared library functions in the _pg module and assume they are ok.
+ The pg and pgdb modules should be tested against _pg mock functions.
"""
@@ -355,17 +356,16 @@
self.connection.close()
def testAllConnectAttributes(self):
- attributes = ['db', 'error', 'host', 'options', 'port',
- 'protocol_version', 'server_version', 'status', 'tty', 'user']
+ attributes = '''db error host options port
+ protocol_version server_version status tty user'''.split()
connection_attributes = [a for a in dir(self.connection)
if not callable(eval("self.connection." + a))]
self.assertEqual(attributes, connection_attributes)
def testAllConnectMethods(self):
- methods = ['cancel', 'close', 'endcopy', 'escape_bytea',
- 'escape_string', 'fileno', 'getline', 'getlo', 'getnotify',
- 'inserttable', 'locreate', 'loimport', 'parameter', 'putline',
- 'query', 'reset', 'source', 'transaction']
+ methods = '''cancel close endcopy escape_bytea escape_string fileno
+ getline getlo getnotify inserttable locreate loimport parameter
+ putline query reset source transaction'''.split()
connection_methods = [a for a in dir(self.connection)
if callable(eval("self.connection." + a))]
self.assertEqual(methods, connection_methods)
@@ -442,7 +442,7 @@
""""Test simple queries via a basic pg connection."""
def setUp(self):
- dbname = 'template1'
+ dbname = 'test'
self.c = pg.connect(dbname)
def tearDown(self):
@@ -692,14 +692,13 @@
self.db.close()
def testAllDBAttributes(self):
- attributes = ['cancel', 'clear', 'close', 'db', 'dbname', 'debug',
- 'delete', 'endcopy', 'error', 'escape_bytea', 'escape_string',
- 'fileno', 'get', 'get_attnames', 'get_databases', 'get_relations',
- 'get_tables', 'getline', 'getlo', 'getnotify', 'host', 'insert',
- 'inserttable', 'locreate', 'loimport', 'options', 'parameter',
- 'pkey', 'port', 'protocol_version', 'putline', 'query', 'reopen',
- 'reset', 'server_version', 'source', 'status', 'transaction',
- 'tty', 'unescape_bytea', 'update', 'user']
+ attributes = '''cancel clear close db dbname debug delete endcopy
+ error escape_bytea escape_string fileno get get_attnames
+ get_databases get_relations get_tables getline getlo getnotify
+ has_table_privilege host insert inserttable locreate loimport
+ options parameter pkey port protocol_version putline query
+ reopen reset server_version source status transaction tty
+ unescape_bytea update user'''.split()
db_attributes = [a for a in dir(self.db)
if not a.startswith('_')]
self.assertEqual(attributes, db_attributes)
@@ -1060,6 +1059,18 @@
'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int' }
self.assertEqual(attributes, result)
+ def testHasTablePrivilege(self):
+ can = self.db.has_table_privilege
+ self.assertEqual(can('test'), True)
+ self.assertEqual(can('test', 'select'), True)
+ self.assertEqual(can('test', 'SeLeCt'), True)
+ self.assertEqual(can('test', 'SELECT'), True)
+ self.assertEqual(can('test', 'insert'), True)
+ self.assertEqual(can('test', 'update'), True)
+ self.assertEqual(can('test', 'delete'), True)
+ self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
+ self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
+
def testGet(self):
for table in ('get_test_table', 'test table for get'):
smart_ddl(self.db, 'drop table "%s"' % table)
@@ -1178,7 +1189,7 @@
self.assertEqual(r, 'u')
def testUpdateWithCompositeKey(self):
- table = 'get_update_table_1'
+ table = 'update_test_table_1'
smart_ddl(self.db, "drop table %s" % table)
smart_ddl(self.db, "create table %s ("
"n integer, t text, primary key (n))" % table)
@@ -1191,7 +1202,7 @@
r = self.db.query('select t from "%s" where n=2' % table
).getresult()[0][0]
self.assertEqual(r, 'd')
- table = 'get_test_update_2'
+ table = 'update_test_table_2'
smart_ddl(self.db, "drop table %s" % table)
smart_ddl(self.db, "create table %s ("
"n integer, m integer, t text, primary key (n, m))" % table)
@@ -1235,8 +1246,12 @@
self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
r = self.db.get(table, 1, 'n')
s = self.db.delete(table, r)
+ self.assertEqual(s, 1)
r = self.db.get(table, 3, 'n')
s = self.db.delete(table, r)
+ self.assertEqual(s, 1)
+ s = self.db.delete(table, r)
+ self.assertEqual(s, 0)
r = self.db.query('select * from "%s"' % table).dictresult()
self.assertEqual(len(r), 1)
r = r[0]
@@ -1244,8 +1259,53 @@
self.assertEqual(r, result)
r = self.db.get(table, 2, 'n')
s = self.db.delete(table, r)
+ self.assertEqual(s, 1)
+ s = self.db.delete(table, r)
+ self.assertEqual(s, 0)
self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
+ def testDeleteWithCompositeKey(self):
+ table = 'delete_test_table_1'
+ smart_ddl(self.db, "drop table %s" % table)
+ smart_ddl(self.db, "create table %s ("
+ "n integer, t text, primary key (n))" % table)
+ for n, t in enumerate('abc'):
+ self.db.query("insert into %s values("
+ "%d, '%s')" % (table, n+1, t))
+ self.assertRaises(pg.ProgrammingError, self.db.delete,
+ table, dict(t='b'))
+ self.assertEqual(self.db.delete(table, dict(n=2)), 1)
+ r = self.db.query('select t from "%s" where n=2' % table
+ ).getresult()
+ self.assertEqual(r, [])
+ self.assertEqual(self.db.delete(table, dict(n=2)), 0)
+ r = self.db.query('select t from "%s" where n=3' % table
+ ).getresult()[0][0]
+ self.assertEqual(r, 'c')
+ table = 'delete_test_table_2'
+ smart_ddl(self.db, "drop table %s" % table)
+ smart_ddl(self.db, "create table %s ("
+ "n integer, m integer, t text, primary key (n, m))" % table)
+ for n in range(3):
+ for m in range(2):
+ t = chr(ord('a') + 2*n +m)
+ self.db.query("insert into %s values("
+ "%d, %d, '%s')" % (table, n+1, m+1, t))
+ self.assertRaises(pg.ProgrammingError, self.db.delete,
+ table, dict(n=2, t='b'))
+ self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
+ r = [r[0] for r in self.db.query('select t from "%s" where n=2'
+ ' order by m' % table).getresult()]
+ self.assertEqual(r, ['c'])
+ self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
+ r = [r[0] for r in self.db.query('select t from "%s" where n=3'
+ ' order by m' % table).getresult()]
+ self.assertEqual(r, ['e', 'f'])
+ self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
+ r = [r[0] for r in self.db.query('select t from "%s" where n=3'
+ ' order by m' % table).getresult()]
+ self.assertEqual(r, ['f'])
+
def testBytea(self):
smart_ddl(self.db, 'drop table bytea_test')
smart_ddl(self.db, 'create table bytea_test ('
_______________________________________________
PyGreSQL mailing list
[email protected]
http://mailman.vex.net/mailman/listinfo/pygresql