Title: [953] trunk: Add special methods for using prepared statements
Revision
953
Author
cito
Date
2019-01-03 15:38:57 -0500 (Thu, 03 Jan 2019)

Log Message

Add special methods for using prepared statements

Modified Paths


Diff

Modified: trunk/pg.py (952 => 953)


--- trunk/pg.py	2019-01-03 15:20:00 UTC (rev 952)
+++ trunk/pg.py	2019-01-03 20:38:57 UTC (rev 953)
@@ -1869,6 +1869,57 @@
         return self.query(*self.adapter.format_query(
             command, parameters, types, inline))
 
+    def query_prepared(self, name, *args):
+        """Execute a prepared SQL statement.
+
+        This works like the query() method, but you need to pass the name of
+        a prepared statement that you have already created with prepare().
+        """
+        if not self.db:
+            raise _int_error('Connection is not valid')
+        if args:
+            self._do_debug('EXECUTE', name, args)
+            return self.db.query_prepared(name, args)
+        self._do_debug('EXECUTE', name)
+        return self.db.query_prepared(name)
+
+    def prepare(self, name, command):
+        """Create a prepared SQL statement.
+
+        This creates a prepared statement with the given name for the given
+        command for later execution with the query_prepared() method.
+        The name can be "" to create an unnamed statement, in which case any
+        pre-existing unnamed statement is automatically replaced; otherwise
+        it is an error if the statement name is already defined in the current
+        database session.
+
+        If any parameters are used, they can be referred to in the query as
+        numbered parameters of the form $1.
+        """
+        if not self.db:
+            raise _int_error('Connection is not valid')
+        self._do_debug('prepare', name, command)
+        return self.db.prepare(name, command)
+
+    def describe_prepared(self, name):
+        """Describe a prepared SQL statement.
+
+        This method returns a Query object describing the result columns of
+        the prepared statement with the given name.
+        """
+        return self.db.describe_prepared(name)
+
+    def delete_prepared(self, name=None):
+        """Delete a prepared SQL statement
+
+        This deallocates a previously prepared SQL statement with the given
+        name, or deallocates all prepared statements. Prepared statements are
+        also deallocated automatically when the current session ends.
+        """
+        q = "DEALLOCATE %s" % (name or 'ALL',)
+        self._do_debug(q)
+        return self.db.query(q)
+
     def pkey(self, table, composite=False, flush=False):
         """Get or set the primary key of a table.
 

Modified: trunk/pgmodule.c (952 => 953)


--- trunk/pgmodule.c	2019-01-03 15:20:00 UTC (rev 952)
+++ trunk/pgmodule.c	2019-01-03 20:38:57 UTC (rev 953)
@@ -2141,13 +2141,10 @@
 }
 
 /* database query */
-static char connQuery__doc__[] =
-"query(sql, [arg]) -- create a new query object for this connection\n\n"
-"You must pass the SQL (string) request and you can optionally pass\n"
-"a tuple with positional parameters.\n";
 
+/* base method for execution of both unprepared and prepared queries */
 static PyObject *
-connQuery(connObject *self, PyObject *args)
+_connQuery(connObject *self, PyObject *args, int prepared)
 {
 	PyObject	*query_obj;
 	PyObject	*param_obj = NULL;
@@ -2287,8 +2284,11 @@
 		}
 
 		Py_BEGIN_ALLOW_THREADS
-		result = PQexecParams(self->cnx, query, nparms,
-			NULL, parms, NULL, NULL, 0);
+		result = prepared ?
+			PQexecPrepared(self->cnx, query, nparms,
+				parms, NULL, NULL, 0) :
+			PQexecParams(self->cnx, query, nparms,
+				NULL, parms, NULL, NULL, 0);
 		Py_END_ALLOW_THREADS
 
 		PyMem_Free((void *)parms);
@@ -2298,7 +2298,10 @@
 	else
 	{
 		Py_BEGIN_ALLOW_THREADS
-		result = PQexec(self->cnx, query);
+		result = prepared ?
+			PQexecPrepared(self->cnx, query, 0,
+				NULL, NULL, NULL, 0) :
+			PQexec(self->cnx, query);
 		Py_END_ALLOW_THREADS
 	}
 
@@ -2376,6 +2379,123 @@
 	return (PyObject *) npgobj;
 }
 
+/* database query */
+static char connQuery__doc__[] =
+"query(sql, [arg]) -- create a new query object for this connection\n\n"
+"You must pass the SQL (string) request and you can optionally pass\n"
+"a tuple with positional parameters.\n";
+
+static PyObject *
+connQuery(connObject *self, PyObject *args)
+{
+	return _connQuery(self, args, 0);
+}
+
+/* execute prepared statement */
+static char connQueryPrepared__doc__[] =
+"query_prepared(name, [arg]) -- execute a prepared statement\n\n"
+"You must pass the name (string) of the prepared statement and you can\n"
+"optionally pass a tuple with positional parameters.\n";
+
+static PyObject *
+connQueryPrepared(connObject *self, PyObject *args)
+{
+	return _connQuery(self, args, 1);
+}
+
+/* create prepared statement */
+static char connPrepare__doc__[] =
+"prepare(name, sql) -- create a prepared statement\n\n"
+"You must pass the name (string) of the prepared statement and the\n"
+"SQL (string) request for later execution.\n";
+
+static PyObject *
+connPrepare(connObject *self, PyObject *args)
+{
+	char 		*name, *query;
+	int 		name_length, query_length;
+	PGresult	*result;
+
+	if (!self->cnx)
+	{
+		PyErr_SetString(PyExc_TypeError, "Connection is not valid");
+		return NULL;
+	}
+
+	/* reads args */
+	if (!PyArg_ParseTuple(args, "s#s#",
+		&name, &name_length, &query, &query_length))
+	{
+		PyErr_SetString(PyExc_TypeError,
+			"Method prepare() takes two string arguments");
+		return NULL;
+	}
+
+	/* create prepared statement */
+	Py_BEGIN_ALLOW_THREADS
+	result = PQprepare(self->cnx, name, query, 0, NULL);
+	Py_END_ALLOW_THREADS
+	if (result && PQresultStatus(result) == PGRES_COMMAND_OK)
+	{
+		PQclear(result);
+		Py_INCREF(Py_None);
+		return Py_None; /* success */
+	}
+	set_error(ProgrammingError, "Cannot create prepared statement",
+		self->cnx, result);
+	if (result)
+		PQclear(result);
+	return NULL; /* error */
+}
+
+/* describe prepared statement */
+static char connDescribePrepared__doc__[] =
+"describe_prepared(name, sql) -- describe a prepared statement\n\n"
+"You must pass the name (string) of the prepared statement.\n";
+
+static PyObject *
+connDescribePrepared(connObject *self, PyObject *args)
+{
+	char 		*name;
+	int 		name_length;
+	PGresult	*result;
+
+	if (!self->cnx)
+	{
+		PyErr_SetString(PyExc_TypeError, "Connection is not valid");
+		return NULL;
+	}
+
+	/* reads args */
+	if (!PyArg_ParseTuple(args, "s#",
+		&name, &name_length))
+	{
+		PyErr_SetString(PyExc_TypeError,
+			"Method prepare() takes a string argument");
+		return NULL;
+	}
+
+	/* describe prepared statement */
+	Py_BEGIN_ALLOW_THREADS
+	result = PQdescribePrepared(self->cnx, name);
+	Py_END_ALLOW_THREADS
+	if (result && PQresultStatus(result) == PGRES_COMMAND_OK)
+	{
+		queryObject *npgobj = PyObject_NEW(queryObject, &queryType);
+		if (!npgobj)
+			return PyErr_NoMemory();
+		Py_XINCREF(self);
+		npgobj->pgcnx = self;
+		npgobj->result = result;
+		return (PyObject *) npgobj;
+	}
+	set_error(ProgrammingError, "Cannot describe prepared statement",
+		self->cnx, result);
+	if (result)
+		PQclear(result);
+	return NULL; /* error */
+}
+
 #ifdef DIRECT_ACCESS
 static char connPutLine__doc__[] =
 "putline(line) -- send a line directly to the backend";
@@ -3414,6 +3534,11 @@
 
 	{"source", (PyCFunction) connSource, METH_NOARGS, connSource__doc__},
 	{"query", (PyCFunction) connQuery, METH_VARARGS, connQuery__doc__},
+	{"query_prepared", (PyCFunction) connQueryPrepared, METH_VARARGS,
+			connQueryPrepared__doc__},
+	{"prepare", (PyCFunction) connPrepare, METH_VARARGS, connPrepare__doc__},
+	{"describe_prepared", (PyCFunction) connDescribePrepared, METH_VARARGS,
+			connDescribePrepared__doc__},
 	{"reset", (PyCFunction) connReset, METH_NOARGS, connReset__doc__},
 	{"cancel", (PyCFunction) connCancel, METH_NOARGS, connCancel__doc__},
 	{"close", (PyCFunction) connClose, METH_NOARGS, connClose__doc__},

Modified: trunk/tests/test_classic_connection.py (952 => 953)


--- trunk/tests/test_classic_connection.py	2019-01-03 15:20:00 UTC (rev 952)
+++ trunk/tests/test_classic_connection.py	2019-01-03 20:38:57 UTC (rev 953)
@@ -121,10 +121,11 @@
         self.assertEqual(attributes, connection_attributes)
 
     def testAllConnectMethods(self):
-        methods = '''cancel close date_format endcopy
+        methods = '''cancel close date_format describe_prepared endcopy
             escape_bytea escape_identifier escape_literal escape_string
             fileno get_cast_hook get_notice_receiver getline getlo getnotify
-            inserttable locreate loimport parameter putline query reset
+            inserttable locreate loimport parameter
+            prepare putline query query_prepared reset
             set_cast_hook set_notice_receiver source transaction'''.split()
         connection_methods = [a for a in dir(self.connection)
             if not a.startswith('__') and self.is_method(a)]
@@ -932,6 +933,82 @@
             ).dictresult(), [{'garbage': garbage}])
 
 
+class TestPreparedQueries(unittest.TestCase):
+    """Test prepared queries via a basic pg connection."""
+
+    def setUp(self):
+        self.c = connect()
+        self.c.query('set client_encoding=utf8')
+
+    def tearDown(self):
+        self.c.close()
+
+    def testEmptyPreparedStatement(self):
+        self.c.prepare('', '')
+        self.assertRaises(ValueError, self.c.query_prepared, '')
+
+    def testInvalidPreparedStatement(self):
+        self.assertRaises(pg.ProgrammingError, self.c.prepare, '', 'bad')
+
+    def testNonExistentPreparedStatement(self):
+        self.assertRaises(pg.OperationalError,
+            self.c.query_prepared, 'does-not-exist')
+
+    def testAnonymousQueryWithoutParams(self):
+        self.assertIsNone(self.c.prepare('', "select 'anon'"))
+        self.assertEqual(self.c.query_prepared('').getresult(), [('anon',)])
+
+    def testNamedQueryWithoutParams(self):
+        self.assertIsNone(self.c.prepare('hello', "select 'world'"))
+        self.assertEqual(self.c.query_prepared('hello').getresult(),
+            [('world',)])
+
+    def testMultipleNamedQueriesWithoutParams(self):
+        self.assertIsNone(self.c.prepare('query17', "select 17"))
+        self.assertIsNone(self.c.prepare('query42', "select 42"))
+        self.assertEqual(self.c.query_prepared('query17').getresult(), [(17,)])
+        self.assertEqual(self.c.query_prepared('query42').getresult(), [(42,)])
+
+    def testAnonymousQueryWithParams(self):
+        self.assertIsNone(self.c.prepare('', "select $1 || ', ' || $2"))
+        self.assertEqual(
+            self.c.query_prepared('', ['hello', 'world']).getresult(),
+            [('hello, world',)])
+        self.assertIsNone(self.c.prepare('', "select 1+ $1 + $2 + $3"))
+        self.assertEqual(
+            self.c.query_prepared('', [17, -5, 29]).getresult(), [(42,)])
+
+    def testMultipleNamedQueriesWithParams(self):
+        self.assertIsNone(self.c.prepare('q1', "select $1 || '!'"))
+        self.assertIsNone(self.c.prepare('q2', "select $1 || '-' || $2"))
+        self.assertEqual(self.c.query_prepared('q1', ['hello']).getresult(),
+            [('hello!',)])
+        self.assertEqual(self.c.query_prepared('q2', ['he', 'lo']).getresult(),
+            [('he-lo',)])
+
+    def testDescribeNonExistentQuery(self):
+        self.assertRaises(pg.OperationalError,
+            self.c.describe_prepared, 'does-not-exist')
+
+    def testDescribeAnonymousQuery(self):
+        self.c.prepare('', "select 1::int, 'a'::char")
+        r = self.c.describe_prepared('')
+        self.assertEqual(r.listfields(), ('int4', 'bpchar'))
+
+    def testDescribeNamedQuery(self):
+        self.c.prepare('myquery', "select 1 as first, 2 as second")
+        r = self.c.describe_prepared('myquery')
+        self.assertEqual(r.listfields(), ('first', 'second'))
+
+    def testDescribeMultipleNamedQueries(self):
+        self.c.prepare('query1', "select 1::int")
+        self.c.prepare('query2', "select 1::int, 2::int")
+        r = self.c.describe_prepared('query1')
+        self.assertEqual(r.listfields(), ('int4',))
+        r = self.c.describe_prepared('query2')
+        self.assertEqual(r.listfields(), ('int4', 'int4'))
+
+
 class TestQueryResultTypes(unittest.TestCase):
     """Test proper result types via a basic pg connection."""
 

Modified: trunk/tests/test_classic_dbwrapper.py (952 => 953)


--- trunk/tests/test_classic_dbwrapper.py	2019-01-03 15:20:00 UTC (rev 952)
+++ trunk/tests/test_classic_dbwrapper.py	2019-01-03 20:38:57 UTC (rev 953)
@@ -198,6 +198,7 @@
             'cancel', 'clear', 'close', 'commit',
             'date_format', 'db', 'dbname', 'dbtypes',
             'debug', 'decode_json', 'delete',
+            'delete_prepared', 'describe_prepared',
             'encode_json', 'end', 'endcopy', 'error',
             'escape_bytea', 'escape_identifier',
             'escape_literal', 'escape_string',
@@ -213,8 +214,8 @@
             'notification_handler',
             'options',
             'parameter', 'pkey', 'port',
-            'protocol_version', 'putline',
-            'query', 'query_formatted',
+            'prepare', 'protocol_version', 'putline',
+            'query', 'query_formatted', 'query_prepared',
             'release', 'reopen', 'reset', 'rollback',
             'savepoint', 'server_version',
             'set_cast_hook', 'set_notice_receiver',
@@ -968,6 +969,58 @@
         r = f(q, {}).getresult()[0][0]
         self.assertEqual(r, 42)
 
+    def testQueryPreparedWithoutParams(self):
+        p = self.db.prepare
+        p('q1', "select 17")
+        p('q2', "select 42")
+        f = self.db.query_prepared
+        r = f('q1').getresult()[0][0]
+        self.assertEqual(r, 17)
+        r = f('q2').getresult()[0][0]
+        self.assertEqual(r, 42)
+
+    def testQueryPreparedWithParams(self):
+        p = self.db.prepare
+        p('sum', "select 1 + $1 + $2 + $3")
+        p('cat', "select initcap($1) || ', ' || $2 || '!'")
+        f = self.db.query_prepared
+        r = f('sum', 2, 3, 5).getresult()[0][0]
+        self.assertEqual(r, 11)
+        r = f('cat', 'hello', 'world').getresult()[0][0]
+        self.assertEqual(r, 'Hello, world!')
+
+    def testPrepare(self):
+        p = self.db.prepare
+        self.assertIsNone(p('',  "select null"))
+        self.assertIsNone(p('myquery', "select 'hello'"))
+        self.assertIsNone(p('myquery2', "select 'world'"))
+        self.assertRaises(pg.ProgrammingError,
+            p, 'myquery', "select 'hello, too'")
+
+    def testDescribePrepared(self):
+        self.db.prepare('count', 'select 1 as first, 2 as second')
+        f = self.db.describe_prepared
+        r = f('count').listfields()
+        self.assertEqual(r, ('first', 'second'))
+
+    def testDeletePrepared(self):
+        f = self.db.delete_prepared
+        f()
+        e = pg.OperationalError
+        self.assertRaises(e, f, 'myquery')
+        p = self.db.prepare
+        p('q1', "select 1")
+        p('q2', "select 2")
+        f('q1')
+        f('q2')
+        self.assertRaises(e, f, 'q1')
+        self.assertRaises(e, f, 'q2')
+        p('q1', "select 1")
+        p('q2', "select 2")
+        f()
+        self.assertRaises(e, f, 'q1')
+        self.assertRaises(e, f, 'q2')
+
     def testPkey(self):
         query = self.db.query
         pkey = self.db.pkey
_______________________________________________
PyGreSQL mailing list
PyGreSQL@Vex.Net
https://mail.vex.net/mailman/listinfo/pygresql

Reply via email to