Sorry, the list didn't like the way I attached that patch. I'll try it again. 
See my previous msg for a description.

-- 
Patrick TJ McPhee <[email protected]>
Index: module/test_pg.py
===================================================================
--- module/test_pg.py	(revision 445)
+++ module/test_pg.py	(working copy)
@@ -418,7 +418,14 @@
         self.assertNotEqual(user, no_user)
 
     def testMethodQuery(self):
-        self.connection.query("select 1+1")
+        self.assertEqual(self.connection.query("select 1+1").getresult()[0][0], 2)
+        self.assertEqual(self.connection.query("select 1+$1", (1,)).getresult()[0][0], 2)
+        self.assertEqual(self.connection.query("select 1+$1", [1,]).getresult()[0][0], 2)
+        self.assertEqual(self.connection.query("select 1+$1::numeric", [1,]).getresult()[0][0], Decimal('2'))
+        self.assertEqual(self.connection.query("select $1::int+$2", (1,1)).getresult()[0][0], 2)
+        self.assertEqual(self.connection.query("create temp table x (a varchar(20));"
+                                               "insert into x values ('alpha');"
+                                               "select a from x").getresult()[0][0], "alpha")
 
     def testMethodEndcopy(self):
         try:
@@ -925,7 +932,14 @@
             r"\x4f007073ff21"), 'O\x00ps\xff!')
 
     def testMethodQuery(self):
-        self.db.query("select 1+1")
+        self.assertEqual(self.db.query("select 1+1").getresult()[0][0], 2)
+        self.assertEqual(self.db.query("select 1+$1", (1,)).getresult()[0][0], 2)
+        self.assertEqual(self.db.query("select 1+$1", [1,]).getresult()[0][0], 2)
+        self.assertEqual(self.db.query("select 1+$1::numeric", [1,]).getresult()[0][0], Decimal('2'))
+        self.assertEqual(self.db.query("select $1::int+$2", (1,1)).getresult()[0][0], 2)
+        self.assertEqual(self.db.query("select $1::int+$2", 1,1).getresult()[0][0], 2)
+        s = r"'\{}+()-#[]oo324"
+        self.assertEqual(self.db.query("select $1::text AS dennis", s).dictresult()[0]['dennis'], s)
 
     def testMethodQueryProgrammingError(self):
         try:
Index: module/pgmodule.c
===================================================================
--- module/pgmodule.c	(revision 445)
+++ module/pgmodule.c	(working copy)
@@ -2434,16 +2434,23 @@
 
 /* database query */
 static char pg_query__doc__[] =
-"query(sql) -- creates a new query object for this connection,"
-" using sql (string) request.";
+"query(sql, [args]) -- creates a new query object for this connection,"
+" using sql (string) request and optionally a tuple with positional"
+" parameters.";
 
 static PyObject *
 pg_query(pgobject *self, PyObject *args)
 {
-	char	   *query;
+	char	   *query,
+			  **parms;
+	PyObject   *oargs = NULL,
+			  **str;
 	PGresult   *result;
 	pgqueryobject *npgobj;
-	int			status;
+	int			status,
+				nparms = 0,
+			   *lparms;
+	register	int i;
 
 	if (!self->cnx)
 	{
@@ -2452,17 +2459,99 @@
 	}
 
 	/* get query args */
-	if (!PyArg_ParseTuple(args, "s", &query))
+	if (!PyArg_ParseTuple(args, "s|O", &query, &oargs))
 	{
-		PyErr_SetString(PyExc_TypeError, "query(sql), with sql (string).");
+		PyErr_SetString(PyExc_TypeError, "query(sql, [args]), with sql (string).");
 		return NULL;
 	}
 
+	/* If oargs is passed, ensure it's a non-empty tuple. We want to treat
+	 * an empty tuple the same as no argument since we'll get that when the
+	 * caller passes no arguments to db.query(), and historic behaviour was
+	 * to call PQexec(). */
+	if (oargs)
+	{
+		if (!PyTuple_Check(oargs) && !PyList_Check(oargs))
+		{
+			PyErr_SetString(PyExc_TypeError, "query parameters must be a tuple or list.");
+			return NULL;
+		}
+
+		nparms = PySequence_Size(oargs);
+	}
+
+
 	/* gets result */
-	Py_BEGIN_ALLOW_THREADS
-	result = PQexec(self->cnx, query);
-	Py_END_ALLOW_THREADS
+	if (!nparms)
+	{
+		Py_BEGIN_ALLOW_THREADS
+		result = PQexec(self->cnx, query);
+		Py_END_ALLOW_THREADS
+	}
 
+
+	else
+	{
+		/* prepare arguments */
+		PyObject * obj = PySequence_GetItem(oargs, 0);
+
+		/* if there's a single argument and it's a list or tuple, it
+		 * contains the positional aguments. */
+		if (nparms == 1 && (PyList_Check(obj) ||
+							PyTuple_Check(obj)))
+		{
+			oargs = obj;
+			nparms = PySequence_Size(oargs);
+		}
+		parms = (char **)alloca(nparms * sizeof(*parms));
+		lparms = (int *)alloca(nparms * sizeof(*lparms));
+		str = (PyObject **)alloca(nparms * sizeof(*str));
+
+		/* convert optional args to a list of strings -- this allows
+		 * the caller to pass whatever they like, and prevents us
+		 * from having to map types to OIDs */
+		for (i = 0; i < nparms; i++)
+		{
+			obj = PySequence_GetItem(oargs, i);
+
+			if (obj == Py_None)
+			{
+				str[i] = NULL;
+				lparms[i] = 0;
+				parms[i] = NULL;
+			}
+			else if (PyUnicode_Check(obj))
+			{
+				/* most objects can be converted to strings. Unicodes might
+                 * not work with the default encoding, so encode as UTF-8 */
+				str[i] = PyUnicode_AsUTF8String(obj);
+				lparms[i] = PyString_Size(str[i]);
+				parms[i] = PyString_AsString(str[i]);
+			}
+			else
+			{
+				str[i] = PyObject_Str(obj);
+				lparms[i] = PyString_Size(str[i]);
+				parms[i] = PyString_AsString(str[i]);
+			}
+		}
+
+
+		Py_BEGIN_ALLOW_THREADS
+		result = PQexecParams(self->cnx, query, nparms,
+							  NULL, parms, lparms, NULL, 0);
+		Py_END_ALLOW_THREADS
+
+		for (i = 0; i < nparms; i++)
+		{
+			if (str[i])
+			{
+				Py_DECREF(str[i]);
+			}
+		}
+	}
+
+
 	/* checks result validity */
 	if (!result)
 	{
Index: module/pg.py
===================================================================
--- module/pg.py	(revision 445)
+++ module/pg.py	(working copy)
@@ -318,7 +318,7 @@
                 self.db.close()
             self.db = db
 
-    def query(self, qstr):
+    def query(self, qstr, *args):
         """Executes a SQL command string.
 
         This method simply sends a SQL query to the database. If the query is
@@ -332,12 +332,16 @@
         a pgqueryobject that can be accessed via getresult() or dictresult()
         or simply printed. Otherwise, it returns `None`.
 
+        The query can contain numbered parameters of the form $1 in place
+        of any data constant. Arguments given after the query string will
+        be substituted for the corresponding numbered parameter. Parameter
+        values can also be given as a single list or tuple argument.
         """
         # Wraps shared library function for debugging.
         if not self.db:
             raise _int_error('Connection is not valid')
         self._do_debug(qstr)
-        return self.db.query(qstr)
+        return self.db.query(qstr, args)
 
     def pkey(self, cl, newpkey=None):
         """This method gets or sets the primary key of a class.
@@ -548,7 +552,8 @@
                     raise _db_error('%s not in arg' % qoid)
             else:
                 arg = {qoid: arg}
-            where = 'oid = %s' % arg[qoid]
+            where = 'oid = $1'
+            params = (arg[qoid],)
             attnames = '*'
         else:
             attnames = self.get_attnames(qcl)
@@ -558,14 +563,15 @@
                 if len(keyname) > 1:
                     raise _prg_error('Composite key needs dict as arg')
                 arg = dict([(k, arg) for k in keyname])
-            where = ' AND '.join(['%s = %s'
-                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
+            where = ' AND '.join(['%s = $%s'
+                % (k, i+1) for i,k in enumerate(keyname)])
+            params = tuple(arg[k] for k in keyname)
             attnames = ', '.join(attnames)
         q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
-        self._do_debug(q)
-        res = self.db.query(q).dictresult()
+        self._do_debug(q + ('%% %r' % (params,)))
+        res = self.db.query(q, params).dictresult()
         if not res:
-            raise _db_error('No such record in %s where %s' % (qcl, where))
+            raise _db_error('No such record in %s where %s %% %r' % (qcl, where, params))
         for att, value in res[0].iteritems():
             arg[att == 'oid' and qoid or att] = value
         return arg
@@ -590,11 +596,14 @@
             d = {}
         d.update(kw)
         attnames = self.get_attnames(qcl)
-        names, values = [], []
+        names, values, params = [], [], []
+        i = 1
         for n in attnames:
             if n != 'oid' and n in d:
                 names.append('"%s"' % n)
-                values.append(self._quote(d[n], attnames[n]))
+                values.append('$%s' % (i,))
+                params.append(d[n])
+                i += 1
         names, values = ', '.join(names), ', '.join(values)
         selectable = self.has_table_privilege(qcl)
         if selectable and self.server_version >= 80200:
@@ -602,8 +611,8 @@
         else:
             ret = ''
         q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
-        self._do_debug(q)
-        res = self.db.query(q)
+        self._do_debug(q + ("%% %r" % (params,)))
+        res = self.db.query(q, params)
         if ret:
             res = res.dictresult()
             for att, value in res[0].iteritems():
@@ -645,7 +654,8 @@
         d.update(kw)
         attnames = self.get_attnames(qcl)
         if qoid in d:
-            where = 'oid = %s' % d[qoid]
+            where = 'oid = $1'
+            params = [d[qoid]]
             keyname = ()
         else:
             try:
@@ -655,14 +665,18 @@
             if isinstance(keyname, basestring):
                 keyname = (keyname,)
             try:
-                where = ' AND '.join(['%s = %s'
-                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
+                where = ' AND '.join(['%s = $%s'
+                    % (k, i+1) for i,k in enumerate(keyname)])
+                params = [d[k] for k in keyname]
             except KeyError:
                 raise _prg_error('Update needs primary key or oid.')
         values = []
+        i = len(params)
         for n in attnames:
             if n in d and n not in keyname:
-                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
+                i += 1
+                values.append('%s = $%s' % (n, i))
+                params.append(d[n])
         if not values:
             return d
         values = ', '.join(values)
@@ -673,7 +687,7 @@
             ret = ''
         q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
         self._do_debug(q)
-        res = self.db.query(q)
+        res = self.db.query(q, params)
         if ret:
             res = res.dictresult()[0]
             for att, value in res.iteritems():
@@ -735,7 +749,8 @@
             d = {}
         d.update(kw)
         if qoid in d:
-            where = 'oid = %s' % d[qoid]
+            where = 'oid = $1'
+            params = (d[qoid],)
         else:
             try:
                 keyname = self.pkey(qcl)
@@ -745,13 +760,14 @@
                 keyname = (keyname,)
             attnames = self.get_attnames(qcl)
             try:
-                where = ' AND '.join(['%s = %s'
-                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
+                where = ' AND '.join(['%s = $%s'
+                    % (k, i+1) for i,k in enumerate(keyname)])
+                params = tuple(d[k] for k in keyname)
             except KeyError:
                 raise _prg_error('Delete needs primary key or oid.')
         q = 'DELETE FROM %s WHERE %s' % (qcl, where)
-        self._do_debug(q)
-        return int(self.db.query(q))
+        self._do_debug(q + ("%% %r" % (params,)))
+        return int(self.db.query(q, params))
 
 
 # if run as script, print some information
Index: docs/pg.txt
===================================================================
--- docs/pg.txt	(revision 445)
+++ docs/pg.txt	(working copy)
@@ -423,10 +423,11 @@
 -------------------------------------
 Syntax::
 
-  query(command)
+  query(command, [args])
 
 Parameters:
   :command: SQL command (string)
+  :args: optional positional arguments
 
 Return type:
   :pgqueryobject, None: result values
@@ -449,10 +450,22 @@
   method returns a `pgqueryobject` that can be accessed via the `getresult()`
   or `dictresult()` method or simply printed. Otherwise, it returns `None`.
 
+  The query may optionally contain positional parameters of the form `$1`,
+  `$2`, etc instead of literal data, and the values supplied as a tuple.
+  The values are substituted by the database in such a way that they don't
+  need to be escaped, making this an effective way to pass arbitrary or
+  unknown data without worrying about SQL injection or syntax errors.
+
   When the database could not process the query, a `pg.ProgrammingError` or
   a `pg.InternalError` is raised. You can check the "SQLSTATE" code of this
   error by reading its `sqlstate` attribute.
 
+Example::
+
+  name = raw_input("Name? ")
+  phone = con.query("select phone from employees"
+    " where name=$1", (name, )).getresult()
+
 reset - resets the connection
 -----------------------------
 Syntax::
@@ -1003,6 +1016,41 @@
   either in the dictionary where the OID must be munged, or in the keywords
   where it can be simply the string "oid".
 
+query - executes a SQL command string
+-------------------------------------
+Syntax::
+
+  query(command, [arg1, [arg2, ...]])
+
+Parameters:
+  :command: SQL command (string)
+  :arg*: optional positional arguments
+
+Return type:
+  :pgqueryobject, None: result values
+
+Exceptions raised:
+  :TypeError: bad argument type, or too many arguments
+  :TypeError: invalid connection
+  :ValueError: empty SQL query or lost connection
+  :pg.ProgrammingError: error in query
+  :pg.InternalError: error during query processing
+
+Description:
+  Similar to the pgobject function with the same name, except that positional
+  arguments can be passed either as a single list or tuple, or as individual
+  positional arguments
+
+Example::
+
+  name = raw_input("Name? ")
+  phone = raw_input("Phone? "
+  rows = db.query("update employees set phone=$2"
+    " where name=$1", (name, phone)).getresult()[0][0]
+  # or
+  rows = db.query("update employees set phone=$2"
+    " where name=$1", name, phone).getresult()[0][0]
+
 clear - clears row values in memory
 -----------------------------------
 Syntax::
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to