Title: [978] trunk: Make classic query work as iterator
Revision
978
Author
cito
Date
2019-04-21 15:07:48 -0400 (Sun, 21 Apr 2019)

Log Message

Make classic query work as iterator

Mostly following Justin's proposal on the mailing list.

Modified Paths

Diff

Modified: trunk/docs/contents/changelog.rst (977 => 978)


--- trunk/docs/contents/changelog.rst	2019-04-20 12:40:21 UTC (rev 977)
+++ trunk/docs/contents/changelog.rst	2019-04-21 19:07:48 UTC (rev 978)
@@ -6,6 +6,15 @@
 - Support for prepared statements has been added to the classic API.
 - DB wrapper objects based on existing connections can now be closed and
   reopened properly (but the underlying connection will not be affected).
+- The query objects in the classic API can now be used as iterators
+  and will then yield the rows as tuples, similar to query.getresult().
+  Thanks to Justin Pryzby for the proposal and most of the implementation.
+- Added methods query.dictiter() and query.namediter() to the classic API
+  which work like query.dictresult() and query.namedresult() except that
+  they return iterators instead of lists.
+- Deprecated query.ntuples() in the classic API, since len(query) can now
+  be used and returns the same number.
+- Added pg.get/set_namediter and deprecated pg.get/set_namedresult.
 
 Vesion 5.0.7 (2019-mm-dd)
 -------------------------

Modified: trunk/docs/contents/pg/connection.rst (977 => 978)


--- trunk/docs/contents/pg/connection.rst	2019-04-20 12:40:21 UTC (rev 977)
+++ trunk/docs/contents/pg/connection.rst	2019-04-21 19:07:48 UTC (rev 978)
@@ -50,11 +50,16 @@
 of rows affected is returned as a string. If it is a statement that returns
 rows as a result (usually a select statement, but maybe also an
 ``"insert/update ... returning"`` statement), this method returns
-a :class:`Query` that can be accessed via the
-:meth:`Query.getresult`, :meth:`Query.dictresult` or
-:meth:`Query.namedresult` methods or simply printed.
-Otherwise, it returns ``None``.
+a :class:`Query`. Otherwise, it returns ``None``.
 
+You can use the :class:`Query` object as an iterator that yields all results
+as tuples, or call :meth:`Query.getresult` to get the result as a list
+of tuples. Alternatively, you can call :meth:`Query.dictresult` or
+:meth:`Query.dictiter` if you want to get the rows as dictionaries,
+or :meth:`Query.namedresult` or :meth:`Query.namediter` if you want to
+get the rows as named tuples. You can also simply print the :class:`Query`
+object to show the query results on the console.
+
 The SQL command may optionally contain positional parameters of the form
 ``$1``, ``$2``, etc instead of literal data, in which case the values
 must be supplied separately as a tuple.  The values are substituted by

Modified: trunk/docs/contents/pg/module.rst (977 => 978)


--- trunk/docs/contents/pg/module.rst	2019-04-20 12:40:21 UTC (rev 977)
+++ trunk/docs/contents/pg/module.rst	2019-04-21 19:07:48 UTC (rev 978)
@@ -308,30 +308,49 @@
 Note that there is also a :class:`DB` method with the same name
 which does exactly the same.
 
+get/set_namediter -- conversion to named tuples
+-----------------------------------------------
+
+.. function:: get_namediter()
+
+    Get the generator that converts to named tuples
+
+This returns the function used by PyGreSQL to construct the result of the
+:meth:`Query.namediter` and :meth:`Query.namedresult` methods.
+
+.. versionadded:: 5.1
+
+.. function:: set_namediter(func)
+
+    Set a generator that will convert to named tuples
+
+    :param func: the generator to be used to convert results to named tuples
+
+You can use this if you want to create different kinds of named tuples
+returned by the :meth:`Query.namediter` and :meth:`Query.namedresult` methods.
+If you set this function to *None*, then normal tuples will be used.
+
+.. versionadded:: 5.1
+
 get/set_namedresult -- conversion to named tuples
 -------------------------------------------------
 
 .. function:: get_namedresult()
 
-    Get the function that converts to named tuples
+    Get the generator that converts to named tuples
 
-This returns the function used by PyGreSQL to construct the result of the
-:meth:`Query.namedresult` method.
+.. deprecated:: 5.1
+   Use :func:`get_namediter` instead.
 
-.. versionadded:: 4.1
-
 .. function:: set_namedresult(func)
 
-    Set a function that will convert to named tuples
+    Set a generator that will convert to named tuples
 
-    :param func: the function to be used to convert results to named tuples
+    :param func: the generator to be used to convert results to named tuples
 
-You can use this if you want to create different kinds of named tuples
-returned by the :meth:`Query.namedresult` method.  If you set this function
-to *None*, then it will become equal to :meth:`Query.getresult`.
+.. deprecated:: 5.1
+   Use :func:`set_namediter` instead.
 
-.. versionadded:: 4.1
-
 get/set_decimal -- decimal type to be used for numeric values
 -------------------------------------------------------------
 

Modified: trunk/docs/contents/pg/query.rst (977 => 978)


--- trunk/docs/contents/pg/query.rst	2019-04-20 12:40:21 UTC (rev 977)
+++ trunk/docs/contents/pg/query.rst	2019-04-21 19:07:48 UTC (rev 978)
@@ -6,8 +6,11 @@
 .. class:: Query
 
 The :class:`Query` object returned by :meth:`Connection.query` and
-:meth:`DB.query` provides the following methods for accessing
-the results of the query:
+:meth:`DB.query` can be used as an iterator returning rows as tuples.
+You can also directly access row tuples using their index, and get
+the number of rows with the :func:`len` function. The :class:`Query`
+class also provides the following methods for accessing the results
+of the query:
 
 getresult -- get query values as list of tuples
 -----------------------------------------------
@@ -29,6 +32,13 @@
 Note that since PyGreSQL 5.0 this method will return the values of array
 type columns as Python lists.
 
+Since PyGreSQL 5.1 the :class:`Query` can be also used directly as
+an iterable sequence, i.e. you can iterate over the :class:`Query`
+object to get the same tuples as returned by :meth:`Query.getresult`.
+You can also call :func:`len` on a query to find the number of rows
+in the result, and access row tuples using their index directly on
+the :class:`Query` object.
+
 dictresult -- get query values as list of dictionaries
 ------------------------------------------------------
 
@@ -75,6 +85,49 @@
 
 .. versionadded:: 4.1
 
+dictiter -- get query values as iterator of dictionaries
+--------------------------------------------------------
+
+.. method:: Query.dictiter()
+
+    Get query values as iterator of dictionaries
+
+    :returns: result values as an iterator of dictionaries
+    :rtype: iterator
+    :raises TypeError: too many (any) parameters
+    :raises MemoryError: internal memory error
+
+This method returns query results as an iterator of dictionaries which have
+the field names as keys.
+
+If the query has duplicate field names, you will get the value for the
+field with the highest index in the query.
+
+.. versionadded:: 5.1
+
+namediter -- get query values as iterator of named tuples
+---------------------------------------------------------
+
+.. method:: Query.namediter()
+
+    Get query values as iterator of named tuples
+
+    :returns: result values as an iterator of named tuples
+    :rtype: iterator
+    :raises TypeError: too many (any) parameters
+    :raises TypeError: named tuples not supported
+    :raises MemoryError: internal memory error
+
+This method returns query results as an iterator of named tuples with
+proper field names.
+
+Column names in the database that are not valid as field names for
+named tuples (particularly, names starting with an underscore) are
+automatically renamed to valid positional names.
+
+.. versionadded:: 5.1
+
+
 listfields -- list fields names of previous query result
 --------------------------------------------------------
 
@@ -133,3 +186,6 @@
     :raises TypeError: Too many arguments.
 
 This method returns the number of tuples in the query result.
+
+.. deprecated:: 5.1
+   You can use the normal :func:`len` function instead.

Modified: trunk/docs/contents/tutorial.rst (977 => 978)


--- trunk/docs/contents/tutorial.rst	2019-04-20 12:40:21 UTC (rev 977)
+++ trunk/docs/contents/tutorial.rst	2019-04-21 19:07:48 UTC (rev 978)
@@ -106,6 +106,11 @@
     >>> rows[3].name
     'durian'
 
+In PyGreSQL 5.1 and newer, you can also use the :class:`Query` instance
+directly as an iterator that yields the rows as tuples, and you can use
+the methods :meth:`Query.dictiter` or :meth:`Query.namediter` to get
+iterators yielding the rows as dictionaries or named tuples.
+
 Using the method :meth:`DB.get_as_dict`, you can easily import the whole table
 into a Python dictionary mapping the primary key *id* to the *name*::
 

Modified: trunk/pg.py (977 => 978)


--- trunk/pg.py	2019-04-20 12:40:21 UTC (rev 977)
+++ trunk/pg.py	2019-04-21 19:07:48 UTC (rev 978)
@@ -1311,10 +1311,18 @@
     _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__)
 
 
-def _namedresult(q):
-    """Get query result as named tuples."""
+def _dictiter(q):
+    """Get query result as an iterator of dictionaries."""
+    fields = q.listfields()
+    for r in q:
+        yield dict(zip(fields, r))
+
+
+def _namediter(q):
+    """Get query result as an iterator of named tuples."""
     row = _row_factory(q.listfields())
-    return [row(r) for r in q.getresult()]
+    for r in q:
+        yield row(r)
 
 
 class _MemoryQuery:
@@ -1333,7 +1341,10 @@
         """Return the stored result of this query."""
         return self.result
 
+    def __iter__(self):
+        return iter(self.result)
 
+
 def _db_error(msg, cls=DatabaseError):
     """Return DatabaseError with empty sqlstate attribute."""
     error = cls(msg)
@@ -1353,7 +1364,8 @@
 
 # Initialize the C module
 
-set_namedresult(_namedresult)
+set_dictiter(_dictiter)
+set_namediter(_namediter)
 set_decimal(Decimal)
 set_jsondecode(jsondecode)
 

Modified: trunk/pgmodule.c (977 => 978)


--- trunk/pgmodule.c	2019-04-20 12:40:21 UTC (rev 977)
+++ trunk/pgmodule.c	2019-04-21 19:07:48 UTC (rev 978)
@@ -89,7 +89,8 @@
 #endif	/* DEFAULT_VARS */
 
 static PyObject *decimal = NULL, /* decimal type */
-				*namedresult = NULL, /* function for getting named results */
+				*dictiter = NULL, /* function for getting named results */
+				*namediter = NULL, /* function for getting named results */
 				*jsondecode = NULL; /* function for decoding json strings */
 static const char *date_format = NULL; /* date format that is always assumed */
 static char decimal_point = '.'; /* decimal point used in money values */
@@ -156,7 +157,7 @@
 	int			encoding; 		/* client encoding */
 	int			result_type;	/* result type (DDL/DML/DQL) */
 	long		arraysize;		/* array size for fetch method */
-	int			current_row;	/* current selected row */
+	int			current_row;	/* currently selected row */
 	int			max_row;		/* number of rows in the result */
 	int			num_fields;		/* number of fields in each row */
 }	sourceObject;
@@ -176,6 +177,10 @@
 	connObject *pgcnx;			/* parent connection object */
 	PGresult   *result;			/* result content */
 	int			encoding; 		/* client encoding */
+	int			current_row;	/* currently selected row */
+	int			max_row;		/* number of rows in the result */
+	int			num_fields;		/* number of fields in each row */
+	int		   *col_types;		/* PyGreSQL column types */
 }	queryObject;
 #define is_queryObject(v) (PyType(v) == &queryType)
 
@@ -931,7 +936,7 @@
 		{
 			char	   *estr;
 			Py_ssize_t	esize;
-			int quoted = 0, escaped =0;
+			int quoted = 0, escaped = 0;
 
 			estr = s;
 			quoted = *s == '"';
@@ -1653,17 +1658,17 @@
 static largeObject *
 largeNew(connObject *pgcnx, Oid oid)
 {
-	largeObject *npglo;
+	largeObject *large_obj;
 
-	if (!(npglo = PyObject_NEW(largeObject, &largeType)))
+	if (!(large_obj = PyObject_NEW(largeObject, &largeType)))
 		return NULL;
 
 	Py_XINCREF(pgcnx);
-	npglo->pgcnx = pgcnx;
-	npglo->lo_fd = -1;
-	npglo->lo_oid = oid;
+	large_obj->pgcnx = pgcnx;
+	large_obj->lo_fd = -1;
+	large_obj->lo_oid = oid;
 
-	return npglo;
+	return large_obj;
 }
 
 /* destructor */
@@ -2120,7 +2125,7 @@
 static PyObject *
 connSource(connObject *self, PyObject *noargs)
 {
-	sourceObject *npgobj;
+	sourceObject *source_obj;
 
 	/* checks validity */
 	if (!check_cnx_obj(self))
@@ -2127,17 +2132,17 @@
 		return NULL;
 
 	/* allocates new query object */
-	if (!(npgobj = PyObject_NEW(sourceObject, &sourceType)))
+	if (!(source_obj = PyObject_NEW(sourceObject, &sourceType)))
 		return NULL;
 
 	/* initializes internal parameters */
 	Py_XINCREF(self);
-	npgobj->pgcnx = self;
-	npgobj->result = NULL;
-	npgobj->valid = 1;
-	npgobj->arraysize = PG_ARRAYSIZE;
+	source_obj->pgcnx = self;
+	source_obj->result = NULL;
+	source_obj->valid = 1;
+	source_obj->arraysize = PG_ARRAYSIZE;
 
-	return (PyObject *) npgobj;
+	return (PyObject *) source_obj;
 }
 
 /* database query */
@@ -2146,11 +2151,11 @@
 static PyObject *
 _connQuery(connObject *self, PyObject *args, int prepared)
 {
-	PyObject	*query_obj;
+	PyObject	*query_str_obj;
 	PyObject	*param_obj = NULL;
 	char		*query;
 	PGresult	*result;
-	queryObject *npgobj;
+	queryObject *query_obj;
 	int			encoding,
 				status,
 				nparms = 0;
@@ -2162,7 +2167,7 @@
 	}
 
 	/* get query args */
-	if (!PyArg_ParseTuple(args, "O|O", &query_obj, &param_obj))
+	if (!PyArg_ParseTuple(args, "O|O", &query_str_obj, &param_obj))
 	{
 		return NULL;
 	}
@@ -2169,16 +2174,16 @@
 
 	encoding = PQclientEncoding(self->cnx);
 
-	if (PyBytes_Check(query_obj))
+	if (PyBytes_Check(query_str_obj))
 	{
-		query = PyBytes_AsString(query_obj);
-		query_obj = NULL;
+		query = PyBytes_AsString(query_str_obj);
+		query_str_obj = NULL;
 	}
-	else if (PyUnicode_Check(query_obj))
+	else if (PyUnicode_Check(query_str_obj))
 	{
-		query_obj = get_encoded_string(query_obj, encoding);
-		if (!query_obj) return NULL; /* pass the UnicodeEncodeError */
-		query = PyBytes_AsString(query_obj);
+		query_str_obj = get_encoded_string(query_str_obj, encoding);
+		if (!query_str_obj) return NULL; /* pass the UnicodeEncodeError */
+		query = PyBytes_AsString(query_str_obj);
 	}
 	else
 	{
@@ -2197,7 +2202,7 @@
 			"Method query() expects a sequence as second argument");
 		if (!param_obj)
 		{
-			Py_XDECREF(query_obj);
+			Py_XDECREF(query_str_obj);
 			return NULL;
 		}
 		nparms = (int)PySequence_Fast_GET_SIZE(param_obj);
@@ -2229,7 +2234,7 @@
 		if (!str || !parms)
 		{
 			PyMem_Free((void *)parms); PyMem_Free(str);
-			Py_XDECREF(query_obj); Py_XDECREF(param_obj);
+			Py_XDECREF(query_str_obj); Py_XDECREF(param_obj);
 			return PyErr_NoMemory();
 		}
 
@@ -2256,7 +2261,7 @@
 					PyMem_Free((void *)parms);
 					while (s != str) { s--; Py_DECREF(*s); }
 					PyMem_Free(str);
-					Py_XDECREF(query_obj);
+					Py_XDECREF(query_str_obj);
 					Py_XDECREF(param_obj);
 					/* pass the UnicodeEncodeError */
 					return NULL;
@@ -2272,7 +2277,7 @@
 					PyMem_Free((void *)parms);
 					while (s != str) { s--; Py_DECREF(*s); }
 					PyMem_Free(str);
-					Py_XDECREF(query_obj);
+					Py_XDECREF(query_str_obj);
 					Py_XDECREF(param_obj);
 					PyErr_SetString(PyExc_TypeError,
 						"Query parameter has no string representation");
@@ -2306,7 +2311,7 @@
 	}
 
 	/* we don't need the query and its params any more */
-	Py_XDECREF(query_obj);
+	Py_XDECREF(query_str_obj);
 	Py_XDECREF(param_obj);
 
 	/* checks result validity */
@@ -2368,15 +2373,25 @@
 		return NULL;			/* error detected on query */
 	}
 
-	if (!(npgobj = PyObject_NEW(queryObject, &queryType)))
+	if (!(query_obj = PyObject_NEW(queryObject, &queryType)))
 		return PyErr_NoMemory();
 
 	/* stores result and returns object */
 	Py_XINCREF(self);
-	npgobj->pgcnx = self;
-	npgobj->result = result;
-	npgobj->encoding = encoding;
-	return (PyObject *) npgobj;
+	query_obj->pgcnx = self;
+	query_obj->result = result;
+	query_obj->encoding = encoding;
+	query_obj->current_row = 0;
+	query_obj->max_row = PQntuples(result);
+	query_obj->num_fields = PQnfields(result);
+	query_obj->col_types = get_col_types(result, query_obj->num_fields);
+	if (!query_obj->col_types) {
+		Py_DECREF(query_obj);
+		Py_DECREF(self);
+		return NULL;
+	}
+
+	return (PyObject *) query_obj;
 }
 
 /* database query */
@@ -2481,13 +2496,18 @@
 	Py_END_ALLOW_THREADS
 	if (result && PQresultStatus(result) == PGRES_COMMAND_OK)
 	{
-		queryObject *npgobj = PyObject_NEW(queryObject, &queryType);
-		if (!npgobj)
+		queryObject *query_obj = PyObject_NEW(queryObject, &queryType);
+		if (!query_obj)
 			return PyErr_NoMemory();
 		Py_XINCREF(self);
-		npgobj->pgcnx = self;
-		npgobj->result = result;
-		return (PyObject *) npgobj;
+		query_obj->pgcnx = self;
+		query_obj->result = result;
+		query_obj->encoding = PQclientEncoding(self->cnx);
+		query_obj->current_row = 0;
+		query_obj->max_row = PQntuples(result);
+		query_obj->num_fields = PQnfields(result);
+		query_obj->col_types = get_col_types(result, query_obj->num_fields);
+		return (PyObject *) query_obj;
 	}
 	set_error(ProgrammingError, "Cannot describe prepared statement",
 		self->cnx, result);
@@ -4557,7 +4577,7 @@
 			   *pgpasswd;
 	int			pgport;
 	char		port_buffer[20];
-	connObject *npgobj;
+	connObject *conn_obj;
 
 	pghost = pgopt = pgdbname = pguser = pgpasswd = NULL;
 	pgport = -1;
@@ -4593,17 +4613,17 @@
 		pgpasswd = PyBytes_AsString(pg_default_passwd);
 #endif /* DEFAULT_VARS */
 
-	if (!(npgobj = PyObject_NEW(connObject, &connType)))
+	if (!(conn_obj = PyObject_NEW(connObject, &connType)))
 	{
 		set_error_msg(InternalError, "Can't create new connection object");
 		return NULL;
 	}
 
-	npgobj->valid = 1;
-	npgobj->cnx = NULL;
-	npgobj->date_format = date_format;
-	npgobj->cast_hook = NULL;
-	npgobj->notice_receiver = NULL;
+	conn_obj->valid = 1;
+	conn_obj->cnx = NULL;
+	conn_obj->date_format = date_format;
+	conn_obj->cast_hook = NULL;
+	conn_obj->notice_receiver = NULL;
 
 	if (pgport != -1)
 	{
@@ -4612,18 +4632,18 @@
 	}
 
 	Py_BEGIN_ALLOW_THREADS
-	npgobj->cnx = PQsetdbLogin(pghost, pgport == -1 ? NULL : port_buffer,
+	conn_obj->cnx = PQsetdbLogin(pghost, pgport == -1 ? NULL : port_buffer,
 		pgopt, NULL, pgdbname, pguser, pgpasswd);
 	Py_END_ALLOW_THREADS
 
-	if (PQstatus(npgobj->cnx) == CONNECTION_BAD)
+	if (PQstatus(conn_obj->cnx) == CONNECTION_BAD)
 	{
-		set_error(InternalError, "Cannot connect", npgobj->cnx, NULL);
-		Py_XDECREF(npgobj);
+		set_error(InternalError, "Cannot connect", conn_obj->cnx, NULL);
+		Py_XDECREF(conn_obj);
 		return NULL;
 	}
 
-	return (PyObject *) npgobj;
+	return (PyObject *) conn_obj;
 }
 
 static void
@@ -4630,6 +4650,8 @@
 queryDealloc(queryObject *self)
 {
 	Py_XDECREF(self->pgcnx);
+	if (self->col_types)
+		PyMem_Free(self->col_types);
 	if (self->result)
 		PQclear(self->result);
 
@@ -4637,48 +4659,46 @@
 }
 
 /* get number of rows */
-static char queryNTuples__doc__[] =
+static char queryNtuples__doc__[] =
 "ntuples() -- return number of tuples returned by query";
 
 static PyObject *
-queryNTuples(queryObject *self, PyObject *noargs)
+queryNtuples(queryObject *self, PyObject *noargs)
 {
-	return PyInt_FromLong((long) PQntuples(self->result));
+	return PyInt_FromLong(self->max_row);
 }
 
 /* list fields names from query result */
-static char queryListFields__doc__[] =
+static char queryListfields__doc__[] =
 "listfields() -- List field names from result";
 
 static PyObject *
-queryListFields(queryObject *self, PyObject *noargs)
+queryListfields(queryObject *self, PyObject *noargs)
 {
-	int			i,
-				n;
+	int			i;
 	char	   *name;
 	PyObject   *fieldstuple,
 			   *str;
 
 	/* builds tuple */
-	n = PQnfields(self->result);
-	fieldstuple = PyTuple_New(n);
-
-	for (i = 0; i < n; ++i)
-	{
-		name = PQfname(self->result, i);
-		str = PyStr_FromString(name);
-		PyTuple_SET_ITEM(fieldstuple, i, str);
+	fieldstuple = PyTuple_New(self->num_fields);
+	if (fieldstuple) {
+		for (i = 0; i < self->num_fields; ++i)
+		{
+			name = PQfname(self->result, i);
+			str = PyStr_FromString(name);
+			PyTuple_SET_ITEM(fieldstuple, i, str);
+		}
 	}
-
 	return fieldstuple;
 }
 
 /* get field name from last result */
-static char queryFieldName__doc__[] =
+static char queryFieldname__doc__[] =
 "fieldname(num) -- return name of field from result from its position";
 
 static PyObject *
-queryFieldName(queryObject *self, PyObject *args)
+queryFieldname(queryObject *self, PyObject *args)
 {
 	int		i;
 	char   *name;
@@ -4692,7 +4712,7 @@
 	}
 
 	/* checks number validity */
-	if (i >= PQnfields(self->result))
+	if (i >= self->num_fields)
 	{
 		PyErr_SetString(PyExc_ValueError, "Invalid field number");
 		return NULL;
@@ -4704,11 +4724,11 @@
 }
 
 /* gets fields number from name in last result */
-static char queryFieldNumber__doc__[] =
+static char queryFieldnum__doc__[] =
 "fieldnum(name) -- return position in query for field from its name";
 
 static PyObject *
-queryFieldNumber(queryObject *self, PyObject *args)
+queryFieldnum(queryObject *self, PyObject *args)
 {
 	int		num;
 	char   *name;
@@ -4731,206 +4751,269 @@
 	return PyInt_FromLong(num);
 }
 
-/* retrieves last result */
-static char queryGetResult__doc__[] =
+/* The __iter__() method of the queryObject.
+   This returns the default iterator yielding rows as tuples. */
+static PyObject* queryGetIter(queryObject *self)
+{
+	self->current_row = 0;
+	Py_INCREF(self);
+	return (PyObject*)self;
+}
+
+/* Return the value in the given column of the current row. */
+static PyObject *
+getValueInColumn(queryObject *self, int column)
+{
+	if (PQgetisnull(self->result, self->current_row, column))
+	{
+		Py_INCREF(Py_None);
+		return Py_None;
+	}
+
+	/* get the string representation of the value */
+	/* note: this is always null-terminated text format */
+	char   *s = PQgetvalue(self->result, self->current_row, column);
+	/* get the PyGreSQL type of the column */
+	int		type = self->col_types[column];
+	/* cast the string representation into a Python object */
+	if (type & PYGRES_ARRAY)
+		return cast_array(s,
+			PQgetlength(self->result, self->current_row, column),
+			self->encoding, type, NULL, 0);
+	if (type == PYGRES_BYTEA)
+		return cast_bytea_text(s);
+	if (type == PYGRES_OTHER)
+		return cast_other(s,
+			PQgetlength(self->result, self->current_row, column),
+			self->encoding,
+			PQftype(self->result, column), self->pgcnx->cast_hook);
+	if (type & PYGRES_TEXT)
+		return cast_sized_text(s,
+			PQgetlength(self->result, self->current_row, column),
+			self->encoding, type);
+	return cast_unsized_simple(s, type);
+}
+
+/* Return the current row as a tuple. */
+static PyObject *
+queryGetRowAsTuple(queryObject *self)
+{
+	PyObject   *row_tuple = NULL;
+	int			j;
+
+	if (!(row_tuple = PyTuple_New(self->num_fields))) return NULL;
+
+	for (j = 0; j < self->num_fields; ++j)
+	{
+		PyObject *val = getValueInColumn(self, j);
+		if (!val)
+		{
+			Py_DECREF(row_tuple); return NULL;
+		}
+		PyTuple_SET_ITEM(row_tuple, j, val);
+	}
+
+	return row_tuple;
+}
+
+/* The __next__() method of the queryObject.
+   Returns the current current row as a tuple and moves to the next one. */
+static PyObject *
+queryNext(queryObject *self, PyObject *noargs)
+{
+	PyObject   *row_tuple = NULL;
+
+	if (self->current_row >= self->max_row) {
+		PyErr_SetNone(PyExc_StopIteration);
+		return NULL;
+	}
+
+	row_tuple = queryGetRowAsTuple(self);
+	if (row_tuple) ++self->current_row;
+    return row_tuple;
+}
+
+/* Retrieves the last query result as a list of tuples. */
+static char queryGetresult__doc__[] =
 "getresult() -- Get the result of a query\n\n"
 "The result is returned as a list of rows, each one a tuple of fields\n"
 "in the order returned by the server.\n";
 
 static PyObject *
-queryGetResult(queryObject *self, PyObject *noargs)
+queryGetresult(queryObject *self, PyObject *noargs)
 {
-	PyObject   *reslist;
-	int			i, m, n, *col_types;
-	int			encoding = self->encoding;
+	PyObject   *result_list;
+	int			i;
 
-	/* stores result in tuple */
-	m = PQntuples(self->result);
-	n = PQnfields(self->result);
-	if (!(reslist = PyList_New(m))) return NULL;
+	if (!(result_list = PyList_New(self->max_row))) return NULL;
 
-	if (!(col_types = get_col_types(self->result, n))) return NULL;
-
-	for (i = 0; i < m; ++i)
+	for (i = self->current_row = 0; i < self->max_row; ++i)
 	{
-		PyObject   *rowtuple;
-		int			j;
-
-		if (!(rowtuple = PyTuple_New(n)))
+		PyObject   *row_tuple = queryNext(self, noargs);
+		if (!row_tuple)
 		{
-			Py_DECREF(reslist);
-			reslist = NULL;
-			goto exit;
+			Py_DECREF(result_list); return NULL;
 		}
+		PyList_SET_ITEM(result_list, i, row_tuple);
+	}
 
-		for (j = 0; j < n; ++j)
-		{
-			PyObject * val;
+	return result_list;
+}
 
-			if (PQgetisnull(self->result, i, j))
-			{
-				Py_INCREF(Py_None);
-				val = Py_None;
-			}
-			else /* not null */
-			{
-				/* get the string representation of the value */
-				/* note: this is always null-terminated text format */
-				char   *s = PQgetvalue(self->result, i, j);
-				/* get the PyGreSQL type of the column */
-				int		type = col_types[j];
+/* Return the current row as a dict. */
+static PyObject *
+queryGetRowAsDict(queryObject *self)
+{
+	PyObject   *row_dict = NULL;
+	int			j;
 
-				if (type & PYGRES_ARRAY)
-					val = cast_array(s, PQgetlength(self->result, i, j),
-						encoding, type, NULL, 0);
-				else if (type == PYGRES_BYTEA)
-					val = cast_bytea_text(s);
-				else if (type == PYGRES_OTHER)
-					val = cast_other(s,
-						PQgetlength(self->result, i, j), encoding,
-						PQftype(self->result, j), self->pgcnx->cast_hook);
-				else if (type & PYGRES_TEXT)
-					val = cast_sized_text(s, PQgetlength(self->result, i, j),
-						encoding, type);
-				else
-					val = cast_unsized_simple(s, type);
-			}
+	if (!(row_dict = PyDict_New())) return NULL;
 
-			if (!val)
-			{
-				Py_DECREF(reslist);
-				Py_DECREF(rowtuple);
-				reslist = NULL;
-				goto exit;
-			}
-
-			PyTuple_SET_ITEM(rowtuple, j, val);
+	for (j = 0; j < self->num_fields; ++j)
+	{
+		PyObject *val = getValueInColumn(self, j);
+		if (!val)
+		{
+			Py_DECREF(row_dict); return NULL;
 		}
+		PyDict_SetItemString(row_dict, PQfname(self->result, j), val);
+		Py_DECREF(val);
+	}
 
-		PyList_SET_ITEM(reslist, i, rowtuple);
+	return row_dict;
+}
+
+/* Return the current current row as a dict and move to the next one. */
+static PyObject *
+queryNextDict(queryObject *self, PyObject *noargs)
+{
+	PyObject   *row_dict = NULL;
+
+	if (self->current_row >= self->max_row) {
+		PyErr_SetNone(PyExc_StopIteration);
+		return NULL;
 	}
 
-exit:
-	PyMem_Free(col_types);
-
-	/* returns list */
-	return reslist;
+	row_dict = queryGetRowAsDict(self);
+	if (row_dict) ++self->current_row;
+    return row_dict;
 }
 
-/* retrieves last result as a list of dictionaries*/
-static char queryDictResult__doc__[] =
+/* Retrieve the last query result as a list of dictionaries. */
+static char queryDictresult__doc__[] =
 "dictresult() -- Get the result of a query\n\n"
 "The result is returned as a list of rows, each one a dictionary with\n"
-"the field names used as the labels.\n";
+"the field names used as the keys.\n";
 
 static PyObject *
-queryDictResult(queryObject *self, PyObject *noargs)
+queryDictresult(queryObject *self, PyObject *noargs)
 {
-	PyObject   *reslist;
-	int			i,
-				m,
-				n,
-			   *col_types;
-	int			encoding = self->encoding;
+	PyObject   *result_list;
+	int			i;
 
-	/* stores result in list */
-	m = PQntuples(self->result);
-	n = PQnfields(self->result);
-	if (!(reslist = PyList_New(m))) return NULL;
+	if (!(result_list = PyList_New(self->max_row))) return NULL;
 
-	if (!(col_types = get_col_types(self->result, n))) return NULL;
-
-	for (i = 0; i < m; ++i)
+	for (i = self->current_row = 0; i < self->max_row; ++i)
 	{
-		PyObject   *dict;
-		int			j;
-
-		if (!(dict = PyDict_New()))
+		PyObject   *row_dict = queryNextDict(self, noargs);
+		if (!row_dict)
 		{
-			Py_DECREF(reslist);
-			reslist = NULL;
-			goto exit;
+			Py_DECREF(result_list); return NULL;
 		}
+		PyList_SET_ITEM(result_list, i, row_dict);
+	}
 
-		for (j = 0; j < n; ++j)
-		{
-			PyObject * val;
+	return result_list;
+}
 
-			if (PQgetisnull(self->result, i, j))
-			{
-				Py_INCREF(Py_None);
-				val = Py_None;
-			}
-			else /* not null */
-			{
-				/* get the string representation of the value */
-				/* note: this is always null-terminated text format */
-				char   *s = PQgetvalue(self->result, i, j);
-				/* get the PyGreSQL type of the column */
-				int		type = col_types[j];
+/* retrieves last result as iterator of dictionaries */
+static char queryDictiter__doc__[] =
+"dictiter() -- Get the result of a query\n\n"
+"The result is returned as an iterator of rows, each one a a dictionary\n"
+"with the field names used as the keys.\n";
 
-				if (type & PYGRES_ARRAY)
-					val = cast_array(s, PQgetlength(self->result, i, j),
-						encoding, type, NULL, 0);
-				else if (type == PYGRES_BYTEA)
-					val = cast_bytea_text(s);
-				else if (type == PYGRES_OTHER)
-					val = cast_other(s,
-						PQgetlength(self->result, i, j), encoding,
-						PQftype(self->result, j), self->pgcnx->cast_hook);
-				else if (type & PYGRES_TEXT)
-					val = cast_sized_text(s, PQgetlength(self->result, i, j),
-						encoding, type);
-				else
-					val = cast_unsized_simple(s, type);
-			}
+static PyObject *
+queryDictiter(queryObject *self, PyObject *noargs)
+{
+	if (dictiter) {
+		return PyObject_CallFunction(dictiter, "(O)", self);
+	}
+	return queryGetIter(self);
+}
 
-			if (!val)
-			{
-				Py_DECREF(dict);
-				Py_DECREF(reslist);
-				reslist = NULL;
-				goto exit;
-			}
 
-			PyDict_SetItemString(dict, PQfname(self->result, j), val);
-			Py_DECREF(val);
-		}
+/* retrieves last result as list of named tuples */
+static char queryNamedresult__doc__[] =
+"namedresult() -- Get the result of a query\n\n"
+"The result is returned as a list of rows, each one a named tuple of fields\n"
+"in the order returned by the server.\n";
 
-		PyList_SET_ITEM(reslist, i, dict);
+static PyObject *
+queryNamedresult(queryObject *self, PyObject *noargs)
+{
+	if (namediter) {
+		PyObject* res = PyObject_CallFunction(namediter, "(O)", self);
+		if (res && PyList_Check(res))
+		 	return res;
+		PyObject *res_list = PySequence_List(res);
+		Py_DECREF(res);
+		return res_list;
 	}
+	return queryGetresult(self, noargs);
+}
 
-exit:
-	PyMem_Free(col_types);
+/* retrieves last result as iterator of named tuples */
+static char queryNamediter__doc__[] =
+"namediter() -- Get the result of a query\n\n"
+"The result is returned as an iterator of rows, each one a named tuple\n"
+"of fields in the order returned by the server.\n";
 
-	/* returns list */
-	return reslist;
+static PyObject *
+queryNamediter(queryObject *self, PyObject *noargs)
+{
+	if (namediter) {
+		PyObject* res = PyObject_CallFunction(namediter, "(O)", self);
+		if (res && !PyList_Check(res))
+			return res;
+		PyObject* res_iter = (Py_TYPE(res)->tp_iter)((PyObject *)self);
+		Py_DECREF(res);
+		return res_iter;
+	}
+	return queryGetIter(self);
 }
 
-/* retrieves last result as named tuples */
-static char queryNamedResult__doc__[] =
-"namedresult() -- Get the result of a query\n\n"
-"The result is returned as a list of rows, each one a tuple of fields\n"
-"in the order returned by the server.\n";
+/* Return length of a query object. */
+static Py_ssize_t
+queryLen(PyObject *self)
+{
+	PyObject   *tmp;
+	long		len;
 
+	tmp = PyLong_FromLong(((queryObject*)self)->max_row);
+	len = PyLong_AsSsize_t(tmp);
+	Py_DECREF(tmp);
+	return len;
+}
+
+/* Return given item from a query object. */
 static PyObject *
-queryNamedResult(queryObject *self, PyObject *noargs)
+queryGetItem(PyObject *self, Py_ssize_t i)
 {
-	PyObject   *ret;
+	queryObject	   *q = (queryObject *)self;
+	PyObject	   *tmp;
+	long			row;
 
-	if (namedresult)
-	{
-		ret = PyObject_CallFunction(namedresult, "(O)", self);
+	tmp = PyLong_FromSize_t(i);
+ 	row = PyLong_AsLong(tmp);
+	Py_DECREF(tmp);
 
-		if (ret == NULL)
-			return NULL;
-		}
-	else
-	{
-		ret = queryGetResult(self, NULL);
+	if (row < 0 || row >= q->max_row) {
+		PyErr_SetNone(PyExc_IndexError);
+		return NULL;
 	}
 
-	return ret;
+	q->current_row = row;
+	return queryGetRowAsTuple(q);
 }
 
 /* gets notice object attributes */
@@ -5051,54 +5134,72 @@
 
 /* query object methods */
 static struct PyMethodDef queryMethods[] = {
-	{"getresult", (PyCFunction) queryGetResult, METH_NOARGS,
-			queryGetResult__doc__},
-	{"dictresult", (PyCFunction) queryDictResult, METH_NOARGS,
-			queryDictResult__doc__},
-	{"namedresult", (PyCFunction) queryNamedResult, METH_NOARGS,
-			queryNamedResult__doc__},
-	{"fieldname", (PyCFunction) queryFieldName, METH_VARARGS,
-			 queryFieldName__doc__},
-	{"fieldnum", (PyCFunction) queryFieldNumber, METH_VARARGS,
-			queryFieldNumber__doc__},
-	{"listfields", (PyCFunction) queryListFields, METH_NOARGS,
-			queryListFields__doc__},
-	{"ntuples", (PyCFunction) queryNTuples, METH_NOARGS,
-			queryNTuples__doc__},
+	{"getresult", (PyCFunction) queryGetresult, METH_NOARGS,
+			queryGetresult__doc__},
+	{"dictresult", (PyCFunction) queryDictresult, METH_NOARGS,
+			queryDictresult__doc__},
+	{"dictiter", (PyCFunction) queryDictiter, METH_NOARGS,
+			queryDictiter__doc__},
+	{"namedresult", (PyCFunction) queryNamedresult, METH_NOARGS,
+			queryNamedresult__doc__},
+	{"namediter", (PyCFunction) queryNamediter, METH_NOARGS,
+			queryNamediter__doc__},
+	{"fieldname", (PyCFunction) queryFieldname, METH_VARARGS,
+			 queryFieldname__doc__},
+	{"fieldnum", (PyCFunction) queryFieldnum, METH_VARARGS,
+			queryFieldnum__doc__},
+	{"listfields", (PyCFunction) queryListfields, METH_NOARGS,
+			queryListfields__doc__},
+	{"ntuples", (PyCFunction) queryNtuples, METH_NOARGS,
+			queryNtuples__doc__},
 	{NULL, NULL}
 };
 
+/* query sequence protocol methods */
+static PySequenceMethods querySequenceMethods = {
+	(lenfunc) queryLen,				/* sq_length */
+	0,								/* sq_concat */
+	0,								/* sq_repeat */
+	(ssizeargfunc) queryGetItem,	/* sq_item */
+	0,								/* sq_ass_item */
+	0,								/* sq_contains */
+	0,								/* sq_inplace_concat */
+	0,								/* sq_inplace_repeat */
+};
+
+
 /* query type definition */
 static PyTypeObject queryType = {
 	PyVarObject_HEAD_INIT(NULL, 0)
-	"pg.Query",						/* tp_name */
-	sizeof(queryObject),			/* tp_basicsize */
-	0,								/* tp_itemsize */
+	"pg.Query",					/* tp_name */
+	sizeof(queryObject),		/* tp_basicsize */
+	0,							/* tp_itemsize */
 	/* methods */
-	(destructor) queryDealloc,		/* tp_dealloc */
-	0,								/* tp_print */
-	0,								/* tp_getattr */
-	0,								/* tp_setattr */
-	0,								/* tp_compare */
-	0,								/* tp_repr */
-	0,								/* tp_as_number */
-	0,								/* tp_as_sequence */
-	0,								/* tp_as_mapping */
-	0,								/* tp_hash */
-	0,								/* tp_call */
-	(reprfunc) queryStr,			/* tp_str */
-	PyObject_GenericGetAttr,		/* tp_getattro */
-	0,								/* tp_setattro */
-	0,								/* tp_as_buffer */
-	Py_TPFLAGS_DEFAULT,				/* tp_flags */
-	0,								/* tp_doc */
-	0,								/* tp_traverse */
-	0,								/* tp_clear */
-	0,								/* tp_richcompare */
-	0,								/* tp_weaklistoffset */
-	0,								/* tp_iter */
-	0,								/* tp_iternext */
-	queryMethods,					/* tp_methods */
+	(destructor) queryDealloc,	/* tp_dealloc */
+	0,							/* tp_print */
+	0,							/* tp_getattr */
+	0,							/* tp_setattr */
+	0,							/* tp_compare */
+	0,							/* tp_repr */
+	0,							/* tp_as_number */
+	&querySequenceMethods,		/* tp_as_sequence */
+	0,							/* tp_as_mapping */
+	0,							/* tp_hash */
+	0,							/* tp_call */
+	(reprfunc) queryStr,		/* tp_str */
+	PyObject_GenericGetAttr,	/* tp_getattro */
+	0,							/* tp_setattro */
+	0,							/* tp_as_buffer */
+	Py_TPFLAGS_DEFAULT
+		|Py_TPFLAGS_HAVE_ITER,	/* tp_flags */
+	0,							/* tp_doc */
+	0,							/* tp_traverse */
+	0,							/* tp_clear */
+	0,							/* tp_richcompare */
+	0,							/* tp_weaklistoffset */
+	(getiterfunc)queryGetIter,	/* tp_iter */
+	(iternextfunc)queryNext,	/* tp_iternext */
+	queryMethods,				/* tp_methods */
 };
 
 /* --------------------------------------------------------------------- */
@@ -5497,16 +5598,58 @@
 	return ret;
 }
 
+/* get dict result factory */
+static char pgGetDictiter__doc__[] =
+"get_dictiter() -- get the generator used for getting dict results";
+
+static PyObject *
+pgGetDictiter(PyObject *self, PyObject *noargs)
+{
+	PyObject *ret;
+
+	ret = dictiter ? dictiter : Py_None;
+	Py_INCREF(ret);
+
+	return ret;
+}
+
+/* set dict result factory */
+static char pgSetDictiter__doc__[] =
+"set_dictiter(func) -- set a generator to be used for getting dict results";
+
+static PyObject *
+pgSetDictiter(PyObject *self, PyObject *func)
+{
+	PyObject *ret = NULL;
+
+	if (func == Py_None)
+	{
+		Py_XDECREF(dictiter); dictiter = NULL;
+		Py_INCREF(Py_None); ret = Py_None;
+	}
+	else if (PyCallable_Check(func))
+	{
+		Py_XINCREF(func); Py_XDECREF(dictiter); dictiter = func;
+		Py_INCREF(Py_None); ret = Py_None;
+	}
+	else
+		PyErr_SetString(PyExc_TypeError,
+			"Function set_dictiter() expects"
+			 " a callable or None as argument");
+
+	return ret;
+}
+
 /* get named result factory */
-static char pgGetNamedresult__doc__[] =
-"get_namedresult() -- get the function used for getting named results";
+static char pgGetNamediter__doc__[] =
+"get_namediter() -- get the generator used for getting named results";
 
 static PyObject *
-pgGetNamedresult(PyObject *self, PyObject *noargs)
+pgGetNamediter(PyObject *self, PyObject *noargs)
 {
 	PyObject *ret;
 
-	ret = namedresult ? namedresult : Py_None;
+	ret = namediter ? namediter : Py_None;
 	Py_INCREF(ret);
 
 	return ret;
@@ -5513,27 +5656,27 @@
 }
 
 /* set named result factory */
-static char pgSetNamedresult__doc__[] =
-"set_namedresult(func) -- set a function to be used for getting named results";
+static char pgSetNamediter__doc__[] =
+"set_namediter(func) -- set a generator to be used for getting named results";
 
 static PyObject *
-pgSetNamedresult(PyObject *self, PyObject *func)
+pgSetNamediter(PyObject *self, PyObject *func)
 {
 	PyObject *ret = NULL;
 
 	if (func == Py_None)
 	{
-		Py_XDECREF(namedresult); namedresult = NULL;
+		Py_XDECREF(namediter); namediter = NULL;
 		Py_INCREF(Py_None); ret = Py_None;
 	}
 	else if (PyCallable_Check(func))
 	{
-		Py_XINCREF(func); Py_XDECREF(namedresult); namedresult = func;
+		Py_XINCREF(func); Py_XDECREF(namediter); namediter = func;
 		Py_INCREF(Py_None); ret = Py_None;
 	}
 	else
 		PyErr_SetString(PyExc_TypeError,
-			"Function set_namedresult() expects"
+			"Function set_namediter() expects"
 			 " a callable or None as argument");
 
 	return ret;
@@ -6027,10 +6170,19 @@
 		pgGetByteaEscaped__doc__},
 	{"set_bytea_escaped", (PyCFunction) pgSetByteaEscaped, METH_VARARGS,
 		pgSetByteaEscaped__doc__},
-	{"get_namedresult", (PyCFunction) pgGetNamedresult, METH_NOARGS,
-			pgGetNamedresult__doc__},
-	{"set_namedresult", (PyCFunction) pgSetNamedresult, METH_O,
-			pgSetNamedresult__doc__},
+	{"get_dictiter", (PyCFunction) pgGetDictiter, METH_NOARGS,
+			pgGetDictiter__doc__},
+	{"set_dictiter", (PyCFunction) pgSetDictiter, METH_O,
+			pgSetDictiter__doc__},
+	{"get_namediter", (PyCFunction) pgGetNamediter, METH_NOARGS,
+			pgGetNamediter__doc__},
+	{"set_namediter", (PyCFunction) pgSetNamediter, METH_O,
+			pgSetNamediter__doc__},
+	/* get/set_namedresult is deprecated, use get/set_namediter */
+	{"get_namedresult", (PyCFunction) pgGetNamediter, METH_NOARGS,
+			pgGetNamediter__doc__},
+	{"set_namedresult", (PyCFunction) pgSetNamediter, METH_O,
+			pgSetNamediter__doc__},
 	{"get_jsondecode", (PyCFunction) pgGetJsondecode, METH_NOARGS,
 			pgGetJsondecode__doc__},
 	{"set_jsondecode", (PyCFunction) pgSetJsondecode, METH_O,

Modified: trunk/py3c.h (977 => 978)


--- trunk/py3c.h	2019-04-20 12:40:21 UTC (rev 977)
+++ trunk/py3c.h	2019-04-21 19:07:48 UTC (rev 978)
@@ -52,6 +52,10 @@
     PyMODINIT_FUNC PyInit_ ## name(void); \
     PyMODINIT_FUNC PyInit_ ## name(void)
 
+/* Other */
+
+#define Py_TPFLAGS_HAVE_ITER 0 // not needed in Python 3
+
 #else
 
 /***** Python 2 *****/

Modified: trunk/tests/test_classic_connection.py (977 => 978)


--- trunk/tests/test_classic_connection.py	2019-04-20 12:40:21 UTC (rev 977)
+++ trunk/tests/test_classic_connection.py	2019-04-21 19:07:48 UTC (rev 978)
@@ -18,7 +18,7 @@
 import time
 import os
 
-from collections import namedtuple
+from collections import namedtuple, Iterable
 from decimal import Decimal
 
 import pg  # the module under test
@@ -549,7 +549,7 @@
         self.assertIsInstance(r, int)
         self.assertEqual(r, 3)
 
-    def testNtuples(self):
+    def testNtuples(self):  # deprecated
         q = "select 1 where false"
         r = self.c.query(q).ntuples()
         self.assertIsInstance(r, int)
@@ -565,6 +565,16 @@
         self.assertIsInstance(r, int)
         self.assertEqual(r, 6)
 
+    def testLen(self):
+        q = "select 1 where false"
+        self.assertEqual(len(self.c.query(q)), 0)
+        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
+            " union select 5 as a, 6 as b, 7 as c, 8 as d")
+        self.assertEqual(len(self.c.query(q)), 2)
+        q = ("select 1 union select 2 union select 3"
+            " union select 4 union select 5 union select 6")
+        self.assertEqual(len(self.c.query(q)), 6)
+
     def testQuery(self):
         query = self.c.query
         query("drop table if exists test_table")
@@ -1098,6 +1108,127 @@
         self.assert_proper_cast('{}', 'json', dict)
 
 
+class TestQueryIterator(unittest.TestCase):
+    """Test the query operating as an iterator."""
+
+    def setUp(self):
+        self.c = connect()
+
+    def tearDown(self):
+        self.c.close()
+
+    def testLen(self):
+        r = self.c.query("select generate_series(3,7)")
+        self.assertEqual(len(r), 5)
+
+    def testGetItem(self):
+        r = self.c.query("select generate_series(7,9)")
+        self.assertEqual(r[0], (7,))
+        self.assertEqual(r[1], (8,))
+        self.assertEqual(r[2], (9,))
+
+    def testGetItemWithNegativeIndex(self):
+        r = self.c.query("select generate_series(7,9)")
+        self.assertEqual(r[-1], (9,))
+        self.assertEqual(r[-2], (8,))
+        self.assertEqual(r[-3], (7,))
+
+    def testGetItemOutOfRange(self):
+        r = self.c.query("select generate_series(7,9)")
+        self.assertRaises(IndexError, r.__getitem__, 3)
+
+    def testIterate(self):
+        r = self.c.query("select generate_series(3,5)")
+        self.assertNotIsInstance(r, (list, tuple))
+        self.assertIsInstance(r, Iterable)
+        self.assertEqual(list(r), [(3,), (4,), (5,)])
+        self.assertIsInstance(r[1], tuple)
+
+    def testIterateTwice(self):
+        r = self.c.query("select generate_series(3,5)")
+        for i in range(2):
+            self.assertEqual(list(r), [(3,), (4,), (5,)])
+
+    def testIterateTwoColumns(self):
+        r = self.c.query("select 1,2 union select 3,4")
+        self.assertIsInstance(r, Iterable)
+        self.assertEqual(list(r), [(1, 2), (3, 4)])
+
+    def testNext(self):
+        r = self.c.query("select generate_series(7,9)")
+        self.assertEqual(next(r), (7,))
+        self.assertEqual(next(r), (8,))
+        self.assertEqual(next(r), (9,))
+        self.assertRaises(StopIteration, next, r)
+
+    def testContains(self):
+        r = self.c.query("select generate_series(7,9)")
+        self.assertIn((8,), r)
+        self.assertNotIn((5,), r)
+
+    def testNamedIterate(self):
+        r = self.c.query("select generate_series(3,5) as number").namediter()
+        self.assertNotIsInstance(r, (list, tuple))
+        self.assertIsInstance(r, Iterable)
+        r = list(r)
+        self.assertEqual(r, [(3,), (4,), (5,)])
+        self.assertIsInstance(r[1], tuple)
+        self.assertEqual(r[1]._fields, ('number',))
+        self.assertEqual(r[1].number, 4)
+
+    def testNamedIterateTwoColumns(self):
+        r = self.c.query("select 1 as one, 2 as two"
+            " union select 3 as one, 4 as two").namediter()
+        self.assertIsInstance(r, Iterable)
+        r = list(r)
+        self.assertEqual(r, [(1, 2), (3, 4)])
+        self.assertEqual(r[0]._fields, ('one', 'two'))
+        self.assertEqual(r[0].one, 1)
+        self.assertEqual(r[1]._fields, ('one', 'two'))
+        self.assertEqual(r[1].two, 4)
+
+    def testNamedNext(self):
+        r = self.c.query("select generate_series(7,9) as number").namediter()
+        self.assertEqual(next(r), (7,))
+        self.assertEqual(next(r), (8,))
+        n = next(r)
+        self.assertEqual(n._fields, ('number',))
+        self.assertEqual(n.number, 9)
+        self.assertRaises(StopIteration, next, r)
+
+    def testNamedContains(self):
+        r = self.c.query("select generate_series(7,9)").namediter()
+        self.assertIn((8,), r)
+        self.assertNotIn((5,), r)
+
+    def testDictIterate(self):
+        r = self.c.query("select generate_series(3,5) as n").dictiter()
+        self.assertNotIsInstance(r, (list, tuple))
+        self.assertIsInstance(r, Iterable)
+        r = list(r)
+        self.assertEqual(r, [dict(n=3), dict(n=4), dict(n=5)])
+        self.assertIsInstance(r[1], dict)
+
+    def testDictIterateTwoColumns(self):
+        r = self.c.query("select 1 as one, 2 as two"
+            " union select 3 as one, 4 as two").dictiter()
+        self.assertIsInstance(r, Iterable)
+        r = list(r)
+        self.assertEqual(r, [dict(_one_=1, two=2), dict(_one_=3, two=4)])
+
+    def testDictNext(self):
+        r = self.c.query("select generate_series(7,9) as n").dictiter()
+        self.assertEqual(next(r), dict(n=7))
+        self.assertEqual(next(r), dict(n=8))
+        self.assertEqual(next(r), dict(n=9))
+        self.assertRaises(StopIteration, next, r)
+
+    def testNamedContains(self):
+        r = self.c.query("select generate_series(7,9) as n").dictiter()
+        self.assertIn(dict(n=8), r)
+        self.assertNotIn(dict(n=5), r)
+
+
 class TestInserttable(unittest.TestCase):
     """Test inserttable method."""
 
@@ -1614,7 +1745,7 @@
         else:
             self.skipTest("cannot set English money locale")
         try:
-            r = query(select_money)
+            query(select_money)
         except pg.DataError:
             # this can happen if the currency signs cannot be
             # converted using the encoding of the test database
@@ -1621,39 +1752,35 @@
             self.skipTest("database does not support English money")
         pg.set_decimal_point(None)
         try:
-            r = r.getresult()[0][0]
+            r = query(select_money).getresult()[0][0]
         finally:
             pg.set_decimal_point(point)
         self.assertIsInstance(r, str)
         self.assertIn(r, en_money)
-        r = query(select_money)
         pg.set_decimal_point('')
         try:
-            r = r.getresult()[0][0]
+            r = query(select_money).getresult()[0][0]
         finally:
             pg.set_decimal_point(point)
         self.assertIsInstance(r, str)
         self.assertIn(r, en_money)
-        r = query(select_money)
         pg.set_decimal_point('.')
         try:
-            r = r.getresult()[0][0]
+            r = query(select_money).getresult()[0][0]
         finally:
             pg.set_decimal_point(point)
         self.assertIsInstance(r, d)
         self.assertEqual(r, proper_money)
-        r = query(select_money)
         pg.set_decimal_point(',')
         try:
-            r = r.getresult()[0][0]
+            r = query(select_money).getresult()[0][0]
         finally:
             pg.set_decimal_point(point)
         self.assertIsInstance(r, d)
         self.assertEqual(r, bad_money)
-        r = query(select_money)
         pg.set_decimal_point("'")
         try:
-            r = r.getresult()[0][0]
+            r = query(select_money).getresult()[0][0]
         finally:
             pg.set_decimal_point(point)
         self.assertIsInstance(r, d)
@@ -1670,43 +1797,39 @@
             self.skipTest("cannot set German money locale")
         select_money = select_money.replace('.', ',')
         try:
-            r = query(select_money)
+            query(select_money)
         except pg.DataError:
             self.skipTest("database does not support English money")
         pg.set_decimal_point(None)
         try:
-            r = r.getresult()[0][0]
+            r = query(select_money).getresult()[0][0]
         finally:
             pg.set_decimal_point(point)
         self.assertIsInstance(r, str)
         self.assertIn(r, de_money)
-        r = query(select_money)
         pg.set_decimal_point('')
         try:
-            r = r.getresult()[0][0]
+            r = query(select_money).getresult()[0][0]
         finally:
             pg.set_decimal_point(point)
         self.assertIsInstance(r, str)
         self.assertIn(r, de_money)
-        r = query(select_money)
         pg.set_decimal_point(',')
         try:
-            r = r.getresult()[0][0]
+            r = query(select_money).getresult()[0][0]
         finally:
             pg.set_decimal_point(point)
         self.assertIsInstance(r, d)
         self.assertEqual(r, proper_money)
-        r = query(select_money)
         pg.set_decimal_point('.')
         try:
-            r = r.getresult()[0][0]
+            r = query(select_money).getresult()[0][0]
         finally:
             pg.set_decimal_point(point)
         self.assertEqual(r, bad_money)
-        r = query(select_money)
         pg.set_decimal_point("'")
         try:
-            r = r.getresult()[0][0]
+            r = query(select_money).getresult()[0][0]
         finally:
             pg.set_decimal_point(point)
         self.assertEqual(r, bad_money)
@@ -1794,18 +1917,16 @@
         r = r.getresult()[0][0]
         self.assertIsInstance(r, bool)
         self.assertEqual(r, True)
-        r = query("select true::bool")
         pg.set_bool(False)
         try:
-            r = r.getresult()[0][0]
+            r = query("select true::bool").getresult()[0][0]
         finally:
             pg.set_bool(use_bool)
         self.assertIsInstance(r, str)
         self.assertIs(r, 't')
-        r = query("select true::bool")
         pg.set_bool(True)
         try:
-            r = r.getresult()[0][0]
+            r = query("select true::bool").getresult()[0][0]
         finally:
             pg.set_bool(use_bool)
         self.assertIsInstance(r, bool)
@@ -1858,30 +1979,127 @@
         r = r.getresult()[0][0]
         self.assertIsInstance(r, bytes)
         self.assertEqual(r, b'data')
-        r = query("select 'data'::bytea")
         pg.set_bytea_escaped(True)
         try:
-            r = r.getresult()[0][0]
+            r = query("select 'data'::bytea").getresult()[0][0]
         finally:
             pg.set_bytea_escaped(bytea_escaped)
         self.assertIsInstance(r, str)
         self.assertEqual(r, '\\x64617461')
-        r = query("select 'data'::bytea")
         pg.set_bytea_escaped(False)
         try:
-            r = r.getresult()[0][0]
+            r = query("select 'data'::bytea").getresult()[0][0]
         finally:
             pg.set_bytea_escaped(bytea_escaped)
         self.assertIsInstance(r, bytes)
         self.assertEqual(r, b'data')
 
-    def testGetNamedresult(self):
+    def testGetDictditer(self):
+        dictiter = pg.get_dictiter()
+        # error if a parameter is passed
+        self.assertRaises(TypeError, pg.get_dictiter, dictiter)
+        self.assertIs(dictiter, pg._dictiter)  # the default setting
+
+    def testSetDictiter(self):
+        dictiter = pg.get_dictiter()
+        self.assertTrue(callable(dictiter))
+
+        query = self.c.query
+
+        r = query("select 1 as x, 2 as y").dictiter()
+        self.assertNotIsInstance(r, list)
+        r = next(r)
+        self.assertIsInstance(r, dict)
+        self.assertEqual(r, dict(x=1, y=2))
+
+        def listiter(q):
+            for row in q:
+                yield list(row)
+
+        pg.set_dictiter(listiter)
+        try:
+            r = pg.get_dictiter()
+            self.assertIs(r, listiter)
+            r = query("select 1 as x, 2 as y").dictiter()
+            self.assertNotIsInstance(r, list)
+            r = next(r)
+            self.assertIsInstance(r, list)
+            self.assertEqual(r, [1, 2])
+            self.assertNotIsInstance(r, dict)
+        finally:
+            pg.set_dictiter(dictiter)
+
+        r = pg.get_dictiter()
+        self.assertIs(r, dictiter)
+
+    def testGetNamediter(self):
+        namediter = pg.get_namediter()
+        # error if a parameter is passed
+        self.assertRaises(TypeError, pg.get_namediter, namediter)
+        self.assertIs(namediter, pg._namediter)  # the default setting
+
+    def testSetNamediter(self):
+        namediter = pg.get_namediter()
+        self.assertTrue(callable(namediter))
+
+        query = self.c.query
+
+        r = query("select 1 as x, 2 as y").namediter()
+        self.assertNotIsInstance(r, list)
+        r = next(r)
+        self.assertIsInstance(r, tuple)
+        self.assertEqual(r, (1, 2))
+        self.assertIsNot(type(r), tuple)
+        self.assertEqual(r._fields, ('x', 'y'))
+        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
+        self.assertEqual(r.__class__.__name__, 'Row')
+        r = query("select 1 as x, 2 as y").namedresult()
+        self.assertIsInstance(r, list)
+        r = r[0]
+        self.assertIsInstance(r, tuple)
+        self.assertEqual(r, (1, 2))
+        self.assertIsNot(type(r), tuple)
+        self.assertEqual(r._fields, ('x', 'y'))
+        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
+        self.assertEqual(r.__class__.__name__, 'Row')
+
+        def listiter(q):
+            for row in q:
+                yield list(row)
+
+        pg.set_namediter(listiter)
+        try:
+            r = pg.get_namediter()
+            self.assertIs(r, listiter)
+            r = query("select 1 as x, 2 as y").namediter()
+            self.assertNotIsInstance(r, list)
+            r = next(r)
+            self.assertIsInstance(r, list)
+            self.assertEqual(r, [1, 2])
+            self.assertIsNot(type(r), tuple)
+            self.assertFalse(hasattr(r, '_fields'))
+            self.assertNotEqual(r.__class__.__name__, 'Row')
+            r = query("select 1 as x, 2 as y").namedresult()
+            self.assertIsInstance(r, list)
+            r = r[0]
+            self.assertIsInstance(r, list)
+            self.assertEqual(r, [1, 2])
+            self.assertIsNot(type(r), tuple)
+            self.assertFalse(hasattr(r, '_fields'))
+            self.assertNotEqual(r.__class__.__name__, 'Row')
+        finally:
+            pg.set_namediter(namediter)
+
+        r = pg.get_namediter()
+        self.assertIs(r, namediter)
+
+    def testGetNamedresult(self):  # deprecated
         namedresult = pg.get_namedresult()
         # error if a parameter is passed
         self.assertRaises(TypeError, pg.get_namedresult, namedresult)
-        self.assertIs(namedresult, pg._namedresult)  # the default setting
+        self.assertIs(namedresult, pg._namediter)  # the default setting
 
-    def testSetNamedresult(self):
+    def testSetNamedresult(self):  # deprecated
         namedresult = pg.get_namedresult()
         self.assertTrue(callable(namedresult))
 

Modified: trunk/tests/test_classic_functions.py (977 => 978)


--- trunk/tests/test_classic_functions.py	2019-04-20 12:40:21 UTC (rev 977)
+++ trunk/tests/test_classic_functions.py	2019-04-21 19:07:48 UTC (rev 978)
@@ -1005,12 +1005,54 @@
         self.assertIsInstance(r, bool)
         self.assertIs(r, bytea_escaped)
 
-    def testGetNamedresult(self):
+    def testGetDictiter(self):
+        r = pg.get_dictiter()
+        self.assertTrue(callable(r))
+        self.assertIs(r, pg._dictiter)
+
+    def testSetDictiter(self):
+        dictiter = pg.get_dictiter()
+        try:
+            pg.set_dictiter(None)
+            r = pg.get_dictiter()
+            self.assertIsNone(r)
+            f = lambda q: q
+            pg.set_dictiter(f)
+            r = pg.get_dictiter()
+            self.assertIs(r, f)
+            self.assertRaises(TypeError, pg.set_dictiter, 'invalid')
+        finally:
+            pg.set_dictiter(dictiter)
+        r = pg.get_dictiter()
+        self.assertIs(r, dictiter)
+
+    def testGetNamediter(self):
+        r = pg.get_namediter()
+        self.assertTrue(callable(r))
+        self.assertIs(r, pg._namediter)
+
+    def testSetNamediter(self):
+        namediter = pg.get_namediter()
+        try:
+            pg.set_namediter(None)
+            r = pg.get_namediter()
+            self.assertIsNone(r)
+            f = lambda q: q
+            pg.set_namediter(f)
+            r = pg.get_namediter()
+            self.assertIs(r, f)
+            self.assertRaises(TypeError, pg.set_namediter, 'invalid')
+        finally:
+            pg.set_namediter(namediter)
+        r = pg.get_namediter()
+        self.assertIs(r, namediter)
+
+    def testGetNamedresult(self):  # deprecated
         r = pg.get_namedresult()
         self.assertTrue(callable(r))
-        self.assertIs(r, pg._namedresult)
+        self.assertIs(r, pg._namediter)
 
-    def testSetNamedresult(self):
+    def testSetNamedresult(self):  # deprecated
         namedresult = pg.get_namedresult()
         try:
             pg.set_namedresult(None)
_______________________________________________
PyGreSQL mailing list
PyGreSQL@Vex.Net
https://mail.vex.net/mailman/listinfo/pygresql

Reply via email to