Author: cito
Date: Tue Oct 2 12:12:19 2012
New Revision: 446
Log:
Support positional parameters in the query() method.
Modified:
trunk/docs/changelog.txt
trunk/docs/pg.txt
trunk/module/TEST_PyGreSQL_classic.py
trunk/module/pg.py
trunk/module/pgmodule.c
trunk/module/pgtypes.h
trunk/module/test_pg.py
Modified: trunk/docs/changelog.txt
==============================================================================
--- trunk/docs/changelog.txt Tue Sep 4 07:01:24 2012 (r445)
+++ trunk/docs/changelog.txt Tue Oct 2 12:12:19 2012 (r446)
@@ -6,6 +6,9 @@
-----------
- Support the new PostgreSQL versions 9.0 and 9.1.
- Particularly, support PQescapeLiteral() and PQescapeIdentifier().
+- The query method of the classic API now supports positional parameters.
+ This an effective way to pass arbitrary or unknown data without worrying
+ about SQL injection or syntax errors (contribution by Patrick TJ McPhee).
- The execute() and executemany() methods now return the cursor object,
so you can now write statements like "for row in cursor.execute(...)"
(as suggested by Adam Frederick).
Modified: trunk/docs/pg.txt
==============================================================================
--- trunk/docs/pg.txt Tue Sep 4 07:01:24 2012 (r445)
+++ trunk/docs/pg.txt Tue Oct 2 12:12:19 2012 (r446)
@@ -423,10 +423,11 @@
-------------------------------------
Syntax::
- query(command)
+ query(command, [args])
Parameters:
:command: SQL command (string)
+ :args: optional positional arguments
Return type:
:pgqueryobject, None: result values
@@ -449,10 +450,22 @@
method returns a `pgqueryobject` that can be accessed via the `getresult()`
or `dictresult()` method or simply printed. Otherwise, it returns `None`.
+ The query may optionally contain positional parameters of the form `$1`,
+ `$2`, etc instead of literal data, and the values supplied as a tuple.
+ The values are substituted by the database in such a way that they don't
+ need to be escaped, making this an effective way to pass arbitrary or
+ unknown data without worrying about SQL injection or syntax errors.
+
When the database could not process the query, a `pg.ProgrammingError` or
a `pg.InternalError` is raised. You can check the "SQLSTATE" code of this
error by reading its `sqlstate` attribute.
+Example::
+
+ name = raw_input("Name? ")
+ phone = con.query("select phone from employees"
+ " where name=$1", (name, )).getresult()
+
reset - resets the connection
-----------------------------
Syntax::
@@ -1003,6 +1016,41 @@
either in the dictionary where the OID must be munged, or in the keywords
where it can be simply the string "oid".
+query - executes a SQL command string
+-------------------------------------
+Syntax::
+
+ query(command, [arg1, [arg2, ...]])
+
+Parameters:
+ :command: SQL command (string)
+ :arg*: optional positional arguments
+
+Return type:
+ :pgqueryobject, None: result values
+
+Exceptions raised:
+ :TypeError: bad argument type, or too many arguments
+ :TypeError: invalid connection
+ :ValueError: empty SQL query or lost connection
+ :pg.ProgrammingError: error in query
+ :pg.InternalError: error during query processing
+
+Description:
+ Similar to the pgobject function with the same name, except that positional
+ arguments can be passed either as a single list or tuple, or as individual
+ positional arguments
+
+Example::
+
+ name = raw_input("Name? ")
+ phone = raw_input("Phone? "
+ rows = db.query("update employees set phone=$2"
+ " where name=$1", (name, phone)).getresult()[0][0]
+ # or
+ rows = db.query("update employees set phone=$2"
+ " where name=$1", name, phone).getresult()[0][0]
+
clear - clears row values in memory
-----------------------------------
Syntax::
Modified: trunk/module/TEST_PyGreSQL_classic.py
==============================================================================
--- trunk/module/TEST_PyGreSQL_classic.py Tue Sep 4 07:01:24 2012
(r445)
+++ trunk/module/TEST_PyGreSQL_classic.py Tue Oct 2 12:12:19 2012
(r446)
@@ -88,6 +88,11 @@
self.failUnlessRaises(ProgrammingError, db.get, '_test_vschema', 1234)
db.get('_test_vschema', 1234, keyname='_test')
+ def test_params(self):
+ db.query("INSERT INTO _test_schema VALUES ($1, $2, $3)", 12, None, 34)
+ d = db.get('_test_schema', 12)
+ self.assertEqual(d['dvar'], 34)
+
def test_insert(self):
d = dict(_test=1234)
db.insert('_test_schema', d)
Modified: trunk/module/pg.py
==============================================================================
--- trunk/module/pg.py Tue Sep 4 07:01:24 2012 (r445)
+++ trunk/module/pg.py Tue Oct 2 12:12:19 2012 (r446)
@@ -318,7 +318,7 @@
self.db.close()
self.db = db
- def query(self, qstr):
+ def query(self, qstr, *args):
"""Executes a SQL command string.
This method simply sends a SQL query to the database. If the query is
@@ -332,12 +332,21 @@
a pgqueryobject that can be accessed via getresult() or dictresult()
or simply printed. Otherwise, it returns `None`.
+ The query can contain numbered parameters of the form $1 in place
+ of any data constant. Arguments given after the query string will
+ be substituted for the corresponding numbered parameter. Parameter
+ values can also be given as a single list or tuple argument.
+
+ Note that the query string must not be passed as a unicode value,
+ but you can pass arguments as unicode values if they can be decoded
+ using the current client encoding.
+
"""
# Wraps shared library function for debugging.
if not self.db:
raise _int_error('Connection is not valid')
self._do_debug(qstr)
- return self.db.query(qstr)
+ return self.db.query(qstr, args)
def pkey(self, cl, newpkey=None):
"""This method gets or sets the primary key of a class.
@@ -548,7 +557,8 @@
raise _db_error('%s not in arg' % qoid)
else:
arg = {qoid: arg}
- where = 'oid = %s' % arg[qoid]
+ where = 'oid = $1'
+ params = (arg[qoid],)
attnames = '*'
else:
attnames = self.get_attnames(qcl)
@@ -558,14 +568,16 @@
if len(keyname) > 1:
raise _prg_error('Composite key needs dict as arg')
arg = dict([(k, arg) for k in keyname])
- where = ' AND '.join(['%s = %s'
- % (k, self._quote(arg[k], attnames[k])) for k in keyname])
+ where = ' AND '.join(['%s = $%d'
+ % (k, i + 1) for i, k in enumerate(keyname)])
+ params = tuple(arg[k] for k in keyname)
attnames = ', '.join(attnames)
q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
- self._do_debug(q)
- res = self.db.query(q).dictresult()
+ self._do_debug(q + ' %% %r' % (params,))
+ res = self.db.query(q, params).dictresult()
if not res:
- raise _db_error('No such record in %s where %s' % (qcl, where))
+ raise _db_error(
+ 'No such record in %s where %s %% %r' % (qcl, where, params))
for att, value in res[0].iteritems():
arg[att == 'oid' and qoid or att] = value
return arg
@@ -590,11 +602,14 @@
d = {}
d.update(kw)
attnames = self.get_attnames(qcl)
- names, values = [], []
+ names, values, params = [], [], []
+ i = 1
for n in attnames:
if n != 'oid' and n in d:
names.append('"%s"' % n)
- values.append(self._quote(d[n], attnames[n]))
+ values.append('$%d' % (i,))
+ params.append(d[n])
+ i += 1
names, values = ', '.join(names), ', '.join(values)
selectable = self.has_table_privilege(qcl)
if selectable and self.server_version >= 80200:
@@ -602,8 +617,8 @@
else:
ret = ''
q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
- self._do_debug(q)
- res = self.db.query(q)
+ self._do_debug(q + " %% %r" % (params,))
+ res = self.db.query(q, params)
if ret:
res = res.dictresult()
for att, value in res[0].iteritems():
@@ -645,7 +660,8 @@
d.update(kw)
attnames = self.get_attnames(qcl)
if qoid in d:
- where = 'oid = %s' % d[qoid]
+ where = 'oid = $1'
+ params = [d[qoid]]
keyname = ()
else:
try:
@@ -655,14 +671,18 @@
if isinstance(keyname, basestring):
keyname = (keyname,)
try:
- where = ' AND '.join(['%s = %s'
- % (k, self._quote(d[k], attnames[k])) for k in keyname])
+ where = ' AND '.join(['%s = $%d'
+ % (k, i + 1) for i, k in enumerate(keyname)])
+ params = [d[k] for k in keyname]
except KeyError:
raise _prg_error('Update needs primary key or oid.')
values = []
+ i = len(params)
for n in attnames:
if n in d and n not in keyname:
- values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
+ i += 1
+ values.append('%s = $%d' % (n, i))
+ params.append(d[n])
if not values:
return d
values = ', '.join(values)
@@ -673,7 +693,7 @@
ret = ''
q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
self._do_debug(q)
- res = self.db.query(q)
+ res = self.db.query(q, params)
if ret:
res = res.dictresult()[0]
for att, value in res.iteritems():
@@ -735,7 +755,8 @@
d = {}
d.update(kw)
if qoid in d:
- where = 'oid = %s' % d[qoid]
+ where = 'oid = $1'
+ params = (d[qoid],)
else:
try:
keyname = self.pkey(qcl)
@@ -745,13 +766,14 @@
keyname = (keyname,)
attnames = self.get_attnames(qcl)
try:
- where = ' AND '.join(['%s = %s'
- % (k, self._quote(d[k], attnames[k])) for k in keyname])
+ where = ' AND '.join(['%s = $%d'
+ % (k, i+ 1 ) for i, k in enumerate(keyname)])
+ params = tuple(d[k] for k in keyname)
except KeyError:
raise _prg_error('Delete needs primary key or oid.')
q = 'DELETE FROM %s WHERE %s' % (qcl, where)
- self._do_debug(q)
- return int(self.db.query(q))
+ self._do_debug(q + " %% %r" % (params,))
+ return int(self.db.query(q, params))
# if run as script, print some information
Modified: trunk/module/pgmodule.c
==============================================================================
--- trunk/module/pgmodule.c Tue Sep 4 07:01:24 2012 (r445)
+++ trunk/module/pgmodule.c Tue Oct 2 12:12:19 2012 (r446)
@@ -786,7 +786,7 @@
return NULL;
/* builds result */
- for (i = 0; i < size; ++i)
+ for (i = 0; i < size; i++)
{
if (!(rowtuple = PyTuple_New(self->num_fields)))
{
@@ -2434,16 +2434,18 @@
/* database query */
static char pg_query__doc__[] =
-"query(sql) -- creates a new query object for this connection,"
-" using sql (string) request.";
+"query(sql, [args]) -- creates a new query object for this connection, using"
+" sql (string) request and optionally a tuple with positional parameters.";
static PyObject *
pg_query(pgobject *self, PyObject *args)
{
- char *query;
- PGresult *result;
+ char *query;
+ PyObject *oargs = NULL;
+ PGresult *result;
pgqueryobject *npgobj;
- int status;
+ int status,
+ nparms = 0;
if (!self->cnx)
{
@@ -2452,16 +2454,119 @@
}
/* get query args */
- if (!PyArg_ParseTuple(args, "s", &query))
+ if (!PyArg_ParseTuple(args, "s|O", &query, &oargs))
{
- PyErr_SetString(PyExc_TypeError, "query(sql), with sql
(string).");
+ PyErr_SetString(PyExc_TypeError, "query(sql, [args]), with sql
(string).");
return NULL;
}
+ /* If oargs is passed, ensure it's a non-empty tuple. We want to treat
+ * an empty tuple the same as no argument since we'll get that when the
+ * caller passes no arguments to db.query(), and historic behaviour was
+ * to call PQexec() in that case, which can execute multiple commands.
*/
+ if (oargs)
+ {
+ if (!PyTuple_Check(oargs) && !PyList_Check(oargs))
+ {
+ PyErr_SetString(PyExc_TypeError, "query parameters must
be a tuple or list.");
+ return NULL;
+ }
+
+ nparms = PySequence_Size(oargs);
+ }
+
/* gets result */
- Py_BEGIN_ALLOW_THREADS
- result = PQexec(self->cnx, query);
- Py_END_ALLOW_THREADS
+ if (nparms)
+ {
+ /* prepare arguments */
+ PyObject **str, **s, *obj = PySequence_GetItem(oargs, 0);
+ char **parms, **p, *enc=NULL;
+ int *lparms, *l;
+ register int i;
+
+ /* if there's a single argument and it's a list or tuple, it
+ * contains the positional aguments. */
+ if (nparms == 1 && (PyList_Check(obj) || PyTuple_Check(obj)))
+ {
+ oargs = obj;
+ nparms = PySequence_Size(oargs);
+ }
+ str = (PyObject **)alloca(nparms * sizeof(*str));
+ parms = (char **)alloca(nparms * sizeof(*parms));
+ lparms = (int *)alloca(nparms * sizeof(*lparms));
+
+ /* convert optional args to a list of strings -- this allows
+ * the caller to pass whatever they like, and prevents us
+ * from having to map types to OIDs */
+ for (i = 0, s=str, p=parms, l=lparms; i < nparms; i++, s++,
p++, l++)
+ {
+ obj = PySequence_GetItem(oargs, i);
+
+ if (obj == Py_None)
+ {
+ *s = NULL;
+ *p = NULL;
+ *l = 0;
+ }
+ else if (PyUnicode_Check(obj))
+ {
+ if (!enc)
+ enc = (char *)pg_encoding_to_char(
+ PQclientEncoding(self->cnx));
+ if (!strcmp(enc, "UTF8"))
+ *s = PyUnicode_AsUTF8String(obj);
+ else if (!strcmp(enc, "LATIN1"))
+ *s = PyUnicode_AsLatin1String(obj);
+ else if (!strcmp(enc, "SQL_ASCII"))
+ *s = PyUnicode_AsASCIIString(obj);
+ else
+ *s = PyUnicode_AsEncodedString(obj,
enc, "strict");
+ if (*s == NULL) {
+ PyErr_SetString(PyExc_UnicodeError,
"query parameter"
+ " could not be decoded (bad
client encoding)");
+ while (i--) {
+ if (*--s)
+ Py_DECREF(*s);
+ }
+ return NULL;
+ }
+ *p = PyString_AsString(*s);
+ *l = PyString_Size(*s);
+ }
+ else
+ {
+ *s = PyObject_Str(obj);
+ if (*s == NULL) {
+ PyErr_SetString(PyExc_TypeError,
+ "query parameter has no string
representation");
+ while (i--) {
+ if (*--s)
+ Py_DECREF(*s);
+ }
+ return NULL;
+ }
+ *p = PyString_AsString(*s);
+ *l = PyString_Size(*s);
+ }
+ }
+
+ Py_BEGIN_ALLOW_THREADS
+ result = PQexecParams(self->cnx, query, nparms,
+ NULL, (const char * const *)parms, lparms, NULL, 0);
+ Py_END_ALLOW_THREADS
+
+ for (i = 0, s=str; i < nparms; i++, s++)
+ {
+ if (*s)
+ Py_DECREF(*s);
+ }
+ }
+ else
+ {
+ Py_BEGIN_ALLOW_THREADS
+ result = PQexec(self->cnx, query);
+ Py_END_ALLOW_THREADS
+ }
/* checks result validity */
if (!result)
Modified: trunk/module/pgtypes.h
==============================================================================
--- trunk/module/pgtypes.h Tue Sep 4 07:01:24 2012 (r445)
+++ trunk/module/pgtypes.h Tue Oct 2 12:12:19 2012 (r446)
@@ -85,4 +85,4 @@
#define ANYNONARRAYOID 2776
#define ANYENUMOID 3500
-#endif
+#endif /* PG_TYPE_H */
Modified: trunk/module/test_pg.py
==============================================================================
--- trunk/module/test_pg.py Tue Sep 4 07:01:24 2012 (r445)
+++ trunk/module/test_pg.py Tue Oct 2 12:12:19 2012 (r446)
@@ -1,4 +1,5 @@
#!/usr/bin/env python
+# -*- coding: utf-8 -*-
#
# test_pg.py
#
@@ -34,7 +35,7 @@
german = True
try:
import locale
- locale.setlocale(locale.LC_ALL, ('de', 'latin1'))
+ locale.setlocale(locale.LC_ALL, ('de', 'utf-8'))
except Exception:
try:
locale.setlocale(locale.LC_ALL, 'german')
@@ -419,6 +420,12 @@
def testMethodQuery(self):
self.connection.query("select 1+1")
+ self.connection.query("select 1+$1", (1,))
+ self.connection.query("select 1+$1+$2", (2, 3))
+ self.connection.query("select 1+$1+$2", [2, 3])
+
+ def testMethodQueryEmpty(self):
+ self.assertRaises(ValueError, self.connection.query, '')
def testMethodEndcopy(self):
try:
@@ -643,6 +650,116 @@
self.c.query('unlisten test_notify')
+class TestParamQueries(unittest.TestCase):
+ """"Test queries with parameters via a basic pg connection."""
+
+ def setUp(self):
+ dbname = 'test'
+ self.c = pg.connect(dbname)
+
+ def tearDown(self):
+ self.c.query("set client_encoding to UTF8")
+ self.c.close()
+
+ def testQueryWithNoneParam(self):
+ self.assertEqual(self.c.query("select $1::integer", (None,)
+ ).getresult(), [(None,)])
+ self.assertEqual(self.c.query("select $1::text", [None]
+ ).getresult(), [(None,)])
+
+ def testQueryWithIntParams(self):
+ query = self.c.query
+ self.assertEqual(query("select 1+1").getresult(), [(2,)])
+ self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
+ self.assertEqual(query("select 1+$1", [1,]).getresult(), [(2,)])
+ self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
+ self.assertEqual(query("select $1::text", (2,) ).getresult(), [('2',)])
+ self.assertEqual(query("select 1+$1::numeric", [1,]).getresult(),
+ [(Decimal('2'),)])
+ self.assertEqual(query("select 1, $1::integer", (2,)
+ ).getresult(), [(1, 2)])
+ self.assertEqual(query("select 1 union select $1", (2,)
+ ).getresult(), [(1,), (2,)])
+ self.assertEqual(query("select $1::integer+$2", (1, 2)
+ ).getresult(), [(3,)])
+ self.assertEqual(query("select $1::integer+$2", [1, 2]
+ ).getresult(), [(3,)])
+ self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", range(6)
+ ).getresult(), [(15,)])
+
+ def testQueryWithStrParams(self):
+ query = self.c.query
+ self.assertEqual(query("select $1||', world!'", ('Hello',)
+ ).getresult(), [('Hello, world!',)])
+ self.assertEqual(query("select $1||', world!'", ['Hello']
+ ).getresult(), [('Hello, world!',)])
+ self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
+ ).getresult(), [('Hello, world!',)])
+ self.assertEqual(query("select $1::text", ('Hello, world!',)
+ ).getresult(), [('Hello, world!',)])
+ self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
+ ).getresult(), [('Hello', 'world')])
+ self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
+ ).getresult(), [('Hello', 'world')])
+ self.assertEqual(query("select $1::text union select $2::text",
+ ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
+ self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
+ 'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
+
+ def testQueryWithUnicodeParams(self):
+ query = self.c.query
+ self.assertEqual(query("select $1||', '||$2||'!'",
+ ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
+ self.assertEqual(query("select $1||', '||$2||'!'",
+ ('Hello', u'\u043c\u0438\u0440')).getresult(),
+ [('Hello, \xd0\xbc\xd0\xb8\xd1\x80!',)])
+ query('set client_encoding = latin1')
+ self.assertEqual(query("select $1||', '||$2||'!'",
+ ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
+ self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
+ ('Hello', u'\u043c\u0438\u0440'))
+ query('set client_encoding = iso_8859_1')
+ self.assertEqual(query("select $1||', '||$2||'!'",
+ ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
+ self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
+ ('Hello', u'\u043c\u0438\u0440'))
+ query('set client_encoding = iso_8859_5')
+ self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
+ ('Hello', u'w\xf6rld'))
+ self.assertEqual(query("select $1||', '||$2||'!'",
+ ('Hello', u'\u043c\u0438\u0440')).getresult(),
+ [('Hello, \xdc\xd8\xe0!',)])
+ query('set client_encoding = sql_ascii')
+ self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
+ ('Hello', u'w\xf6rld'))
+
+ def testQueryWithMixedParams(self):
+ self.assertEqual(self.c.query("select $1+2,$2||', world!'",
+ (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
+ self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
+ (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
+
+ def testQueryWithDuplicateParams(self):
+ self.assertRaises(pg.ProgrammingError,
+ self.c.query, "select $1+$1", (1,))
+ self.assertRaises(pg.ProgrammingError,
+ self.c.query, "select $1+$1", (1, 2))
+
+ def testQueryWithZeroParams(self):
+ self.assertEqual(self.c.query("select 1+1", []
+ ).getresult(), [(2,)])
+
+ def testQueryWithGarbage(self):
+ garbage = r"'\{}+()-#[]oo324"
+ self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
+ ).dictresult(), [{'garbage': garbage}])
+
+ def testUnicodeQuery(self):
+ query = self.c.query
+ self.assertEqual(query(u"select 1+1").getresult(), [(2,)])
+ self.assertRaises(TypeError, query, u"select 'Hello, w\xf6rld!'")
+
+
class TestInserttable(unittest.TestCase):
""""Test inserttable method."""
@@ -926,6 +1043,13 @@
def testMethodQuery(self):
self.db.query("select 1+1")
+ self.db.query("select 1+$1", 1)
+ self.db.query("select 1+$1+$2", 2, 3)
+ self.db.query("select 1+$1+$2", (2, 3))
+ self.db.query("select 1+$1+$2", [2, 3])
+
+ def testMethodQueryEmpty(self):
+ self.assertRaises(ValueError, self.db.query, '')
def testMethodQueryProgrammingError(self):
try:
@@ -1148,6 +1272,41 @@
self.assert_(isinstance(r, str))
self.assertEqual(r, '5')
+ def testMultipleQueries(self):
+ self.assertEqual(self.db.query(
+ "create temporary table test_multi (n integer);"
+ "insert into test_multi values (4711);"
+ "select n from test_multi").getresult()[0][0], 4711)
+
+ def testQueryWithParams(self):
+ smart_ddl(self.db, "drop table test_table")
+ q = "create table test_table (n1 integer, n2 integer) with oids"
+ r = self.db.query(q)
+ q = "insert into test_table values ($1, $2)"
+ r = self.db.query(q, (1, 2))
+ self.assert_(isinstance(r, int))
+ r = self.db.query(q, [3, 4])
+ self.assert_(isinstance(r, int))
+ r = self.db.query(q, [5, 6])
+ self.assert_(isinstance(r, int))
+ q = "select * from test_table order by 1, 2"
+ self.assertEqual(self.db.query(q).getresult(),
+ [(1, 2), (3, 4), (5, 6)])
+ q = "select * from test_table where n1=$1 and n2=$2"
+ self.assertEqual(self.db.query(q, 3, 4).getresult(), [(3, 4)])
+ q = "update test_table set n2=$2 where n1=$1"
+ r = self.db.query(q, 3, 7)
+ self.assertEqual(r, '1')
+ q = "select * from test_table order by 1, 2"
+ self.assertEqual(self.db.query(q).getresult(),
+ [(1, 2), (3, 7), (5, 6)])
+ q = "delete from test_table where n2!=$1"
+ r = self.db.query(q, 4)
+ self.assertEqual(r, '3')
+
+ def testEmptyQuery(self):
+ self.assertRaises(ValueError, self.db.query, '')
+
def testQueryProgrammingError(self):
try:
self.db.query("select 1/0")
@@ -1345,12 +1504,11 @@
"d numeric, f4 real, f8 double precision, m money, "
"v4 varchar(4), c4 char(4), t text,"
"b boolean, ts timestamp)" % table)
- data = dict(i2 = 2**15 - 1,
- i4 = int(2**31 - 1), i8 = long(2**31 - 1),
- d = Decimal('123456789.9876543212345678987654321'),
- f4 = 1.0 + 1.0/32, f8 = 1.0 + 1.0/32,
- m = "1234.56", v4 = "1234", c4 = "1234", t = "1234" * 10,
- b = 1, ts = 'current_date')
+ data = dict(i2=2**15 - 1, i4=int(2**31 - 1), i8=long(2**31 - 1),
+ d=Decimal('123456789.9876543212345678987654321'),
+ f4=1.0 + 1.0/32, f8 = 1.0 + 1.0/32,
+ m="1234.56", v4="1234", c4="1234", t="1234" * 10,
+ b=1, ts='2012-12-21')
r = self.db.insert(table, data)
self.assertEqual(r, data)
oid_table = table
@@ -1608,6 +1766,7 @@
c.query("create database " + dbname
+ " template=template0")
for s in ('client_min_messages = warning',
+ 'client_encoding = UTF8',
'lc_messages = C',
'default_with_oids = on',
'standard_conforming_strings = off',
@@ -1660,6 +1819,7 @@
unittest.makeSuite(TestCanConnect),
unittest.makeSuite(TestConnectObject),
unittest.makeSuite(TestSimpleQueries),
+ unittest.makeSuite(TestParamQueries),
unittest.makeSuite(TestDBClassBasic),
))
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql