Author: cito
Date: Tue Oct  2 12:12:19 2012
New Revision: 446

Log:
Support positional parameters in the query() method.

Modified:
   trunk/docs/changelog.txt
   trunk/docs/pg.txt
   trunk/module/TEST_PyGreSQL_classic.py
   trunk/module/pg.py
   trunk/module/pgmodule.c
   trunk/module/pgtypes.h
   trunk/module/test_pg.py

Modified: trunk/docs/changelog.txt
==============================================================================
--- trunk/docs/changelog.txt    Tue Sep  4 07:01:24 2012        (r445)
+++ trunk/docs/changelog.txt    Tue Oct  2 12:12:19 2012        (r446)
@@ -6,6 +6,9 @@
 -----------
 - Support the new PostgreSQL versions 9.0 and 9.1.
 - Particularly, support PQescapeLiteral() and PQescapeIdentifier().
+- The query method of the classic API now supports positional parameters.
+  This an effective way to pass arbitrary or unknown data without worrying
+  about SQL injection or syntax errors (contribution by Patrick TJ McPhee).
 - The execute() and executemany() methods now return the cursor object,
   so you can now write statements like "for row in cursor.execute(...)"
   (as suggested by Adam Frederick).

Modified: trunk/docs/pg.txt
==============================================================================
--- trunk/docs/pg.txt   Tue Sep  4 07:01:24 2012        (r445)
+++ trunk/docs/pg.txt   Tue Oct  2 12:12:19 2012        (r446)
@@ -423,10 +423,11 @@
 -------------------------------------
 Syntax::
 
-  query(command)
+  query(command, [args])
 
 Parameters:
   :command: SQL command (string)
+  :args: optional positional arguments
 
 Return type:
   :pgqueryobject, None: result values
@@ -449,10 +450,22 @@
   method returns a `pgqueryobject` that can be accessed via the `getresult()`
   or `dictresult()` method or simply printed. Otherwise, it returns `None`.
 
+  The query may optionally contain positional parameters of the form `$1`,
+  `$2`, etc instead of literal data, and the values supplied as a tuple.
+  The values are substituted by the database in such a way that they don't
+  need to be escaped, making this an effective way to pass arbitrary or
+  unknown data without worrying about SQL injection or syntax errors.
+
   When the database could not process the query, a `pg.ProgrammingError` or
   a `pg.InternalError` is raised. You can check the "SQLSTATE" code of this
   error by reading its `sqlstate` attribute.
 
+Example::
+
+  name = raw_input("Name? ")
+  phone = con.query("select phone from employees"
+    " where name=$1", (name, )).getresult()
+
 reset - resets the connection
 -----------------------------
 Syntax::
@@ -1003,6 +1016,41 @@
   either in the dictionary where the OID must be munged, or in the keywords
   where it can be simply the string "oid".
 
+query - executes a SQL command string
+-------------------------------------
+Syntax::
+
+  query(command, [arg1, [arg2, ...]])
+
+Parameters:
+  :command: SQL command (string)
+  :arg*: optional positional arguments
+
+Return type:
+  :pgqueryobject, None: result values
+
+Exceptions raised:
+  :TypeError: bad argument type, or too many arguments
+  :TypeError: invalid connection
+  :ValueError: empty SQL query or lost connection
+  :pg.ProgrammingError: error in query
+  :pg.InternalError: error during query processing
+
+Description:
+  Similar to the pgobject function with the same name, except that positional
+  arguments can be passed either as a single list or tuple, or as individual
+  positional arguments
+
+Example::
+
+  name = raw_input("Name? ")
+  phone = raw_input("Phone? "
+  rows = db.query("update employees set phone=$2"
+    " where name=$1", (name, phone)).getresult()[0][0]
+  # or
+  rows = db.query("update employees set phone=$2"
+    " where name=$1", name, phone).getresult()[0][0]
+
 clear - clears row values in memory
 -----------------------------------
 Syntax::

Modified: trunk/module/TEST_PyGreSQL_classic.py
==============================================================================
--- trunk/module/TEST_PyGreSQL_classic.py       Tue Sep  4 07:01:24 2012        
(r445)
+++ trunk/module/TEST_PyGreSQL_classic.py       Tue Oct  2 12:12:19 2012        
(r446)
@@ -88,6 +88,11 @@
         self.failUnlessRaises(ProgrammingError, db.get, '_test_vschema', 1234)
         db.get('_test_vschema', 1234, keyname='_test')
 
+    def test_params(self):
+        db.query("INSERT INTO _test_schema VALUES ($1, $2, $3)", 12, None, 34)
+        d = db.get('_test_schema', 12)
+        self.assertEqual(d['dvar'], 34)
+
     def test_insert(self):
         d = dict(_test=1234)
         db.insert('_test_schema', d)

Modified: trunk/module/pg.py
==============================================================================
--- trunk/module/pg.py  Tue Sep  4 07:01:24 2012        (r445)
+++ trunk/module/pg.py  Tue Oct  2 12:12:19 2012        (r446)
@@ -318,7 +318,7 @@
                 self.db.close()
             self.db = db
 
-    def query(self, qstr):
+    def query(self, qstr, *args):
         """Executes a SQL command string.
 
         This method simply sends a SQL query to the database. If the query is
@@ -332,12 +332,21 @@
         a pgqueryobject that can be accessed via getresult() or dictresult()
         or simply printed. Otherwise, it returns `None`.
 
+        The query can contain numbered parameters of the form $1 in place
+        of any data constant. Arguments given after the query string will
+        be substituted for the corresponding numbered parameter. Parameter
+        values can also be given as a single list or tuple argument.
+
+        Note that the query string must not be passed as a unicode value,
+        but you can pass arguments as unicode values if they can be decoded
+        using the current client encoding.
+
         """
         # Wraps shared library function for debugging.
         if not self.db:
             raise _int_error('Connection is not valid')
         self._do_debug(qstr)
-        return self.db.query(qstr)
+        return self.db.query(qstr, args)
 
     def pkey(self, cl, newpkey=None):
         """This method gets or sets the primary key of a class.
@@ -548,7 +557,8 @@
                     raise _db_error('%s not in arg' % qoid)
             else:
                 arg = {qoid: arg}
-            where = 'oid = %s' % arg[qoid]
+            where = 'oid = $1'
+            params = (arg[qoid],)
             attnames = '*'
         else:
             attnames = self.get_attnames(qcl)
@@ -558,14 +568,16 @@
                 if len(keyname) > 1:
                     raise _prg_error('Composite key needs dict as arg')
                 arg = dict([(k, arg) for k in keyname])
-            where = ' AND '.join(['%s = %s'
-                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
+            where = ' AND '.join(['%s = $%d'
+                % (k, i + 1) for i, k in enumerate(keyname)])
+            params = tuple(arg[k] for k in keyname)
             attnames = ', '.join(attnames)
         q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
-        self._do_debug(q)
-        res = self.db.query(q).dictresult()
+        self._do_debug(q + ' %% %r' % (params,))
+        res = self.db.query(q, params).dictresult()
         if not res:
-            raise _db_error('No such record in %s where %s' % (qcl, where))
+            raise _db_error(
+                'No such record in %s where %s %% %r' % (qcl, where, params))
         for att, value in res[0].iteritems():
             arg[att == 'oid' and qoid or att] = value
         return arg
@@ -590,11 +602,14 @@
             d = {}
         d.update(kw)
         attnames = self.get_attnames(qcl)
-        names, values = [], []
+        names, values, params = [], [], []
+        i = 1
         for n in attnames:
             if n != 'oid' and n in d:
                 names.append('"%s"' % n)
-                values.append(self._quote(d[n], attnames[n]))
+                values.append('$%d' % (i,))
+                params.append(d[n])
+                i += 1
         names, values = ', '.join(names), ', '.join(values)
         selectable = self.has_table_privilege(qcl)
         if selectable and self.server_version >= 80200:
@@ -602,8 +617,8 @@
         else:
             ret = ''
         q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
-        self._do_debug(q)
-        res = self.db.query(q)
+        self._do_debug(q + " %% %r" % (params,))
+        res = self.db.query(q, params)
         if ret:
             res = res.dictresult()
             for att, value in res[0].iteritems():
@@ -645,7 +660,8 @@
         d.update(kw)
         attnames = self.get_attnames(qcl)
         if qoid in d:
-            where = 'oid = %s' % d[qoid]
+            where = 'oid = $1'
+            params = [d[qoid]]
             keyname = ()
         else:
             try:
@@ -655,14 +671,18 @@
             if isinstance(keyname, basestring):
                 keyname = (keyname,)
             try:
-                where = ' AND '.join(['%s = %s'
-                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
+                where = ' AND '.join(['%s = $%d'
+                    % (k, i + 1) for i, k in enumerate(keyname)])
+                params = [d[k] for k in keyname]
             except KeyError:
                 raise _prg_error('Update needs primary key or oid.')
         values = []
+        i = len(params)
         for n in attnames:
             if n in d and n not in keyname:
-                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
+                i += 1
+                values.append('%s = $%d' % (n, i))
+                params.append(d[n])
         if not values:
             return d
         values = ', '.join(values)
@@ -673,7 +693,7 @@
             ret = ''
         q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
         self._do_debug(q)
-        res = self.db.query(q)
+        res = self.db.query(q, params)
         if ret:
             res = res.dictresult()[0]
             for att, value in res.iteritems():
@@ -735,7 +755,8 @@
             d = {}
         d.update(kw)
         if qoid in d:
-            where = 'oid = %s' % d[qoid]
+            where = 'oid = $1'
+            params = (d[qoid],)
         else:
             try:
                 keyname = self.pkey(qcl)
@@ -745,13 +766,14 @@
                 keyname = (keyname,)
             attnames = self.get_attnames(qcl)
             try:
-                where = ' AND '.join(['%s = %s'
-                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
+                where = ' AND '.join(['%s = $%d'
+                    % (k, i+ 1 ) for i, k in enumerate(keyname)])
+                params = tuple(d[k] for k in keyname)
             except KeyError:
                 raise _prg_error('Delete needs primary key or oid.')
         q = 'DELETE FROM %s WHERE %s' % (qcl, where)
-        self._do_debug(q)
-        return int(self.db.query(q))
+        self._do_debug(q + " %% %r" % (params,))
+        return int(self.db.query(q, params))
 
 
 # if run as script, print some information

Modified: trunk/module/pgmodule.c
==============================================================================
--- trunk/module/pgmodule.c     Tue Sep  4 07:01:24 2012        (r445)
+++ trunk/module/pgmodule.c     Tue Oct  2 12:12:19 2012        (r446)
@@ -786,7 +786,7 @@
                return NULL;
 
        /* builds result */
-       for (i = 0; i < size; ++i)
+       for (i = 0; i < size; i++)
        {
                if (!(rowtuple = PyTuple_New(self->num_fields)))
                {
@@ -2434,16 +2434,18 @@
 
 /* database query */
 static char pg_query__doc__[] =
-"query(sql) -- creates a new query object for this connection,"
-" using sql (string) request.";
+"query(sql, [args]) -- creates a new query object for this connection, using"
+" sql (string) request and optionally a tuple with positional parameters.";
 
 static PyObject *
 pg_query(pgobject *self, PyObject *args)
 {
-       char       *query;
-       PGresult   *result;
+       char            *query;
+       PyObject        *oargs = NULL;
+       PGresult        *result;
        pgqueryobject *npgobj;
-       int                     status;
+       int                     status,
+                               nparms = 0;
 
        if (!self->cnx)
        {
@@ -2452,16 +2454,119 @@
        }
 
        /* get query args */
-       if (!PyArg_ParseTuple(args, "s", &query))
+       if (!PyArg_ParseTuple(args, "s|O", &query, &oargs))
        {
-               PyErr_SetString(PyExc_TypeError, "query(sql), with sql 
(string).");
+               PyErr_SetString(PyExc_TypeError, "query(sql, [args]), with sql 
(string).");
                return NULL;
        }
 
+       /* If oargs is passed, ensure it's a non-empty tuple. We want to treat
+        * an empty tuple the same as no argument since we'll get that when the
+        * caller passes no arguments to db.query(), and historic behaviour was
+        * to call PQexec() in that case, which can execute multiple commands. 
*/
+       if (oargs)
+       {
+               if (!PyTuple_Check(oargs) && !PyList_Check(oargs))
+               {
+                       PyErr_SetString(PyExc_TypeError, "query parameters must 
be a tuple or list.");
+                       return NULL;
+               }
+
+               nparms = PySequence_Size(oargs);
+       }
+
        /* gets result */
-       Py_BEGIN_ALLOW_THREADS
-       result = PQexec(self->cnx, query);
-       Py_END_ALLOW_THREADS
+       if (nparms)
+       {
+               /* prepare arguments */
+               PyObject        **str, **s, *obj = PySequence_GetItem(oargs, 0);
+               char            **parms, **p, *enc=NULL;
+               int                     *lparms, *l;
+               register int i;
+
+               /* if there's a single argument and it's a list or tuple, it
+                * contains the positional aguments. */
+               if (nparms == 1 && (PyList_Check(obj) || PyTuple_Check(obj)))
+               {
+                       oargs = obj;
+                       nparms = PySequence_Size(oargs);
+               }
+               str = (PyObject **)alloca(nparms * sizeof(*str));
+               parms = (char **)alloca(nparms * sizeof(*parms));
+               lparms = (int *)alloca(nparms * sizeof(*lparms));
+
+               /* convert optional args to a list of strings -- this allows
+                * the caller to pass whatever they like, and prevents us
+                * from having to map types to OIDs */
+               for (i = 0, s=str, p=parms, l=lparms; i < nparms; i++, s++, 
p++, l++)
+               {
+                       obj = PySequence_GetItem(oargs, i);
+
+                       if (obj == Py_None)
+                       {
+                               *s = NULL;
+                               *p = NULL;
+                               *l = 0;
+                       }
+                       else if (PyUnicode_Check(obj))
+                       {
+                               if (!enc)
+                                       enc = (char *)pg_encoding_to_char(
+                                               PQclientEncoding(self->cnx));
+                               if (!strcmp(enc, "UTF8"))
+                                       *s = PyUnicode_AsUTF8String(obj);
+                               else if (!strcmp(enc, "LATIN1"))
+                                       *s = PyUnicode_AsLatin1String(obj);
+                               else if (!strcmp(enc, "SQL_ASCII"))
+                                       *s = PyUnicode_AsASCIIString(obj);
+                               else
+                                       *s = PyUnicode_AsEncodedString(obj, 
enc, "strict");
+                               if (*s == NULL) {
+                                       PyErr_SetString(PyExc_UnicodeError, 
"query parameter"
+                                               " could not be decoded (bad 
client encoding)");
+                                       while (i--) {
+                                               if (*--s)
+                                                       Py_DECREF(*s);
+                                       }
+                                       return NULL;
+                               }
+                               *p = PyString_AsString(*s);
+                               *l = PyString_Size(*s);
+                       }
+                       else
+                       {
+                               *s = PyObject_Str(obj);
+                               if (*s == NULL) {
+                                       PyErr_SetString(PyExc_TypeError,
+                                               "query parameter has no string 
representation");
+                                       while (i--) {
+                                               if (*--s)
+                                                       Py_DECREF(*s);
+                                       }
+                                       return NULL;
+                               }
+                               *p = PyString_AsString(*s);
+                               *l = PyString_Size(*s);
+                       }
+               }
+
+               Py_BEGIN_ALLOW_THREADS
+               result = PQexecParams(self->cnx, query, nparms,
+                       NULL, (const char * const *)parms, lparms, NULL, 0);
+               Py_END_ALLOW_THREADS
+
+               for (i = 0, s=str; i < nparms; i++, s++)
+               {
+                       if (*s)
+                               Py_DECREF(*s);
+               }
+       }
+       else
+       {
+               Py_BEGIN_ALLOW_THREADS
+               result = PQexec(self->cnx, query);
+               Py_END_ALLOW_THREADS
+       }
 
        /* checks result validity */
        if (!result)

Modified: trunk/module/pgtypes.h
==============================================================================
--- trunk/module/pgtypes.h      Tue Sep  4 07:01:24 2012        (r445)
+++ trunk/module/pgtypes.h      Tue Oct  2 12:12:19 2012        (r446)
@@ -85,4 +85,4 @@
 #define ANYNONARRAYOID 2776
 #define ANYENUMOID 3500
 
-#endif
+#endif /* PG_TYPE_H */

Modified: trunk/module/test_pg.py
==============================================================================
--- trunk/module/test_pg.py     Tue Sep  4 07:01:24 2012        (r445)
+++ trunk/module/test_pg.py     Tue Oct  2 12:12:19 2012        (r446)
@@ -1,4 +1,5 @@
 #!/usr/bin/env python
+# -*- coding: utf-8 -*-
 #
 # test_pg.py
 #
@@ -34,7 +35,7 @@
 german = True
 try:
     import locale
-    locale.setlocale(locale.LC_ALL, ('de', 'latin1'))
+    locale.setlocale(locale.LC_ALL, ('de', 'utf-8'))
 except Exception:
     try:
         locale.setlocale(locale.LC_ALL, 'german')
@@ -419,6 +420,12 @@
 
     def testMethodQuery(self):
         self.connection.query("select 1+1")
+        self.connection.query("select 1+$1", (1,))
+        self.connection.query("select 1+$1+$2", (2, 3))
+        self.connection.query("select 1+$1+$2", [2, 3])
+
+    def testMethodQueryEmpty(self):
+        self.assertRaises(ValueError, self.connection.query, '')
 
     def testMethodEndcopy(self):
         try:
@@ -643,6 +650,116 @@
             self.c.query('unlisten test_notify')
 
 
+class TestParamQueries(unittest.TestCase):
+    """"Test queries with parameters via a basic pg connection."""
+
+    def setUp(self):
+        dbname = 'test'
+        self.c = pg.connect(dbname)
+
+    def tearDown(self):
+        self.c.query("set client_encoding to UTF8")
+        self.c.close()
+
+    def testQueryWithNoneParam(self):
+        self.assertEqual(self.c.query("select $1::integer", (None,)
+            ).getresult(), [(None,)])
+        self.assertEqual(self.c.query("select $1::text", [None]
+            ).getresult(), [(None,)])
+
+    def testQueryWithIntParams(self):
+        query = self.c.query
+        self.assertEqual(query("select 1+1").getresult(), [(2,)])
+        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
+        self.assertEqual(query("select 1+$1", [1,]).getresult(), [(2,)])
+        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
+        self.assertEqual(query("select $1::text", (2,) ).getresult(), [('2',)])
+        self.assertEqual(query("select 1+$1::numeric", [1,]).getresult(),
+            [(Decimal('2'),)])
+        self.assertEqual(query("select 1, $1::integer", (2,)
+            ).getresult(), [(1, 2)])
+        self.assertEqual(query("select 1 union select $1", (2,)
+            ).getresult(), [(1,), (2,)])
+        self.assertEqual(query("select $1::integer+$2", (1, 2)
+            ).getresult(), [(3,)])
+        self.assertEqual(query("select $1::integer+$2", [1, 2]
+            ).getresult(), [(3,)])
+        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", range(6)
+            ).getresult(), [(15,)])
+
+    def testQueryWithStrParams(self):
+        query = self.c.query
+        self.assertEqual(query("select $1||', world!'", ('Hello',)
+            ).getresult(), [('Hello, world!',)])
+        self.assertEqual(query("select $1||', world!'", ['Hello']
+            ).getresult(), [('Hello, world!',)])
+        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
+            ).getresult(), [('Hello, world!',)])
+        self.assertEqual(query("select $1::text", ('Hello, world!',)
+            ).getresult(), [('Hello, world!',)])
+        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
+            ).getresult(), [('Hello', 'world')])
+        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
+            ).getresult(), [('Hello', 'world')])
+        self.assertEqual(query("select $1::text union select $2::text",
+            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
+        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
+            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
+
+    def testQueryWithUnicodeParams(self):
+        query = self.c.query
+        self.assertEqual(query("select $1||', '||$2||'!'",
+            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
+        self.assertEqual(query("select $1||', '||$2||'!'",
+            ('Hello', u'\u043c\u0438\u0440')).getresult(),
+            [('Hello, \xd0\xbc\xd0\xb8\xd1\x80!',)])
+        query('set client_encoding = latin1')
+        self.assertEqual(query("select $1||', '||$2||'!'",
+            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
+        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
+            ('Hello', u'\u043c\u0438\u0440'))
+        query('set client_encoding = iso_8859_1')
+        self.assertEqual(query("select $1||', '||$2||'!'",
+            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
+        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
+            ('Hello', u'\u043c\u0438\u0440'))
+        query('set client_encoding = iso_8859_5')
+        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
+            ('Hello', u'w\xf6rld'))
+        self.assertEqual(query("select $1||', '||$2||'!'",
+            ('Hello', u'\u043c\u0438\u0440')).getresult(),
+            [('Hello, \xdc\xd8\xe0!',)])
+        query('set client_encoding = sql_ascii')
+        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
+            ('Hello', u'w\xf6rld'))
+
+    def testQueryWithMixedParams(self):
+        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
+            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
+        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
+            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
+
+    def testQueryWithDuplicateParams(self):
+        self.assertRaises(pg.ProgrammingError,
+            self.c.query, "select $1+$1", (1,))
+        self.assertRaises(pg.ProgrammingError,
+            self.c.query, "select $1+$1", (1, 2))
+
+    def testQueryWithZeroParams(self):
+        self.assertEqual(self.c.query("select 1+1", []
+            ).getresult(), [(2,)])
+
+    def testQueryWithGarbage(self):
+        garbage = r"'\{}+()-#[]oo324"
+        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
+            ).dictresult(), [{'garbage': garbage}])
+
+    def testUnicodeQuery(self):
+        query = self.c.query
+        self.assertEqual(query(u"select 1+1").getresult(), [(2,)])
+        self.assertRaises(TypeError, query, u"select 'Hello, w\xf6rld!'")
+
+
 class TestInserttable(unittest.TestCase):
     """"Test inserttable method."""
 
@@ -926,6 +1043,13 @@
 
     def testMethodQuery(self):
         self.db.query("select 1+1")
+        self.db.query("select 1+$1", 1)
+        self.db.query("select 1+$1+$2", 2, 3)
+        self.db.query("select 1+$1+$2", (2, 3))
+        self.db.query("select 1+$1+$2", [2, 3])
+
+    def testMethodQueryEmpty(self):
+        self.assertRaises(ValueError, self.db.query, '')
 
     def testMethodQueryProgrammingError(self):
         try:
@@ -1148,6 +1272,41 @@
         self.assert_(isinstance(r, str))
         self.assertEqual(r, '5')
 
+    def testMultipleQueries(self):
+        self.assertEqual(self.db.query(
+            "create temporary table test_multi (n integer);"
+            "insert into test_multi values (4711);"
+            "select n from test_multi").getresult()[0][0], 4711)
+
+    def testQueryWithParams(self):
+        smart_ddl(self.db, "drop table test_table")
+        q = "create table test_table (n1 integer, n2 integer) with oids"
+        r = self.db.query(q)
+        q = "insert into test_table values ($1, $2)"
+        r = self.db.query(q, (1, 2))
+        self.assert_(isinstance(r, int))
+        r = self.db.query(q, [3, 4])
+        self.assert_(isinstance(r, int))
+        r = self.db.query(q, [5, 6])
+        self.assert_(isinstance(r, int))
+        q = "select * from test_table order by 1, 2"
+        self.assertEqual(self.db.query(q).getresult(),
+            [(1, 2), (3, 4), (5, 6)])
+        q = "select * from test_table where n1=$1 and n2=$2"
+        self.assertEqual(self.db.query(q, 3, 4).getresult(), [(3, 4)])
+        q = "update test_table set n2=$2 where n1=$1"
+        r = self.db.query(q, 3, 7)
+        self.assertEqual(r, '1')
+        q = "select * from test_table order by 1, 2"
+        self.assertEqual(self.db.query(q).getresult(),
+            [(1, 2), (3, 7), (5, 6)])
+        q = "delete from test_table where n2!=$1"
+        r = self.db.query(q, 4)
+        self.assertEqual(r, '3')
+
+    def testEmptyQuery(self):
+        self.assertRaises(ValueError, self.db.query, '')
+
     def testQueryProgrammingError(self):
         try:
             self.db.query("select 1/0")
@@ -1345,12 +1504,11 @@
                 "d numeric, f4 real, f8 double precision, m money, "
                 "v4 varchar(4), c4 char(4), t text,"
                 "b boolean, ts timestamp)" % table)
-            data = dict(i2 = 2**15 - 1,
-                i4 = int(2**31 - 1), i8 = long(2**31 - 1),
-                d = Decimal('123456789.9876543212345678987654321'),
-                f4 = 1.0 + 1.0/32, f8 = 1.0 + 1.0/32,
-                m = "1234.56", v4 = "1234", c4 = "1234", t = "1234" * 10,
-                b = 1, ts = 'current_date')
+            data = dict(i2=2**15 - 1, i4=int(2**31 - 1), i8=long(2**31 - 1),
+                d=Decimal('123456789.9876543212345678987654321'),
+                f4=1.0 + 1.0/32, f8 = 1.0 + 1.0/32,
+                m="1234.56", v4="1234", c4="1234",  t="1234" * 10,
+                b=1, ts='2012-12-21')
             r = self.db.insert(table, data)
             self.assertEqual(r, data)
             oid_table = table
@@ -1608,6 +1766,7 @@
         c.query("create database " + dbname
             + " template=template0")
         for s in ('client_min_messages = warning',
+            'client_encoding = UTF8',
             'lc_messages = C',
             'default_with_oids = on',
             'standard_conforming_strings = off',
@@ -1660,6 +1819,7 @@
         unittest.makeSuite(TestCanConnect),
         unittest.makeSuite(TestConnectObject),
         unittest.makeSuite(TestSimpleQueries),
+        unittest.makeSuite(TestParamQueries),
         unittest.makeSuite(TestDBClassBasic),
         ))
 
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to