Author: cito
Date: Wed Jan 27 18:09:18 2016
New Revision: 791

Log:
Add support for composite types

Added a fast parser for the composite type input/output syntax, which is
similar to the already existing parser for the array input/output syntax.

The pgdb module now makes use of this parser, converting in both directions
between PostgreSQL records (composite types) and Python (named) tuples.

Modified:
   trunk/docs/contents/changelog.rst
   trunk/pgdb.py
   trunk/pgmodule.c
   trunk/tests/test_classic_functions.py
   trunk/tests/test_dbapi20.py

Modified: trunk/docs/contents/changelog.rst
==============================================================================
--- trunk/docs/contents/changelog.rst   Wed Jan 27 13:26:41 2016        (r790)
+++ trunk/docs/contents/changelog.rst   Wed Jan 27 18:09:18 2016        (r791)
@@ -25,9 +25,14 @@
   are now named tuples, i.e. their elements can be also accessed by name.
   The column names and types can now also be requested through the
   colnames and coltypes attributes, which are not part of DB-API 2 though.
-- If you pass a list as one of the parameters to a DB-API 2 cursor, it is
-  now automatically bound as PostgreSQL ARRAY. If you pass a tuple, then
-  it will be bound as a PostgreSQL ROW expression.
+- If you pass a Python list as one of the parameters to a DB-API 2 cursor,
+  it is now automatically bound as a PostgreSQL array. If you pass a Python
+  tuple, it is bound as a PostgreSQL composite type. Inversely, if a query
+  returns a PostgreSQL array, it is passed to Python as a list, and if it
+  returns a PostgreSQL composite type, it is passed to Python as a (named)
+  tuple. PyGreSQL uses the special input and output syntax for PostgreSQL
+  arrays and composite types in all of these cases. Anonymous composite
+  types are returned as ordinary (unnamed) tuples with string values.
 - Re-activated the shortcut methods of the DB-API connection since they
   can be handy when doing experiments or writing quick scripts. We keep
   them undocumented though and discourage using them in production.

Modified: trunk/pgdb.py
==============================================================================
--- trunk/pgdb.py       Wed Jan 27 13:26:41 2016        (r790)
+++ trunk/pgdb.py       Wed Jan 27 18:09:18 2016        (r791)
@@ -136,12 +136,15 @@
         lambda v: v in '0123456789.-', value)))
 
 
-_cast = {'bool': _cast_bool, 'bytea': unescape_bytea,
+_cast = {'char': str, 'bpchar': str, 'name': str,
+    'text': str, 'varchar': str,
+    'bool': _cast_bool, 'bytea': unescape_bytea,
     'int2': int, 'int4': int, 'serial': int,
     'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
     'oid': long, 'oid8': long,
     'float4': float, 'float8': float,
-    'numeric': Decimal, 'money': _cast_money}
+    'numeric': Decimal, 'money': _cast_money,
+    'record': cast_record}
 
 
 def _db_error(msg, cls=DatabaseError):
@@ -183,10 +186,14 @@
             if not '.' in key and not '"' in key:
                 key = '"%s"' % key
             oid = "'%s'::regtype" % self._escape_string(key)
-        self._src.execute("SELECT oid, typname,"
-             " typlen, typtype, typcategory, typdelim, typrelid"
-            " FROM pg_type WHERE oid=%s" % oid)
-        res = self._src.fetch(1)
+        try:
+            self._src.execute("SELECT oid, typname,"
+                 " typlen, typtype, typcategory, typdelim, typrelid"
+                " FROM pg_type WHERE oid=%s" % oid)
+        except ProgrammingError:
+            res = None
+        else:
+            res = self._src.fetch(1)
         if not res:
             raise KeyError('Type %s could not be found' % key)
         res = list(res[0])
@@ -197,37 +204,69 @@
         self[res.oid] = self[res.name] = res
         return res
 
+    def get(self, key, default=None):
+        """Get the type even if it is not cached."""
+        try:
+            return self[key]
+        except KeyError:
+            return default
+
     def columns(self, key):
         """Get the names and types of the columns of composite types."""
-        typ = self[key]
+        try:
+            typ = self[key]
+        except KeyError:
+            return None  # this type is not known
         if typ.type != 'c' or not typ.relid:
-            return []  # this type is not composite
+            return None  # this type is not composite
         self._src.execute("SELECT attname, atttypid"
             " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
             " AND NOT attisdropped ORDER BY attnum" % typ.relid)
         return [ColumnInfo(name, int(oid))
             for name, oid in self._src.fetch(-1)]
 
-    @staticmethod
-    def typecast(typ, value):
+    def typecast(self, typ, value):
         """Cast value according to database type."""
         if value is None:
             # for NULL values, no typecast is necessary
             return None
         cast = _cast.get(typ)
+        if cast is str:
+            return value  # no typecast necessary
         if cast is None:
             if typ.startswith('_'):
                 # cast as an array type
                 cast = _cast.get(typ[1:])
                 return cast_array(value, cast)
-            # no typecast available or necessary
-            return value
+            # check whether this is a composite type
+            cols = self.columns(typ)
+            if cols:
+                getcast = self.getcast
+                cast = [getcast(col.type) for col in cols]
+                value = cast_record(value, cast)
+                fields = [col.name for col in cols]
+                record = namedtuple(typ, fields)
+                return record(*value)
+            return value  # no typecast available or necessary
         else:
             return cast(value)
 
+    def getcast(self, key):
+        """Get a cast function for the given database type."""
+        if isinstance(key, int):
+            try:
+                typ = self[key].name
+            except KeyError:
+                return None
+        else:
+            typ = key
+        typecast = self.typecast
+        return lambda value: typecast(typ, value)
+
 
-_re_array_escape = regex(r'(["\\])')
 _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
+_re_record_quote = regex(r'[(,"\\]')
+_re_array_escape = _re_record_escape = regex(r'(["\\])')
 
 
 class _quotedict(dict):
@@ -299,8 +338,7 @@
         if isinstance(val, list):
             return "'%s'" % self._quote_array(val)
         if isinstance(val, tuple):
-            q = self._quote
-            return 'ROW(%s)' % ','.join(str(q(v)) for v in val)
+            return "'%s'" % self._quote_record(val)
         try:
             return val.__pg_repr__()
         except AttributeError:
@@ -309,27 +347,53 @@
 
     def _quote_array(self, val):
         """Quote value as a literal constant for an array."""
-        # We could also cast to an array constructor here, but that is more
-        # verbose and you need to know the base type to build emtpy arrays.
+        q = self._quote_array_element
+        return '{%s}' % ','.join(q(v) for v in val)
+
+    def _quote_array_element(self, val):
+        """Quote value using the output syntax for arrays."""
         if isinstance(val, list):
-            return '{%s}' % ','.join(self._quote_array(v) for v in val)
+            return self._quote_array(val)
         if val is None:
             return 'null'
         if isinstance(val, (int, long, float)):
             return str(val)
         if isinstance(val, bool):
             return 't' if val else 'f'
+        if isinstance(val, tuple):
+            val = self._quote_record(val)
         if isinstance(val, basestring):
             if not val:
                 return '""'
             if _re_array_quote.search(val):
                 return '"%s"' % _re_array_escape.sub(r'\\\1', val)
             return val
-        try:
-            return val.__pg_repr__()
-        except AttributeError:
-            raise InterfaceError(
-                'do not know how to handle type %s' % type(val))
+        raise InterfaceError(
+            'do not know how to handle base type %s' % type(val))
+
+    def _quote_record(self, val):
+        """Quote value as a literal constant for a record."""
+        q = self._quote_record_element
+        return '(%s)' % ','.join(q(v) for v in val)
+
+    def _quote_record_element(self, val):
+        """Quote value using the output syntax for records."""
+        if val is None:
+            return ''
+        if isinstance(val, (int, long, float)):
+            return str(val)
+        if isinstance(val, bool):
+            return 't' if val else 'f'
+        if isinstance(val, list):
+            val = self._quote_array(val)
+        if isinstance(val, basestring):
+            if not val:
+                return '""'
+            if _re_record_quote.search(val):
+                return '"%s"' % _re_record_escape.sub(r'\\\1', val)
+            return val
+        raise InterfaceError(
+            'do not know how to handle component type %s' % type(val))
 
     def _quoteparams(self, string, parameters):
         """Quote parameters.

Modified: trunk/pgmodule.c
==============================================================================
--- trunk/pgmodule.c    Wed Jan 27 13:26:41 2016        (r790)
+++ trunk/pgmodule.c    Wed Jan 27 18:09:18 2016        (r791)
@@ -555,7 +555,7 @@
                        }
                        buf[j] = '\0'; s = buf;
                        /* FALLTHROUGH */ /* no break here */
-       
+
                case PYGRES_DECIMAL:
                        if (decimal)
                        {
@@ -597,8 +597,9 @@
        && (s[2] == 'l' || s[2] == 'L') \
        && (s[3] == 'l' || s[3] == 'L'))
 
-/* Cast string s with size and encoding to a Python list.
-   Use cast function if specified or basetype to cast elements.
+/* Cast string s with size and encoding to a Python list,
+   using the input and output syntax for arrays.
+   Use internal type or cast function to cast elements.
    The parameter delim specifies the delimiter for the elements,
    since some types do not use the default delimiter of a comma. */
 static PyObject *
@@ -614,7 +615,13 @@
                type &= ~PYGRES_ARRAY; /* get the base type */
                if (!type) type = PYGRES_TEXT;
        }
-       if (!delim) delim = ',';
+       if (!delim)
+               delim = ',';
+       else if (delim == '{' || delim =='}' || delim=='\\')
+       {
+               PyErr_SetString(PyExc_ValueError, "Invalid array delimiter");
+               return NULL;
+       }
 
        /* strip blanks at the beginning */
        while (s != end && *s == ' ') ++s;
@@ -653,7 +660,7 @@
        if (!depth)
        {
                PyErr_SetString(PyExc_ValueError,
-                       "Array must start with an opening brace");
+                       "Array must start with a left brace");
                return NULL;
        }
        if (ranges && depth != ranges)
@@ -689,13 +696,16 @@
                                {
                                        PyErr_SetString(PyExc_ValueError,
                                                "Subarray expected but not 
found");
-                                       return NULL;
+                                       Py_DECREF(result); return NULL;
                                }
                        }
                        else if (*s != '}') break; /* error */
                        subresult = result;
                        result = stack[--level];
-                       if (PyList_Append(result, subresult)) return NULL;
+                       if (PyList_Append(result, subresult))
+                       {
+                               Py_DECREF(result); return NULL;
+                       }
                }
                else if (level == depth) /* we expect elements at this level */
                {
@@ -708,7 +718,7 @@
                        {
                                PyErr_SetString(PyExc_ValueError,
                                        "Subarray found where not expected");
-                               return NULL;
+                               Py_DECREF(result); return NULL;
                        }
                        if (*s == '"') /* quoted element */
                        {
@@ -752,13 +762,16 @@
                        {
                                if (escaped)
                                {
-                                       char   *r;
-                                       int             i;
+                                       char       *r;
+                                       Py_ssize_t      i;
 
                                        /* create unescaped string */
                                        t = estr;
                                        estr = (char *) PyMem_Malloc(esize);
-                                       if (!estr) return PyErr_NoMemory();
+                                       if (!estr)
+                                       {
+                                               Py_DECREF(result); return 
PyErr_NoMemory();
+                                       }
                                        for (i = 0, r = estr; i < esize; ++i)
                                        {
                                                if (*t == '\\') ++t, ++i;
@@ -788,14 +801,20 @@
                                        }
                                }
                                if (escaped) PyMem_Free(estr);
-                               if (!element) return NULL;
+                               if (!element)
+                               {
+                                       Py_DECREF(result); return NULL;
+                               }
                        }
                        else
                        {
-                               Py_INCREF(Py_None);
-                               element = Py_None;
+                               Py_INCREF(Py_None); element = Py_None;
+                       }
+                       if (PyList_Append(result, element))
+                       {
+                               Py_DECREF(element); Py_DECREF(result); return 
NULL;
                        }
-                       if (PyList_Append(result, element)) return NULL;
+                       Py_DECREF(element);
                        if (*s == delim)
                        {
                                do ++s; while (s != end && *s == ' ');
@@ -808,8 +827,8 @@
                        if (*s != '{')
                        {
                                PyErr_SetString(PyExc_ValueError,
-                                       "Subarray must start with an opening 
brace");
-                               return NULL;
+                                       "Subarray must start with a left 
brace");
+                               Py_DECREF(result); return NULL;
                        }
                        do ++s; while (s != end && *s == ' ');
                        if (s == end) break; /* error */
@@ -821,18 +840,202 @@
        {
                PyErr_SetString(PyExc_ValueError,
                        "Unexpected end of array");
-               return NULL;
+               Py_DECREF(result); return NULL;
        }
        do ++s; while (s != end && *s == ' ');
        if (s != end)
        {
                PyErr_SetString(PyExc_ValueError,
                        "Unexpected characters after end of array");
-               return NULL;
+               Py_DECREF(result); return NULL;
        }
        return result;
 }
 
+/* Cast string s with size and encoding to a Python tuple.
+   using the input and output syntax for composite types.
+   Use array of internal types or cast function or sequence of cast
+   functions to cast elements. The parameter len is the record size.
+   The parameter delim can specify a delimiter for the elements,
+   although composite types always use a comma as delimiter. */
+
+static PyObject *
+cast_record(char *s, Py_ssize_t size, int encoding,
+        int *type, PyObject *cast, Py_ssize_t len, char delim)
+{
+       PyObject   *result, *ret;
+       char       *end = s + size, *t;
+       Py_ssize_t      i;
+
+       if (!delim)
+               delim = ',';
+       else if (delim == '(' || delim ==')' || delim=='\\')
+       {
+               PyErr_SetString(PyExc_ValueError, "Invalid record delimiter");
+               return NULL;
+       }
+
+       /* strip blanks at the beginning */
+       while (s != end && *s == ' ') ++s;
+       if (s == end || *s != '(')
+       {
+               PyErr_SetString(PyExc_ValueError,
+                       "Record must start with a left parenthesis");
+               return NULL;
+       }
+       result = PyList_New(0);
+       if (!result) return NULL;
+       i = 0;
+       /* everything is set up, start parsing the record */
+       while (++s != end)
+       {
+               PyObject   *element;
+
+               if (*s == ')' || *s == delim)
+               {
+                       Py_INCREF(Py_None); element = Py_None;
+               }
+               else
+               {
+                       char       *estr;
+                       Py_ssize_t      esize;
+                       int quoted = 0, escaped =0;
+
+                       estr = s;
+                       quoted = *s == '"';
+                       if (quoted) ++s;
+                       esize = 0;
+                       while (s != end)
+                       {
+                               if (!quoted && (*s == ')' || *s == delim))
+                                       break;
+                               if (*s == '"')
+                               {
+                                       ++s; if (s == end) break;
+                                       if (!(quoted && *s == '"'))
+                                       {
+                                               quoted = !quoted; continue;
+                                       }
+                               }
+                               if (*s == '\\')
+                               {
+                                       ++s; if (s == end) break;
+                               }
+                               ++s, ++esize;
+                       }
+                       if (s == end) break; /* error */
+                       if (estr + esize != s)
+                       {
+                               char       *r;
+
+                               escaped = 1;
+                               /* create unescaped string */
+                               t = estr;
+                               estr = (char *) PyMem_Malloc(esize);
+                               if (!estr)
+                               {
+                                       Py_DECREF(result); return 
PyErr_NoMemory();
+                               }
+                               quoted = 0;
+                               r = estr;
+                               while (t != s)
+                               {
+                                       if (*t == '"')
+                                       {
+                                               ++t;
+                                               if (!(quoted && *t == '"'))
+                                               {
+                                                       quoted = !quoted; 
continue;
+                                               }
+                                       }
+                                       if (*t == '\\') ++t;
+                                       *r++ = *t++;
+                               }
+                       }
+                       if (type) /* internal casting of element type */
+                       {
+                               int etype = type[i];
+
+                               if (etype & PYGRES_ARRAY)
+                                       element = cast_array(
+                                               estr, esize, encoding, etype, 
NULL, 0);
+                               else if (etype & PYGRES_TEXT)
+                                       element = cast_sized_text(estr, esize, 
encoding, etype);
+                               else
+                                       element = cast_sized_simple(estr, 
esize, etype);
+                       }
+                       else /* external casting of base type */
+                       {
+#if IS_PY3
+                               element = encoding == pg_encoding_ascii ? NULL :
+                                       get_decoded_string(estr, esize, 
encoding);
+                               if (!element) /* no decoding necessary or 
possible */
+#endif
+                               element = PyBytes_FromStringAndSize(estr, 
esize);
+                               if (element && cast)
+                               {
+                                       if (len)
+                                       {
+                                               PyObject *ecast = 
PySequence_GetItem(cast, i);
+
+                                               if (ecast)
+                                               {
+                                                       if (ecast != Py_None)
+                                                               element = 
PyObject_CallFunctionObjArgs(
+                                                                       ecast, 
element, NULL);
+                                               }
+                                               else
+                                               {
+                                                       Py_DECREF(element); 
element = NULL;
+                                               }
+                                       }
+                                       else
+                                               element = 
PyObject_CallFunctionObjArgs(
+                                                       cast, element, NULL);
+                               }
+                       }
+                       if (escaped) PyMem_Free(estr);
+                       if (!element)
+                       {
+                               Py_DECREF(result); return NULL;
+                       }
+               }
+               if (PyList_Append(result, element))
+               {
+                       Py_DECREF(element); Py_DECREF(result); return NULL;
+               }
+               Py_DECREF(element);
+               if (len) ++i;
+               if (*s != delim) break; /* no next record */
+               if (len && i >= len)
+               {
+                       PyErr_SetString(PyExc_ValueError, "Too many columns");
+                       Py_DECREF(result); return NULL;
+               }
+       }
+       if (s == end || *s != ')')
+       {
+               PyErr_SetString(PyExc_ValueError, "Unexpected end of record");
+               Py_DECREF(result); return NULL;
+       }
+       do ++s; while (s != end && *s == ' ');
+       if (s != end)
+       {
+               PyErr_SetString(PyExc_ValueError,
+                       "Unexpected characters after end of record");
+               Py_DECREF(result); return NULL;
+       }
+       if (len && i < len)
+       {
+               PyErr_SetString(PyExc_ValueError, "Too few columns");
+               Py_DECREF(result); return NULL;
+       }
+
+       ret = PyList_AsTuple(result);
+       Py_DECREF(result);
+       return ret;
+}
+
 /* internal wrapper for the notice receiver callback */
 static void
 notice_receiver(void *arg, const PGresult *res)
@@ -5177,11 +5380,10 @@
 pgCastArray(PyObject *self, PyObject *args, PyObject *dict)
 {
        static const char *kwlist[] = {"string", "cast", "delim", NULL};
-       PyObject   *string_obj, *cast_obj = NULL;
-       char       *string;
+       PyObject   *string_obj, *cast_obj = NULL, *ret;
+       char       *string, delim = ',';
        Py_ssize_t      size;
        int                     encoding;
-       char            delim = ',';
 
        if (!PyArg_ParseTupleAndKeywords(args, dict, "O|Oc",
                        (char **) kwlist, &string_obj, &cast_obj, &delim))
@@ -5207,14 +5409,89 @@
        }
 
        if (!cast_obj || cast_obj == Py_None)
-               cast_obj = NULL;
+       {
+               if (cast_obj)
+               {
+                       Py_DECREF(cast_obj); cast_obj = NULL;
+               }
+       }
        else if (!PyCallable_Check(cast_obj))
        {
                PyErr_SetString(PyExc_TypeError, "The cast argument must be 
callable");
                return NULL;
        }
 
-       return cast_array(string, size, encoding, 0, cast_obj, delim);
+       ret = cast_array(string, size, encoding, 0, cast_obj, delim);
+
+       Py_XDECREF(string_obj);
+
+       return ret;
+}
+
+/* cast a string with a text representation of a record to a tuple */
+static char pgCastRecord__doc__[] =
+"cast_record(string, cast=None, delim=',') -- cast a string as a record";
+
+PyObject *
+pgCastRecord(PyObject *self, PyObject *args, PyObject *dict)
+{
+       static const char *kwlist[] = {"string", "cast", "delim", NULL};
+       PyObject   *string_obj, *cast_obj = NULL, *ret;
+       char       *string, delim = ',';
+       Py_ssize_t      size, len;
+       int                     encoding;
+
+       if (!PyArg_ParseTupleAndKeywords(args, dict, "O|Oc",
+                       (char **) kwlist, &string_obj, &cast_obj, &delim))
+               return NULL;
+
+       if (PyBytes_Check(string_obj))
+       {
+               encoding = pg_encoding_ascii;
+               PyBytes_AsStringAndSize(string_obj, &string, &size);
+               string_obj = NULL;
+       }
+       else if (PyUnicode_Check(string_obj))
+       {
+               encoding = pg_encoding_utf8;
+               string_obj = get_encoded_string(string_obj, encoding);
+               if (!string_obj) return NULL; /* pass the UnicodeEncodeError */
+               PyBytes_AsStringAndSize(string_obj, &string, &size);
+       }
+       else
+       {
+               PyErr_SetString(PyExc_TypeError, "cast_record() expects a 
string");
+               return NULL;
+       }
+
+       if (!cast_obj || PyCallable_Check(cast_obj))
+       {
+               len = 0;
+       }
+       else if (cast_obj == Py_None)
+       {
+               Py_DECREF(cast_obj); cast_obj = NULL; len = 0;
+       }
+       else if (PyTuple_Check(cast_obj) || PyList_Check(cast_obj))
+       {
+               len = PySequence_Size(cast_obj);
+               if (!len)
+               {
+                       Py_DECREF(cast_obj); cast_obj = NULL;
+               }
+       }
+       else
+       {
+               PyErr_SetString(PyExc_TypeError,
+                       "The cast argument must be callable or a tuple or list 
of such");
+               return NULL;
+       }
+
+       ret = cast_record(string, size, encoding, 0, cast_obj, len, delim);
+
+       Py_XDECREF(string_obj);
+
+       return ret;
 }
 
 
@@ -5249,6 +5526,8 @@
                        pgSetJsondecode__doc__},
        {"cast_array", (PyCFunction) pgCastArray, METH_VARARGS|METH_KEYWORDS,
                        pgCastArray__doc__},
+       {"cast_record", (PyCFunction) pgCastRecord, METH_VARARGS|METH_KEYWORDS,
+                       pgCastRecord__doc__},
 
 #ifdef DEFAULT_VARS
        {"get_defhost", pgGetDefHost, METH_VARARGS, pgGetDefHost__doc__},

Modified: trunk/tests/test_classic_functions.py
==============================================================================
--- trunk/tests/test_classic_functions.py       Wed Jan 27 13:26:41 2016        
(r790)
+++ trunk/tests/test_classic_functions.py       Wed Jan 27 18:09:18 2016        
(r791)
@@ -116,7 +116,7 @@
 class TestParseArray(unittest.TestCase):
     """Test the array parser."""
 
-    array_expressions = [
+    test_strings = [
         ('', str, ValueError),
         ('{}', None, []),
         ('{}', str, []),
@@ -155,6 +155,9 @@
         (r'{"a\bc"}', str, ['abc']),
         (r'{\a\b\c}', str, ['abc']),
         (r'{"\a\b\c"}', str, ['abc']),
+        (r'{"a"b"}', str, ValueError),
+        (r'{"a""b"}', str, ValueError),
+        (r'{"a\"b"}', str, ['a"b']),
         ('{"{}"}', str, ['{}']),
         (r'{\{\}}', str, ['{}']),
         ('{"{a,b,c}"}', str, ['{a,b,c}']),
@@ -217,6 +220,9 @@
         self.assertRaises(TypeError, f, '{}', None, None)
         self.assertRaises(TypeError, f, '{}', None, 1)
         self.assertRaises(TypeError, f, '{}', None, '')
+        self.assertRaises(ValueError, f, '{}', None, '\\')
+        self.assertRaises(ValueError, f, '{}', None, '{')
+        self.assertRaises(ValueError, f, '{}', None, '}')
         self.assertRaises(TypeError, f, '{}', None, ',;')
         self.assertEqual(f('{}'), [])
         self.assertEqual(f('{}', None), [])
@@ -299,22 +305,22 @@
 
     def testParserWithData(self):
         f = pg.cast_array
-        for expression, cast, expected in self.array_expressions:
+        for string, cast, expected in self.test_strings:
             if expected is ValueError:
-                self.assertRaises(ValueError, f, expression, cast)
+                self.assertRaises(ValueError, f, string, cast)
             else:
-                self.assertEqual(f(expression, cast), expected)
+                self.assertEqual(f(string, cast), expected)
 
     def testParserWithoutCast(self):
         f = pg.cast_array
 
-        for expression, cast, expected in self.array_expressions:
+        for string, cast, expected in self.test_strings:
             if cast is not str:
                 continue
             if expected is ValueError:
-                self.assertRaises(ValueError, f, expression)
+                self.assertRaises(ValueError, f, string)
             else:
-                self.assertEqual(f(expression), expected)
+                self.assertEqual(f(string), expected)
 
     def testParserWithDifferentDelimiter(self):
         f = pg.cast_array
@@ -327,13 +333,285 @@
             else:
                 return value
 
-        for expression, cast, expected in self.array_expressions:
-            expression = replace_comma(expression)
+        for string, cast, expected in self.test_strings:
+            string = replace_comma(string)
             if expected is ValueError:
-                self.assertRaises(ValueError, f, expression, cast)
+                self.assertRaises(ValueError, f, string, cast)
             else:
                 expected = replace_comma(expected)
-                self.assertEqual(f(expression, cast, b';'), expected)
+                self.assertEqual(f(string, cast, b';'), expected)
+
+
+class TestParseRecord(unittest.TestCase):
+    """Test the record parser."""
+
+    test_strings = [
+        ('', None, ValueError),
+        ('', str, ValueError),
+        ('(', None, ValueError),
+        ('(', str, ValueError),
+        ('()', None, (None,)),
+        ('()', str, (None,)),
+        ('()', int, (None,)),
+        ('(,)', str, (None, None)),
+        ('( , )', str, (' ', ' ')),
+        ('(")', None, ValueError),
+        ('("")', None, ('',)),
+        ('("")', str, ('',)),
+        ('("")', int, ValueError),
+        ('("" )', None, (' ',)),
+        ('("" )', str, (' ',)),
+        ('("" )', int, ValueError),
+        ('    ()    ', None, (None,)),
+        ('   (   )   ', None, ('   ',)),
+        ('(', str, ValueError),
+        ('(()', str, ('(',)),
+        ('(())', str, ValueError),
+        ('()(', str, ValueError),
+        ('()()', str, ValueError),
+        ('[]', str, ValueError),
+        ('{}', str, ValueError),
+        ('([])', str, ('[]',)),
+        ('(hello)', int, ValueError),
+        ('(42)', int, (42,)),
+        ('( 42 )', int, (42,)),
+        ('(  42)', int, (42,)),
+        ('(42)', str, ('42',)),
+        ('( 42 )', str, (' 42 ',)),
+        ('(  42)', str, ('  42',)),
+        ('(42', int, ValueError),
+        ('( 42 ', int, ValueError),
+        ('(hello)', str, ('hello',)),
+        ('( hello )', str, (' hello ',)),
+        ('(hello))', str, ValueError),
+        ('   (hello)   ', str, ('hello',)),
+        ('   (hello)   )', str, ValueError),
+        ('(hello)?', str, ValueError),
+        ('(null)', str, ('null',)),
+        ('(null)', int, ValueError),
+        (' ( NULL ) ', str, (' NULL ',)),
+        ('   (   NULL   )   ', str, ('   NULL   ',)),
+        (' ( null null ) ', str, (' null null ',)),
+        (' ("null") ', str, ('null',)),
+        (' ("NULL") ', str, ('NULL',)),
+        ('(Hi!)', str, ('Hi!',)),
+        ('("Hi!")', str, ('Hi!',)),
+        ("('Hi!')", str, ("'Hi!'",)),
+        ('(" Hi! ")', str, (' Hi! ',)),
+        ('("Hi!" )', str, ('Hi! ',)),
+        ('( "Hi!")', str, (' Hi!',)),
+        ('( "Hi!" )', str, (' Hi! ',)),
+        ('( ""Hi!"" )', str, (' Hi! ',)),
+        ('( """Hi!""" )', str, (' "Hi!" ',)),
+        ('(a")', str, ValueError),
+        ('("b)', str, ValueError),
+        ('("a" "b)', str, ValueError),
+        ('("a" "b")', str, ('a b',)),
+        ('( "a" "b" "c" )', str, (' a b c ',)),
+        ('(  "a"  "b"  "c"  )', str, ('  a  b  c  ',)),
+        ('(  "a,b"  "c,d"  )', str, ('  a,b  c,d  ',)),
+        ('( "(a,b,c)" d, e, "f,g")', str, (' (a,b,c) d', ' e', ' f,g')),
+        ('(a",b,c",d,"e,f")', str, ('a,b,c', 'd', 'e,f')),
+        ('( """a,b""", ""c,d"", "e,f", "g", ""h"", """i""")', str,
+            (' "a,b"', ' c', 'd', ' e,f', ' g', ' h', ' "i"')),
+        ('(a",b)",c"),(d,e)",f,g)', str, ('a,b)', 'c),(d,e)', 'f', 'g')),
+        ('(a"b)', str, ValueError),
+        (r'(a\"b)', str, ('a"b',)),
+        ('(a""b)', str, ('ab',)),
+        ('("a""b")', str, ('a"b',)),
+        (r'(a\,b)', str, ('a,b',)),
+        (r'(a\bc)', str, ('abc',)),
+        (r'("a\bc")', str, ('abc',)),
+        (r'(\a\b\c)', str, ('abc',)),
+        (r'("\a\b\c")', str, ('abc',)),
+        ('("()")', str, ('()',)),
+        (r'(\,)', str, (',',)),
+        (r'(\(\))', str, ('()',)),
+        (r'(\)\()', str, (')(',)),
+        ('("(a,b,c)")', str, ('(a,b,c)',)),
+        ("('abc')", str, ("'abc'",)),
+        ('("abc")', str, ('abc',)),
+        (r'(\"abc\")', str, ('"abc"',)),
+        (r"(\'abc\')", str, ("'abc'",)),
+        ('(Hello World!)', str, ('Hello World!',)),
+        ('(Hello, World!)', str, ('Hello', ' World!',)),
+        ('(Hello,\ World!)', str, ('Hello', ' World!',)),
+        ('(Hello\, World!)', str, ('Hello, World!',)),
+        ('("Hello World!")', str, ('Hello World!',)),
+        ("(this,shouldn't,be,null)", str, ('this', "shouldn't", 'be', 'null')),
+        ('(null,should,be,)', str, ('null', 'should', 'be', None)),
+        ('(abcABC0123!?+-*/=&%$\\\\\'\\"{[]}"""":;\\,,)', str,
+            ('abcABC0123!?+-*/=&%$\\\'"{[]}":;,', None)),
+        ('(3, 2, 1,)', int, (3, 2, 1, None)),
+        ('(3, 2, 1, )', int, ValueError),
+        ('(, 1, 2, 3)', int, (None, 1, 2, 3)),
+        ('( , 1, 2, 3)', int, ValueError),
+        ('(,1,,2,,3,)', int, (None, 1, None, 2, None, 3, None)),
+        ('(3,17,51)', int, (3, 17, 51)),
+        (' ( 3 , 17 , 51 ) ', int, (3, 17, 51)),
+        ('(3,17,51)', str, ('3', '17', '51')),
+        (' ( 3 , 17 , 51 ) ', str, (' 3 ', ' 17 ', ' 51 ')),
+        ('(1,"2",abc,"def")', str, ('1', '2', 'abc', 'def')),
+        ('(())', str, ValueError),
+        ('()))', str, ValueError),
+        ('()()', str, ValueError),
+        ('((()', str, ('((',)),
+        ('(())', int, ValueError),
+        ('((),())', str, ValueError),
+        ('("()","()")', str, ('()', '()')),
+        ('( " () , () , () " )', str, ('  () , () , ()  ',)),
+        ('(20000, 25000, 25000, 25000)', int, (20000, 25000, 25000, 25000)),
+        ('("breakfast","consulting","meeting","lunch")', str,
+            ('breakfast', 'consulting', 'meeting', 'lunch')),
+        ('("breakfast","consulting","meeting","lunch")',
+            (str, str, str), ValueError),
+        ('("breakfast","consulting","meeting","lunch")', (str, str, str, str),
+            ('breakfast', 'consulting', 'meeting', 'lunch')),
+        ('("breakfast","consulting","meeting","lunch")',
+            (str, str, str, str, str), ValueError),
+        ('("fuzzy dice",42,1.9375)', None, ('fuzzy dice', '42', '1.9375')),
+        ('("fuzzy dice",42,1.9375)', str, ('fuzzy dice', '42', '1.9375')),
+        ('("fuzzy dice",42,1.9375)', int, ValueError),
+        ('("fuzzy dice",42,1.9375)', (str, int, float),
+            ('fuzzy dice', 42, 1.9375)),
+        ('("fuzzy dice",42,1.9375)', (str, int), ValueError),
+        ('("fuzzy dice",42,1.9375)', (str, int, float, str), ValueError),
+        ('("fuzzy dice",42,)', (str, int, float), ('fuzzy dice', 42, None)),
+        ('("fuzzy dice",42,)', (str, int), ValueError),
+        ('("",42,)', (str, int, float), ('', 42, None)),
+        ('("fuzzy dice","",1.9375)', (str, int, float), ValueError),
+        ('(fuzzy dice,"42","1.9375")', (str, int, float),
+            ('fuzzy dice', 42, 1.9375))]
+
+    def testParserParams(self):
+        f = pg.cast_record
+        self.assertRaises(TypeError, f)
+        self.assertRaises(TypeError, f, None)
+        self.assertRaises(TypeError, f, '()', 1)
+        self.assertRaises(TypeError, f, '()', ',',)
+        self.assertRaises(TypeError, f, '()', None, None)
+        self.assertRaises(TypeError, f, '()', None, 1)
+        self.assertRaises(TypeError, f, '()', None, '')
+        self.assertRaises(ValueError, f, '()', None, '\\')
+        self.assertRaises(ValueError, f, '()', None, '(')
+        self.assertRaises(ValueError, f, '()', None, ')')
+        self.assertRaises(TypeError, f, '{}', None, ',;')
+        self.assertEqual(f('()'), (None,))
+        self.assertEqual(f('()', None), (None,))
+        self.assertEqual(f('()', None, b';'), (None,))
+        self.assertEqual(f('()', str), (None,))
+        self.assertEqual(f('()', str, b';'), (None,))
+
+    def testParserSimple(self):
+        r = pg.cast_record('(a,b,c)')
+        self.assertIsInstance(r, tuple)
+        self.assertEqual(len(r), 3)
+        self.assertEqual(r, ('a', 'b', 'c'))
+
+    def testParserNested(self):
+        f = pg.cast_record
+        self.assertRaises(ValueError, f, '((a,b,c))')
+        self.assertRaises(ValueError, f, '((a,b),(c,d))')
+        self.assertRaises(ValueError, f, '((a),(b),(c))')
+        self.assertRaises(ValueError, f, '(((((((abc)))))))')
+
+    def testParserManyElements(self):
+        f = pg.cast_record
+        for n in 3, 5, 9, 12, 16, 32, 64, 256:
+            r = '(%s)' % ','.join(map(str, range(n)))
+            r = f(r, int)
+            self.assertEqual(r, tuple(range(n)))
+
+    def testParserCastUniform(self):
+        f = pg.cast_record
+        self.assertEqual(f('(1)'), ('1',))
+        self.assertEqual(f('(1)', None), ('1',))
+        self.assertEqual(f('(1)', int), (1,))
+        self.assertEqual(f('(1)', str), ('1',))
+        self.assertEqual(f('(a)'), ('a',))
+        self.assertEqual(f('(a)', None), ('a',))
+        self.assertRaises(ValueError, f, '(a)', int)
+        self.assertEqual(f('(a)', str), ('a',))
+        cast = lambda s: '%s is ok' % s
+        self.assertEqual(f('(a)', cast), ('a is ok',))
+
+    def testParserCastNonUniform(self):
+        f = pg.cast_record
+        self.assertEqual(f('(1)', []), ('1',))
+        self.assertEqual(f('(1)', [None]), ('1',))
+        self.assertEqual(f('(1)', [str]), ('1',))
+        self.assertEqual(f('(1)', [int]), (1,))
+        self.assertRaises(ValueError, f, '(1)', [None, None])
+        self.assertRaises(ValueError, f, '(1)', [str, str])
+        self.assertRaises(ValueError, f, '(1)', [int, int])
+        self.assertEqual(f('(a)', [None]), ('a',))
+        self.assertEqual(f('(a)', [str]), ('a',))
+        self.assertRaises(ValueError, f, '(a)', [int])
+        self.assertEqual(f('(1,a)', [int, str]), (1, 'a'))
+        self.assertRaises(ValueError, f, '(1,a)', [str, int])
+        self.assertEqual(f('(a,1)', [str, int]), ('a', 1))
+        self.assertRaises(ValueError, f, '(a,1)', [int, str])
+        self.assertEqual(f('(1,a,2,b,3,c)',
+            [int, str, int, str, int, str]), (1, 'a', 2, 'b', 3, 'c'))
+        self.assertEqual(f('(1,a,2,b,3,c)',
+            (int, str, int, str, int, str)), (1, 'a', 2, 'b', 3, 'c'))
+        cast1 = lambda s: '%s is ok' % s
+        self.assertEqual(f('(a)', [cast1]), ('a is ok',))
+        cast2 = lambda s: 'and %s is ok, too' % s
+        self.assertEqual(f('(a,b)', [cast1, cast2]),
+            ('a is ok', 'and b is ok, too'))
+        self.assertRaises(ValueError, f, '(a)', [cast1, cast2])
+        self.assertRaises(ValueError, f, '(a,b,c)', [cast1, cast2])
+        self.assertEqual(f('(1,2,3,4,5,6)',
+            [int, float, str, None, cast1, cast2]),
+            (1, 2.0, '3', '4', '5 is ok', 'and 6 is ok, too'))
+
+    def testParserDelim(self):
+        f = pg.cast_record
+        self.assertEqual(f('(1,2)'), ('1', '2'))
+        self.assertEqual(f('(1,2)', delim=b','), ('1', '2'))
+        self.assertEqual(f('(1;2)'), ('1;2',))
+        self.assertEqual(f('(1;2)', delim=b';'), ('1', '2'))
+        self.assertEqual(f('(1,2)', delim=b';'), ('1,2',))
+
+    def testParserWithData(self):
+        f = pg.cast_record
+        for string, cast, expected in self.test_strings:
+            if expected is ValueError:
+                self.assertRaises(ValueError, f, string, cast)
+            else:
+                self.assertEqual(f(string, cast), expected)
+
+    def testParserWithoutCast(self):
+        f = pg.cast_record
+
+        for string, cast, expected in self.test_strings:
+            if cast is not str:
+                continue
+            if expected is ValueError:
+                self.assertRaises(ValueError, f, string)
+            else:
+                self.assertEqual(f(string), expected)
+
+    def testParserWithDifferentDelimiter(self):
+        f = pg.cast_record
+
+        def replace_comma(value):
+            if isinstance(value, str):
+                return value.replace(';', '@').replace(
+                    ',', ';').replace('@', ',')
+            elif isinstance(value, tuple):
+                return tuple(replace_comma(v) for v in value)
+            else:
+                return value
+
+        for string, cast, expected in self.test_strings:
+            string = replace_comma(string)
+            if expected is ValueError:
+                self.assertRaises(ValueError, f, string, cast)
+            else:
+                expected = replace_comma(expected)
+                self.assertEqual(f(string, cast, b';'), expected)
 
 
 class TestEscapeFunctions(unittest.TestCase):

Modified: trunk/tests/test_dbapi20.py
==============================================================================
--- trunk/tests/test_dbapi20.py Wed Jan 27 13:26:41 2016        (r790)
+++ trunk/tests/test_dbapi20.py Wed Jan 27 18:09:18 2016        (r791)
@@ -293,15 +293,19 @@
     def test_type_cache(self):
         con = self._connect()
         cur = con.cursor()
-        type_cache = cur.type_cache
+        type_cache = con.type_cache
+        self.assertNotIn('numeric', type_cache)
         type_info = type_cache['numeric']
+        self.assertIn('numeric', type_cache)
         self.assertEqual(type_info.oid, 1700)
         self.assertEqual(type_info.name, 'numeric')
         self.assertEqual(type_info.type, 'b')  # base
         self.assertEqual(type_info.category, 'N')  # numeric
         self.assertEqual(type_info.delim, ',')
-        self.assertIs(cur.type_cache[1700], type_info)
+        self.assertIs(con.type_cache[1700], type_info)
+        self.assertNotIn('pg_type', type_cache)
         type_info = type_cache['pg_type']
+        self.assertIn('numeric', type_cache)
         self.assertEqual(type_info.type, 'c')  # composite
         self.assertEqual(type_info.category, 'C')  # composite
         cols = type_cache.columns('pg_type')
@@ -315,6 +319,22 @@
         self.assertEqual(typlen.name, 'int2')
         self.assertEqual(typlen.type, 'b')  # base
         self.assertEqual(typlen.category, 'N')  # numeric
+        cur.close()
+        cur = con.cursor()
+        type_cache = con.type_cache
+        self.assertIn('numeric', type_cache)
+        cur.close()
+        con.close()
+        con = self._connect()
+        cur = con.cursor()
+        type_cache = con.type_cache
+        self.assertNotIn('pg_type', type_cache)
+        self.assertEqual(type_cache.get('pg_type'), type_info)
+        self.assertIn('pg_type', type_cache)
+        self.assertIsNone(type_cache.get(
+            self.table_prefix + '_surely_does_not_exist'))
+        cur.close()
+        con.close()
 
     def test_cursor_iteration(self):
         con = self._connect()
@@ -450,7 +470,7 @@
                 inval = inval.strftime('%Y-%m-%d %H:%M:%S')
             self.assertEqual(inval, outval)
 
-    def test_roundtrip_with_list(self):
+    def test_insert_array(self):
         values = [(None, None), ([], []), ([None], [[None], ['null']]),
             ([1, 2, 3], [['a', 'b'], ['c', 'd']]),
             ([20000, 25000, 25000, 30000],
@@ -463,7 +483,8 @@
             cur.execute("create table %s"
                 " (n smallint, i int[], t text[][])" % table)
             params = [(n, v[0], v[1]) for n, v in enumerate(values)]
-            cur.execute("insert into %s values (%%d,%%s,%%s)" % table, params)
+            cur.executemany(
+                "insert into %s values (%%d,%%s,%%s)" % table, params)
             cur.execute("select i, t from %s order by n" % table)
             self.assertEqual(cur.description[0].type_code, pgdb.ARRAY)
             self.assertEqual(cur.description[0].type_code, pgdb.NUMBER)
@@ -475,17 +496,70 @@
             con.close()
         self.assertEqual(rows, values)
 
-    def test_tuple_binds_as_row(self):
-        values = [(1, 2.5, 'this is a test')]
-        output = '(1,2.5,"this is a test")'
+    def test_select_array(self):
+        values = ([1, 2, 3, None], ['a', 'b', 'c', None])
         con = self._connect()
         try:
             cur = con.cursor()
-            cur.execute("select %s", values)
-            outval = cur.fetchone()[0]
+            cur.execute("select %s::int[], %s::text[]", values)
+            row = cur.fetchone()
+        finally:
+            con.close()
+        self.assertEqual(row, values)
+
+    def test_insert_record(self):
+        values = [('John', 61), ('Jane', 63),
+                  ('Fred', None), ('Wilma', None),
+                  (None, 42), (None, None)]
+        table = self.table_prefix + 'booze'
+        record = self.table_prefix + 'munch'
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            cur.execute("create type %s as (name varchar, age int)" % record)
+            cur.execute("create table %s (n smallint, r %s)" % (table, record))
+            params = enumerate(values)
+            cur.executemany("insert into %s values (%%d,%%s)" % table, params)
+            cur.execute("select r from %s order by n" % table)
+            type_code = cur.description[0].type_code
+            self.assertEqual(type_code, record)
+            columns = con.type_cache.columns(type_code)
+            self.assertEqual(columns[0].name, 'name')
+            self.assertEqual(columns[1].name, 'age')
+            self.assertEqual(con.type_cache[columns[0].type].name, 'varchar')
+            self.assertEqual(con.type_cache[columns[1].type].name, 'int4')
+            rows = cur.fetchall()
         finally:
+            cur.execute('drop table %s' % table)
+            cur.execute('drop type %s' % record)
             con.close()
-        self.assertEqual(outval, output)
+        self.assertEqual(len(rows), len(values))
+        rows = [row[0] for row in rows]
+        self.assertEqual(rows, values)
+        self.assertEqual(rows[0].name, 'John')
+        self.assertEqual(rows[0].age, 61)
+
+    def test_select_record(self):
+        values = (1, 25000, 2.5, 'hello', 'Hello World!', 'Hello, World!',
+            '(test)', '(x,y)', ' x y ', 'null', None)
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            # Note that %s::record does not work on input unfortunately
+            # ("input of anonymous composite types is not implemented").
+            # so we need to resort to a row constructor instead.
+            row = ','.join(["%s"] * len(values))
+            cur.execute("select ROW(%s) as test_record" % row, values)
+            self.assertEqual(cur.description[0].name, 'test_record')
+            self.assertEqual(cur.description[0].type_code, 'record')
+            row = cur.fetchone()[0]
+        finally:
+            con.close()
+        # Note that the element types get lost since we created an
+        # untyped record (an anonymous composite type). For the same
+        # reason this is also a normal tuple, not a named tuple.
+        text_values = tuple(None if v is None else str(v) for v in values)
+        self.assertEqual(row, text_values)
 
     def test_custom_type(self):
         values = [3, 5, 65]
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to