Sorry, the list didn't like the way I attached that patch. I'll try it again.
See my previous msg for a description.
--
Patrick TJ McPhee <[email protected]>
Index: module/test_pg.py
===================================================================
--- module/test_pg.py (revision 445)
+++ module/test_pg.py (working copy)
@@ -418,7 +418,14 @@
self.assertNotEqual(user, no_user)
def testMethodQuery(self):
- self.connection.query("select 1+1")
+ self.assertEqual(self.connection.query("select 1+1").getresult()[0][0], 2)
+ self.assertEqual(self.connection.query("select 1+$1", (1,)).getresult()[0][0], 2)
+ self.assertEqual(self.connection.query("select 1+$1", [1,]).getresult()[0][0], 2)
+ self.assertEqual(self.connection.query("select 1+$1::numeric", [1,]).getresult()[0][0], Decimal('2'))
+ self.assertEqual(self.connection.query("select $1::int+$2", (1,1)).getresult()[0][0], 2)
+ self.assertEqual(self.connection.query("create temp table x (a varchar(20));"
+ "insert into x values ('alpha');"
+ "select a from x").getresult()[0][0], "alpha")
def testMethodEndcopy(self):
try:
@@ -925,7 +932,14 @@
r"\x4f007073ff21"), 'O\x00ps\xff!')
def testMethodQuery(self):
- self.db.query("select 1+1")
+ self.assertEqual(self.db.query("select 1+1").getresult()[0][0], 2)
+ self.assertEqual(self.db.query("select 1+$1", (1,)).getresult()[0][0], 2)
+ self.assertEqual(self.db.query("select 1+$1", [1,]).getresult()[0][0], 2)
+ self.assertEqual(self.db.query("select 1+$1::numeric", [1,]).getresult()[0][0], Decimal('2'))
+ self.assertEqual(self.db.query("select $1::int+$2", (1,1)).getresult()[0][0], 2)
+ self.assertEqual(self.db.query("select $1::int+$2", 1,1).getresult()[0][0], 2)
+ s = r"'\{}+()-#[]oo324"
+ self.assertEqual(self.db.query("select $1::text AS dennis", s).dictresult()[0]['dennis'], s)
def testMethodQueryProgrammingError(self):
try:
Index: module/pgmodule.c
===================================================================
--- module/pgmodule.c (revision 445)
+++ module/pgmodule.c (working copy)
@@ -2434,16 +2434,23 @@
/* database query */
static char pg_query__doc__[] =
-"query(sql) -- creates a new query object for this connection,"
-" using sql (string) request.";
+"query(sql, [args]) -- creates a new query object for this connection,"
+" using sql (string) request and optionally a tuple with positional"
+" parameters.";
static PyObject *
pg_query(pgobject *self, PyObject *args)
{
- char *query;
+ char *query,
+ **parms;
+ PyObject *oargs = NULL,
+ **str;
PGresult *result;
pgqueryobject *npgobj;
- int status;
+ int status,
+ nparms = 0,
+ *lparms;
+ register int i;
if (!self->cnx)
{
@@ -2452,17 +2459,99 @@
}
/* get query args */
- if (!PyArg_ParseTuple(args, "s", &query))
+ if (!PyArg_ParseTuple(args, "s|O", &query, &oargs))
{
- PyErr_SetString(PyExc_TypeError, "query(sql), with sql (string).");
+ PyErr_SetString(PyExc_TypeError, "query(sql, [args]), with sql (string).");
return NULL;
}
+ /* If oargs is passed, ensure it's a non-empty tuple. We want to treat
+ * an empty tuple the same as no argument since we'll get that when the
+ * caller passes no arguments to db.query(), and historic behaviour was
+ * to call PQexec(). */
+ if (oargs)
+ {
+ if (!PyTuple_Check(oargs) && !PyList_Check(oargs))
+ {
+ PyErr_SetString(PyExc_TypeError, "query parameters must be a tuple or list.");
+ return NULL;
+ }
+
+ nparms = PySequence_Size(oargs);
+ }
+
+
/* gets result */
- Py_BEGIN_ALLOW_THREADS
- result = PQexec(self->cnx, query);
- Py_END_ALLOW_THREADS
+ if (!nparms)
+ {
+ Py_BEGIN_ALLOW_THREADS
+ result = PQexec(self->cnx, query);
+ Py_END_ALLOW_THREADS
+ }
+
+ else
+ {
+ /* prepare arguments */
+ PyObject * obj = PySequence_GetItem(oargs, 0);
+
+ /* if there's a single argument and it's a list or tuple, it
+ * contains the positional aguments. */
+ if (nparms == 1 && (PyList_Check(obj) ||
+ PyTuple_Check(obj)))
+ {
+ oargs = obj;
+ nparms = PySequence_Size(oargs);
+ }
+ parms = (char **)alloca(nparms * sizeof(*parms));
+ lparms = (int *)alloca(nparms * sizeof(*lparms));
+ str = (PyObject **)alloca(nparms * sizeof(*str));
+
+ /* convert optional args to a list of strings -- this allows
+ * the caller to pass whatever they like, and prevents us
+ * from having to map types to OIDs */
+ for (i = 0; i < nparms; i++)
+ {
+ obj = PySequence_GetItem(oargs, i);
+
+ if (obj == Py_None)
+ {
+ str[i] = NULL;
+ lparms[i] = 0;
+ parms[i] = NULL;
+ }
+ else if (PyUnicode_Check(obj))
+ {
+ /* most objects can be converted to strings. Unicodes might
+ * not work with the default encoding, so encode as UTF-8 */
+ str[i] = PyUnicode_AsUTF8String(obj);
+ lparms[i] = PyString_Size(str[i]);
+ parms[i] = PyString_AsString(str[i]);
+ }
+ else
+ {
+ str[i] = PyObject_Str(obj);
+ lparms[i] = PyString_Size(str[i]);
+ parms[i] = PyString_AsString(str[i]);
+ }
+ }
+
+
+ Py_BEGIN_ALLOW_THREADS
+ result = PQexecParams(self->cnx, query, nparms,
+ NULL, parms, lparms, NULL, 0);
+ Py_END_ALLOW_THREADS
+
+ for (i = 0; i < nparms; i++)
+ {
+ if (str[i])
+ {
+ Py_DECREF(str[i]);
+ }
+ }
+ }
+
+
/* checks result validity */
if (!result)
{
Index: module/pg.py
===================================================================
--- module/pg.py (revision 445)
+++ module/pg.py (working copy)
@@ -318,7 +318,7 @@
self.db.close()
self.db = db
- def query(self, qstr):
+ def query(self, qstr, *args):
"""Executes a SQL command string.
This method simply sends a SQL query to the database. If the query is
@@ -332,12 +332,16 @@
a pgqueryobject that can be accessed via getresult() or dictresult()
or simply printed. Otherwise, it returns `None`.
+ The query can contain numbered parameters of the form $1 in place
+ of any data constant. Arguments given after the query string will
+ be substituted for the corresponding numbered parameter. Parameter
+ values can also be given as a single list or tuple argument.
"""
# Wraps shared library function for debugging.
if not self.db:
raise _int_error('Connection is not valid')
self._do_debug(qstr)
- return self.db.query(qstr)
+ return self.db.query(qstr, args)
def pkey(self, cl, newpkey=None):
"""This method gets or sets the primary key of a class.
@@ -548,7 +552,8 @@
raise _db_error('%s not in arg' % qoid)
else:
arg = {qoid: arg}
- where = 'oid = %s' % arg[qoid]
+ where = 'oid = $1'
+ params = (arg[qoid],)
attnames = '*'
else:
attnames = self.get_attnames(qcl)
@@ -558,14 +563,15 @@
if len(keyname) > 1:
raise _prg_error('Composite key needs dict as arg')
arg = dict([(k, arg) for k in keyname])
- where = ' AND '.join(['%s = %s'
- % (k, self._quote(arg[k], attnames[k])) for k in keyname])
+ where = ' AND '.join(['%s = $%s'
+ % (k, i+1) for i,k in enumerate(keyname)])
+ params = tuple(arg[k] for k in keyname)
attnames = ', '.join(attnames)
q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
- self._do_debug(q)
- res = self.db.query(q).dictresult()
+ self._do_debug(q + ('%% %r' % (params,)))
+ res = self.db.query(q, params).dictresult()
if not res:
- raise _db_error('No such record in %s where %s' % (qcl, where))
+ raise _db_error('No such record in %s where %s %% %r' % (qcl, where, params))
for att, value in res[0].iteritems():
arg[att == 'oid' and qoid or att] = value
return arg
@@ -590,11 +596,14 @@
d = {}
d.update(kw)
attnames = self.get_attnames(qcl)
- names, values = [], []
+ names, values, params = [], [], []
+ i = 1
for n in attnames:
if n != 'oid' and n in d:
names.append('"%s"' % n)
- values.append(self._quote(d[n], attnames[n]))
+ values.append('$%s' % (i,))
+ params.append(d[n])
+ i += 1
names, values = ', '.join(names), ', '.join(values)
selectable = self.has_table_privilege(qcl)
if selectable and self.server_version >= 80200:
@@ -602,8 +611,8 @@
else:
ret = ''
q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
- self._do_debug(q)
- res = self.db.query(q)
+ self._do_debug(q + ("%% %r" % (params,)))
+ res = self.db.query(q, params)
if ret:
res = res.dictresult()
for att, value in res[0].iteritems():
@@ -645,7 +654,8 @@
d.update(kw)
attnames = self.get_attnames(qcl)
if qoid in d:
- where = 'oid = %s' % d[qoid]
+ where = 'oid = $1'
+ params = [d[qoid]]
keyname = ()
else:
try:
@@ -655,14 +665,18 @@
if isinstance(keyname, basestring):
keyname = (keyname,)
try:
- where = ' AND '.join(['%s = %s'
- % (k, self._quote(d[k], attnames[k])) for k in keyname])
+ where = ' AND '.join(['%s = $%s'
+ % (k, i+1) for i,k in enumerate(keyname)])
+ params = [d[k] for k in keyname]
except KeyError:
raise _prg_error('Update needs primary key or oid.')
values = []
+ i = len(params)
for n in attnames:
if n in d and n not in keyname:
- values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
+ i += 1
+ values.append('%s = $%s' % (n, i))
+ params.append(d[n])
if not values:
return d
values = ', '.join(values)
@@ -673,7 +687,7 @@
ret = ''
q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
self._do_debug(q)
- res = self.db.query(q)
+ res = self.db.query(q, params)
if ret:
res = res.dictresult()[0]
for att, value in res.iteritems():
@@ -735,7 +749,8 @@
d = {}
d.update(kw)
if qoid in d:
- where = 'oid = %s' % d[qoid]
+ where = 'oid = $1'
+ params = (d[qoid],)
else:
try:
keyname = self.pkey(qcl)
@@ -745,13 +760,14 @@
keyname = (keyname,)
attnames = self.get_attnames(qcl)
try:
- where = ' AND '.join(['%s = %s'
- % (k, self._quote(d[k], attnames[k])) for k in keyname])
+ where = ' AND '.join(['%s = $%s'
+ % (k, i+1) for i,k in enumerate(keyname)])
+ params = tuple(d[k] for k in keyname)
except KeyError:
raise _prg_error('Delete needs primary key or oid.')
q = 'DELETE FROM %s WHERE %s' % (qcl, where)
- self._do_debug(q)
- return int(self.db.query(q))
+ self._do_debug(q + ("%% %r" % (params,)))
+ return int(self.db.query(q, params))
# if run as script, print some information
Index: docs/pg.txt
===================================================================
--- docs/pg.txt (revision 445)
+++ docs/pg.txt (working copy)
@@ -423,10 +423,11 @@
-------------------------------------
Syntax::
- query(command)
+ query(command, [args])
Parameters:
:command: SQL command (string)
+ :args: optional positional arguments
Return type:
:pgqueryobject, None: result values
@@ -449,10 +450,22 @@
method returns a `pgqueryobject` that can be accessed via the `getresult()`
or `dictresult()` method or simply printed. Otherwise, it returns `None`.
+ The query may optionally contain positional parameters of the form `$1`,
+ `$2`, etc instead of literal data, and the values supplied as a tuple.
+ The values are substituted by the database in such a way that they don't
+ need to be escaped, making this an effective way to pass arbitrary or
+ unknown data without worrying about SQL injection or syntax errors.
+
When the database could not process the query, a `pg.ProgrammingError` or
a `pg.InternalError` is raised. You can check the "SQLSTATE" code of this
error by reading its `sqlstate` attribute.
+Example::
+
+ name = raw_input("Name? ")
+ phone = con.query("select phone from employees"
+ " where name=$1", (name, )).getresult()
+
reset - resets the connection
-----------------------------
Syntax::
@@ -1003,6 +1016,41 @@
either in the dictionary where the OID must be munged, or in the keywords
where it can be simply the string "oid".
+query - executes a SQL command string
+-------------------------------------
+Syntax::
+
+ query(command, [arg1, [arg2, ...]])
+
+Parameters:
+ :command: SQL command (string)
+ :arg*: optional positional arguments
+
+Return type:
+ :pgqueryobject, None: result values
+
+Exceptions raised:
+ :TypeError: bad argument type, or too many arguments
+ :TypeError: invalid connection
+ :ValueError: empty SQL query or lost connection
+ :pg.ProgrammingError: error in query
+ :pg.InternalError: error during query processing
+
+Description:
+ Similar to the pgobject function with the same name, except that positional
+ arguments can be passed either as a single list or tuple, or as individual
+ positional arguments
+
+Example::
+
+ name = raw_input("Name? ")
+ phone = raw_input("Phone? "
+ rows = db.query("update employees set phone=$2"
+ " where name=$1", (name, phone)).getresult()[0][0]
+ # or
+ rows = db.query("update employees set phone=$2"
+ " where name=$1", name, phone).getresult()[0][0]
+
clear - clears row values in memory
-----------------------------------
Syntax::
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql