Diff
Modified: trunk/pg.py (952 => 953)
--- trunk/pg.py 2019-01-03 15:20:00 UTC (rev 952)
+++ trunk/pg.py 2019-01-03 20:38:57 UTC (rev 953)
@@ -1869,6 +1869,57 @@
return self.query(*self.adapter.format_query(
command, parameters, types, inline))
+ def query_prepared(self, name, *args):
+ """Execute a prepared SQL statement.
+
+ This works like the query() method, but you need to pass the name of
+ a prepared statement that you have already created with prepare().
+ """
+ if not self.db:
+ raise _int_error('Connection is not valid')
+ if args:
+ self._do_debug('EXECUTE', name, args)
+ return self.db.query_prepared(name, args)
+ self._do_debug('EXECUTE', name)
+ return self.db.query_prepared(name)
+
+ def prepare(self, name, command):
+ """Create a prepared SQL statement.
+
+ This creates a prepared statement with the given name for the given
+ command for later execution with the query_prepared() method.
+ The name can be "" to create an unnamed statement, in which case any
+ pre-existing unnamed statement is automatically replaced; otherwise
+ it is an error if the statement name is already defined in the current
+ database session.
+
+ If any parameters are used, they can be referred to in the query as
+ numbered parameters of the form $1.
+ """
+ if not self.db:
+ raise _int_error('Connection is not valid')
+ self._do_debug('prepare', name, command)
+ return self.db.prepare(name, command)
+
+ def describe_prepared(self, name):
+ """Describe a prepared SQL statement.
+
+ This method returns a Query object describing the result columns of
+ the prepared statement with the given name.
+ """
+ return self.db.describe_prepared(name)
+
+ def delete_prepared(self, name=None):
+ """Delete a prepared SQL statement
+
+ This deallocates a previously prepared SQL statement with the given
+ name, or deallocates all prepared statements. Prepared statements are
+ also deallocated automatically when the current session ends.
+ """
+ q = "DEALLOCATE %s" % (name or 'ALL',)
+ self._do_debug(q)
+ return self.db.query(q)
+
def pkey(self, table, composite=False, flush=False):
"""Get or set the primary key of a table.
Modified: trunk/pgmodule.c (952 => 953)
--- trunk/pgmodule.c 2019-01-03 15:20:00 UTC (rev 952)
+++ trunk/pgmodule.c 2019-01-03 20:38:57 UTC (rev 953)
@@ -2141,13 +2141,10 @@
}
/* database query */
-static char connQuery__doc__[] =
-"query(sql, [arg]) -- create a new query object for this connection\n\n"
-"You must pass the SQL (string) request and you can optionally pass\n"
-"a tuple with positional parameters.\n";
+/* base method for execution of both unprepared and prepared queries */
static PyObject *
-connQuery(connObject *self, PyObject *args)
+_connQuery(connObject *self, PyObject *args, int prepared)
{
PyObject *query_obj;
PyObject *param_obj = NULL;
@@ -2287,8 +2284,11 @@
}
Py_BEGIN_ALLOW_THREADS
- result = PQexecParams(self->cnx, query, nparms,
- NULL, parms, NULL, NULL, 0);
+ result = prepared ?
+ PQexecPrepared(self->cnx, query, nparms,
+ parms, NULL, NULL, 0) :
+ PQexecParams(self->cnx, query, nparms,
+ NULL, parms, NULL, NULL, 0);
Py_END_ALLOW_THREADS
PyMem_Free((void *)parms);
@@ -2298,7 +2298,10 @@
else
{
Py_BEGIN_ALLOW_THREADS
- result = PQexec(self->cnx, query);
+ result = prepared ?
+ PQexecPrepared(self->cnx, query, 0,
+ NULL, NULL, NULL, 0) :
+ PQexec(self->cnx, query);
Py_END_ALLOW_THREADS
}
@@ -2376,6 +2379,123 @@
return (PyObject *) npgobj;
}
+/* database query */
+static char connQuery__doc__[] =
+"query(sql, [arg]) -- create a new query object for this connection\n\n"
+"You must pass the SQL (string) request and you can optionally pass\n"
+"a tuple with positional parameters.\n";
+
+static PyObject *
+connQuery(connObject *self, PyObject *args)
+{
+ return _connQuery(self, args, 0);
+}
+
+/* execute prepared statement */
+static char connQueryPrepared__doc__[] =
+"query_prepared(name, [arg]) -- execute a prepared statement\n\n"
+"You must pass the name (string) of the prepared statement and you can\n"
+"optionally pass a tuple with positional parameters.\n";
+
+static PyObject *
+connQueryPrepared(connObject *self, PyObject *args)
+{
+ return _connQuery(self, args, 1);
+}
+
+/* create prepared statement */
+static char connPrepare__doc__[] =
+"prepare(name, sql) -- create a prepared statement\n\n"
+"You must pass the name (string) of the prepared statement and the\n"
+"SQL (string) request for later execution.\n";
+
+static PyObject *
+connPrepare(connObject *self, PyObject *args)
+{
+ char *name, *query;
+ int name_length, query_length;
+ PGresult *result;
+
+ if (!self->cnx)
+ {
+ PyErr_SetString(PyExc_TypeError, "Connection is not valid");
+ return NULL;
+ }
+
+ /* reads args */
+ if (!PyArg_ParseTuple(args, "s#s#",
+ &name, &name_length, &query, &query_length))
+ {
+ PyErr_SetString(PyExc_TypeError,
+ "Method prepare() takes two string arguments");
+ return NULL;
+ }
+
+ /* create prepared statement */
+ Py_BEGIN_ALLOW_THREADS
+ result = PQprepare(self->cnx, name, query, 0, NULL);
+ Py_END_ALLOW_THREADS
+ if (result && PQresultStatus(result) == PGRES_COMMAND_OK)
+ {
+ PQclear(result);
+ Py_INCREF(Py_None);
+ return Py_None; /* success */
+ }
+ set_error(ProgrammingError, "Cannot create prepared statement",
+ self->cnx, result);
+ if (result)
+ PQclear(result);
+ return NULL; /* error */
+}
+
+/* describe prepared statement */
+static char connDescribePrepared__doc__[] =
+"describe_prepared(name, sql) -- describe a prepared statement\n\n"
+"You must pass the name (string) of the prepared statement.\n";
+
+static PyObject *
+connDescribePrepared(connObject *self, PyObject *args)
+{
+ char *name;
+ int name_length;
+ PGresult *result;
+
+ if (!self->cnx)
+ {
+ PyErr_SetString(PyExc_TypeError, "Connection is not valid");
+ return NULL;
+ }
+
+ /* reads args */
+ if (!PyArg_ParseTuple(args, "s#",
+ &name, &name_length))
+ {
+ PyErr_SetString(PyExc_TypeError,
+ "Method prepare() takes a string argument");
+ return NULL;
+ }
+
+ /* describe prepared statement */
+ Py_BEGIN_ALLOW_THREADS
+ result = PQdescribePrepared(self->cnx, name);
+ Py_END_ALLOW_THREADS
+ if (result && PQresultStatus(result) == PGRES_COMMAND_OK)
+ {
+ queryObject *npgobj = PyObject_NEW(queryObject, &queryType);
+ if (!npgobj)
+ return PyErr_NoMemory();
+ Py_XINCREF(self);
+ npgobj->pgcnx = self;
+ npgobj->result = result;
+ return (PyObject *) npgobj;
+ }
+ set_error(ProgrammingError, "Cannot describe prepared statement",
+ self->cnx, result);
+ if (result)
+ PQclear(result);
+ return NULL; /* error */
+}
+
#ifdef DIRECT_ACCESS
static char connPutLine__doc__[] =
"putline(line) -- send a line directly to the backend";
@@ -3414,6 +3534,11 @@
{"source", (PyCFunction) connSource, METH_NOARGS, connSource__doc__},
{"query", (PyCFunction) connQuery, METH_VARARGS, connQuery__doc__},
+ {"query_prepared", (PyCFunction) connQueryPrepared, METH_VARARGS,
+ connQueryPrepared__doc__},
+ {"prepare", (PyCFunction) connPrepare, METH_VARARGS, connPrepare__doc__},
+ {"describe_prepared", (PyCFunction) connDescribePrepared, METH_VARARGS,
+ connDescribePrepared__doc__},
{"reset", (PyCFunction) connReset, METH_NOARGS, connReset__doc__},
{"cancel", (PyCFunction) connCancel, METH_NOARGS, connCancel__doc__},
{"close", (PyCFunction) connClose, METH_NOARGS, connClose__doc__},
Modified: trunk/tests/test_classic_connection.py (952 => 953)
--- trunk/tests/test_classic_connection.py 2019-01-03 15:20:00 UTC (rev 952)
+++ trunk/tests/test_classic_connection.py 2019-01-03 20:38:57 UTC (rev 953)
@@ -121,10 +121,11 @@
self.assertEqual(attributes, connection_attributes)
def testAllConnectMethods(self):
- methods = '''cancel close date_format endcopy
+ methods = '''cancel close date_format describe_prepared endcopy
escape_bytea escape_identifier escape_literal escape_string
fileno get_cast_hook get_notice_receiver getline getlo getnotify
- inserttable locreate loimport parameter putline query reset
+ inserttable locreate loimport parameter
+ prepare putline query query_prepared reset
set_cast_hook set_notice_receiver source transaction'''.split()
connection_methods = [a for a in dir(self.connection)
if not a.startswith('__') and self.is_method(a)]
@@ -932,6 +933,82 @@
).dictresult(), [{'garbage': garbage}])
+class TestPreparedQueries(unittest.TestCase):
+ """Test prepared queries via a basic pg connection."""
+
+ def setUp(self):
+ self.c = connect()
+ self.c.query('set client_encoding=utf8')
+
+ def tearDown(self):
+ self.c.close()
+
+ def testEmptyPreparedStatement(self):
+ self.c.prepare('', '')
+ self.assertRaises(ValueError, self.c.query_prepared, '')
+
+ def testInvalidPreparedStatement(self):
+ self.assertRaises(pg.ProgrammingError, self.c.prepare, '', 'bad')
+
+ def testNonExistentPreparedStatement(self):
+ self.assertRaises(pg.OperationalError,
+ self.c.query_prepared, 'does-not-exist')
+
+ def testAnonymousQueryWithoutParams(self):
+ self.assertIsNone(self.c.prepare('', "select 'anon'"))
+ self.assertEqual(self.c.query_prepared('').getresult(), [('anon',)])
+
+ def testNamedQueryWithoutParams(self):
+ self.assertIsNone(self.c.prepare('hello', "select 'world'"))
+ self.assertEqual(self.c.query_prepared('hello').getresult(),
+ [('world',)])
+
+ def testMultipleNamedQueriesWithoutParams(self):
+ self.assertIsNone(self.c.prepare('query17', "select 17"))
+ self.assertIsNone(self.c.prepare('query42', "select 42"))
+ self.assertEqual(self.c.query_prepared('query17').getresult(), [(17,)])
+ self.assertEqual(self.c.query_prepared('query42').getresult(), [(42,)])
+
+ def testAnonymousQueryWithParams(self):
+ self.assertIsNone(self.c.prepare('', "select $1 || ', ' || $2"))
+ self.assertEqual(
+ self.c.query_prepared('', ['hello', 'world']).getresult(),
+ [('hello, world',)])
+ self.assertIsNone(self.c.prepare('', "select 1+ $1 + $2 + $3"))
+ self.assertEqual(
+ self.c.query_prepared('', [17, -5, 29]).getresult(), [(42,)])
+
+ def testMultipleNamedQueriesWithParams(self):
+ self.assertIsNone(self.c.prepare('q1', "select $1 || '!'"))
+ self.assertIsNone(self.c.prepare('q2', "select $1 || '-' || $2"))
+ self.assertEqual(self.c.query_prepared('q1', ['hello']).getresult(),
+ [('hello!',)])
+ self.assertEqual(self.c.query_prepared('q2', ['he', 'lo']).getresult(),
+ [('he-lo',)])
+
+ def testDescribeNonExistentQuery(self):
+ self.assertRaises(pg.OperationalError,
+ self.c.describe_prepared, 'does-not-exist')
+
+ def testDescribeAnonymousQuery(self):
+ self.c.prepare('', "select 1::int, 'a'::char")
+ r = self.c.describe_prepared('')
+ self.assertEqual(r.listfields(), ('int4', 'bpchar'))
+
+ def testDescribeNamedQuery(self):
+ self.c.prepare('myquery', "select 1 as first, 2 as second")
+ r = self.c.describe_prepared('myquery')
+ self.assertEqual(r.listfields(), ('first', 'second'))
+
+ def testDescribeMultipleNamedQueries(self):
+ self.c.prepare('query1', "select 1::int")
+ self.c.prepare('query2', "select 1::int, 2::int")
+ r = self.c.describe_prepared('query1')
+ self.assertEqual(r.listfields(), ('int4',))
+ r = self.c.describe_prepared('query2')
+ self.assertEqual(r.listfields(), ('int4', 'int4'))
+
+
class TestQueryResultTypes(unittest.TestCase):
"""Test proper result types via a basic pg connection."""
Modified: trunk/tests/test_classic_dbwrapper.py (952 => 953)
--- trunk/tests/test_classic_dbwrapper.py 2019-01-03 15:20:00 UTC (rev 952)
+++ trunk/tests/test_classic_dbwrapper.py 2019-01-03 20:38:57 UTC (rev 953)
@@ -198,6 +198,7 @@
'cancel', 'clear', 'close', 'commit',
'date_format', 'db', 'dbname', 'dbtypes',
'debug', 'decode_json', 'delete',
+ 'delete_prepared', 'describe_prepared',
'encode_json', 'end', 'endcopy', 'error',
'escape_bytea', 'escape_identifier',
'escape_literal', 'escape_string',
@@ -213,8 +214,8 @@
'notification_handler',
'options',
'parameter', 'pkey', 'port',
- 'protocol_version', 'putline',
- 'query', 'query_formatted',
+ 'prepare', 'protocol_version', 'putline',
+ 'query', 'query_formatted', 'query_prepared',
'release', 'reopen', 'reset', 'rollback',
'savepoint', 'server_version',
'set_cast_hook', 'set_notice_receiver',
@@ -968,6 +969,58 @@
r = f(q, {}).getresult()[0][0]
self.assertEqual(r, 42)
+ def testQueryPreparedWithoutParams(self):
+ p = self.db.prepare
+ p('q1', "select 17")
+ p('q2', "select 42")
+ f = self.db.query_prepared
+ r = f('q1').getresult()[0][0]
+ self.assertEqual(r, 17)
+ r = f('q2').getresult()[0][0]
+ self.assertEqual(r, 42)
+
+ def testQueryPreparedWithParams(self):
+ p = self.db.prepare
+ p('sum', "select 1 + $1 + $2 + $3")
+ p('cat', "select initcap($1) || ', ' || $2 || '!'")
+ f = self.db.query_prepared
+ r = f('sum', 2, 3, 5).getresult()[0][0]
+ self.assertEqual(r, 11)
+ r = f('cat', 'hello', 'world').getresult()[0][0]
+ self.assertEqual(r, 'Hello, world!')
+
+ def testPrepare(self):
+ p = self.db.prepare
+ self.assertIsNone(p('', "select null"))
+ self.assertIsNone(p('myquery', "select 'hello'"))
+ self.assertIsNone(p('myquery2', "select 'world'"))
+ self.assertRaises(pg.ProgrammingError,
+ p, 'myquery', "select 'hello, too'")
+
+ def testDescribePrepared(self):
+ self.db.prepare('count', 'select 1 as first, 2 as second')
+ f = self.db.describe_prepared
+ r = f('count').listfields()
+ self.assertEqual(r, ('first', 'second'))
+
+ def testDeletePrepared(self):
+ f = self.db.delete_prepared
+ f()
+ e = pg.OperationalError
+ self.assertRaises(e, f, 'myquery')
+ p = self.db.prepare
+ p('q1', "select 1")
+ p('q2', "select 2")
+ f('q1')
+ f('q2')
+ self.assertRaises(e, f, 'q1')
+ self.assertRaises(e, f, 'q2')
+ p('q1', "select 1")
+ p('q2', "select 2")
+ f()
+ self.assertRaises(e, f, 'q1')
+ self.assertRaises(e, f, 'q2')
+
def testPkey(self):
query = self.db.query
pkey = self.db.pkey