Author: cito
Date: Wed Jan 20 17:44:20 2016
New Revision: 772

Log:
Improve test coverage for the pgdb module

Includes a simple patch that allows storing Python lists or tuple values
in PostgreSQL array fields (they are not yet converted when read, though).

Also re-activated the shortcut methods on the connection again
since they can be sometimes useful.

Test coverage is now around 95%, the remaining lines are due to support for
old Python versions or obscure database errors that can't easily be aroused.

Modified:
   trunk/docs/contents/changelog.rst
   trunk/pgdb.py
   trunk/tests/test_dbapi20.py
   trunk/tests/test_dbapi20_copy.py

Modified: trunk/docs/contents/changelog.rst
==============================================================================
--- trunk/docs/contents/changelog.rst   Wed Jan 20 13:22:10 2016        (r771)
+++ trunk/docs/contents/changelog.rst   Wed Jan 20 17:44:20 2016        (r772)
@@ -25,6 +25,9 @@
   are now named tuples, i.e. their elements can be also accessed by name.
   The column names and types can now also be requested through the
   colnames and coltypes attributes, which are not part of DB-API 2 though.
+- Re-activated the shortcut methods of the DB-API connection since they
+  can be handy when doing experiments or writing quick scripts. We keep
+  them undocumented though and discourage using them in production.
 - The tty parameter and attribute of database connections has been
   removed since it is not supported any more since PostgreSQL 7.4.
 - The pkey() method of the classic interface now returns tuples instead

Modified: trunk/pgdb.py
==============================================================================
--- trunk/pgdb.py       Wed Jan 20 13:22:10 2016        (r771)
+++ trunk/pgdb.py       Wed Jan 20 17:44:20 2016        (r772)
@@ -112,11 +112,9 @@
 # this module use extended python format codes
 paramstyle = 'pyformat'
 
-# shortcut methods are not supported by default
-# since they have been excluded from DB API 2
-# and are not recommended by the DB SIG.
-
-shortcutmethods = 0
+# shortcut methods have been excluded from DB API 2 and
+# are not recommended by the DB SIG, but they can be handy
+shortcutmethods = 1
 
 
 ### Internal Types Handling
@@ -144,16 +142,7 @@
 
 
 def _cast_float(value):
-    try:
-        return float(value)
-    except ValueError:
-        if value == 'NaN':
-            return nan
-        elif value == 'Infinity':
-            return inf
-        elif value == '-Infinity':
-            return -inf
-        raise
+    return float(value)  # this also works with NaN and Infinity
 
 
 _cast = {'bool': _cast_bool, 'bytea': _cast_bytea,
@@ -280,7 +269,8 @@
         elif val is None:
             val = 'NULL'
         elif isinstance(val, (list, tuple)):
-            val = '(%s)' % ','.join(map(lambda v: str(self._quote(v)), val))
+            q = self._quote
+            val = 'ARRAY[%s]' % ','.join(str(q(v)) for v in val)
         elif Decimal is not float and isinstance(val, Decimal):
             pass
         elif hasattr(val, '__pg_repr__'):
@@ -339,8 +329,8 @@
                 try:
                     self._cnx.source().execute(sql)
                 except DatabaseError:
-                    raise
-                except Exception:
+                    raise  # database provides error message
+                except Exception as err:
                     raise _op_error("can't start transaction")
                 self._dbcnx._tnx = True
             for parameters in seq_of_parameters:
@@ -354,9 +344,10 @@
                 else:
                     self.rowcount = -1
         except DatabaseError:
-            raise
+            raise  # database provides error message
         except Error as err:
-            raise _db_error("error in '%s': '%s' " % (sql, err))
+            raise _db_error(
+                "error in '%s': '%s' " % (sql, err), InterfaceError)
         except Exception as err:
             raise _op_error("internal error in '%s': %s" % (sql, err))
         # then initialize result raw count and description
@@ -493,9 +484,9 @@
         else:
             if size is None:
                 size = 8192
+            elif not isinstance(size, int):
+                raise TypeError("The size option must be an integer")
             if size > 0:
-                if not isinstance(size, int):
-                    raise TypeError("The size option must be an integer")
 
                 def chunks():
                     while True:

Modified: trunk/tests/test_dbapi20.py
==============================================================================
--- trunk/tests/test_dbapi20.py Wed Jan 20 13:22:10 2016        (r771)
+++ trunk/tests/test_dbapi20.py Wed Jan 20 17:44:20 2016        (r772)
@@ -28,6 +28,8 @@
     except ImportError:
         pass
 
+from datetime import datetime
+
 try:
     long
 except NameError:  # Python >= 3.0
@@ -39,6 +41,16 @@
     OrderedDict = None
 
 
+class PgBitString:
+    """Test object with a PostgreSQL representation as Bit String."""
+
+    def __init__(self, value):
+        self.value = value
+
+    def __pg_repr__(self):
+         return "B'{0:b}'".format(self.value)
+
+
 class test_PyGreSQL(dbapi20.DatabaseAPI20Test):
 
     driver = pgdb
@@ -340,7 +352,8 @@
         from math import isnan, isinf
         self.assertTrue(isnan(nan) and not isinf(nan))
         self.assertTrue(isinf(inf) and not isnan(inf))
-        values = [0, 1, 0.03125, -42.53125, nan, inf, -inf]
+        values = [0, 1, 0.03125, -42.53125, nan, inf, -inf,
+            'nan', 'inf', '-inf', 'NaN', 'Infinity', '-Infinity']
         table = self.table_prefix + 'booze'
         con = self._connect()
         try:
@@ -356,6 +369,12 @@
         self.assertEqual(len(rows), len(values))
         rows = [row[1] for row in rows]
         for inval, outval in zip(values, rows):
+            if inval in ('inf', 'Infinity'):
+                inval = inf
+            elif inval in ('-inf', '-Infinity'):
+                inval = -inf
+            elif inval in ('nan', 'NaN'):
+                inval = nan
             if isinf(inval):
                 self.assertTrue(isinf(outval))
                 if inval < 0:
@@ -367,6 +386,70 @@
             else:
                 self.assertEqual(inval, outval)
 
+    def test_datetime(self):
+        values = ['2011-07-17 15:47:42', datetime(2016, 1, 20, 20, 15, 51)]
+        table = self.table_prefix + 'booze'
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            cur.execute("set datestyle to 'iso'")
+            cur.execute(
+                "create table %s (n smallint, ts timestamp)" % table)
+            params = enumerate(values)
+            cur.executemany("insert into %s values (%%s,%%s)" % table, params)
+            cur.execute("select * from %s order by 1" % table)
+            rows = cur.fetchall()
+        finally:
+            con.close()
+        self.assertEqual(len(rows), len(values))
+        rows = [row[1] for row in rows]
+        for inval, outval in zip(values, rows):
+            if isinstance(inval, datetime):
+                inval = inval.strftime('%Y-%m-%d %H:%M:%S')
+            self.assertEqual(inval, outval)
+
+    def test_array(self):
+        values = ([20000, 25000, 25000, 30000],
+            [['breakfast', 'consulting'], ['meeting', 'lunch']])
+        output = ('{20000,25000,25000,30000}',
+            '{{breakfast,consulting},{meeting,lunch}}')
+        table = self.table_prefix + 'booze'
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            cur.execute("create table %s (i int[], t text[][])" % table)
+            cur.execute("insert into %s values (%%s,%%s)" % table, values)
+            cur.execute("select * from %s" % table)
+            row = cur.fetchone()
+        finally:
+            con.close()
+        self.assertEqual(row, output)
+
+    def test_custom_type(self):
+        values = [3, 5, 65]
+        values = list(map(PgBitString, values))
+        table = self.table_prefix + 'booze'
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            params = enumerate(values)  # params have __pg_repr__ method
+            cur.execute(
+                'create table "%s" (n smallint, b bit varying(7))' % table)
+            cur.executemany("insert into %s values (%%s,%%s)" % table, params)
+            cur.execute("select * from %s order by 1" % table)
+            rows = cur.fetchall()
+        finally:
+            con.close()
+        self.assertEqual(len(rows), len(values))
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            params = (1, object())  # an object that cannot be handled
+            self.assertRaises(pgdb.InterfaceError, cur.execute,
+                "insert into %s values (%%s,%%s)" % table, params)
+        finally:
+            con.close()
+
     def test_set_decimal_type(self):
         decimal_type = pgdb.decimal_type()
         self.assertTrue(decimal_type is not None and callable(decimal_type))
@@ -473,6 +556,54 @@
         values[4] = values[6] = False
         self.assertEqual(rows, values)
 
+    def test_execute_edge_cases(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            sql = 'invalid'  # should be ignored with empty parameter list
+            cur.executemany(sql, [])
+            sql = 'select %d + 1'
+            cur.execute(sql, [(1,)])  # deprecated use of execute()
+            self.assertEqual(cur.fetchone()[0], 2)
+            sql = 'select 1/0'  # cannot be executed
+            self.assertRaises(pgdb.ProgrammingError, cur.execute, sql)
+            cur.close()
+            con.rollback()
+            if pgdb.shortcutmethods:
+                res = con.execute('select %d', (1,)).fetchone()
+                self.assertEqual(res, (1,))
+                res = con.executemany('select %d', [(1,), (2,)]).fetchone()
+                self.assertEqual(res, (2,))
+        finally:
+            con.close()
+        sql = 'select 1'  # cannot be executed after connection is closed
+        self.assertRaises(pgdb.OperationalError, cur.execute, sql)
+
+    def test_fetchmany_with_keep(self):
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            self.assertEqual(cur.arraysize, 1)
+            cur.execute('select * from generate_series(1, 25)')
+            self.assertEqual(len(cur.fetchmany()), 1)
+            self.assertEqual(len(cur.fetchmany()), 1)
+            self.assertEqual(cur.arraysize, 1)
+            cur.arraysize = 3
+            self.assertEqual(len(cur.fetchmany()), 3)
+            self.assertEqual(len(cur.fetchmany()), 3)
+            self.assertEqual(cur.arraysize, 3)
+            self.assertEqual(len(cur.fetchmany(size=2)), 2)
+            self.assertEqual(cur.arraysize, 3)
+            self.assertEqual(len(cur.fetchmany()), 3)
+            self.assertEqual(len(cur.fetchmany()), 3)
+            self.assertEqual(len(cur.fetchmany(size=2, keep=True)), 2)
+            self.assertEqual(cur.arraysize, 2)
+            self.assertEqual(len(cur.fetchmany()), 2)
+            self.assertEqual(len(cur.fetchmany()), 2)
+            self.assertEqual(len(cur.fetchmany(25)), 3)
+        finally:
+            con.close()
+
     def test_nextset(self):
         con = self._connect()
         cur = con.cursor()

Modified: trunk/tests/test_dbapi20_copy.py
==============================================================================
--- trunk/tests/test_dbapi20_copy.py    Wed Jan 20 13:22:10 2016        (r771)
+++ trunk/tests/test_dbapi20_copy.py    Wed Jan 20 17:44:20 2016        (r772)
@@ -50,7 +50,7 @@
 
     def __str__(self):
         data = self.data
-        if str is unicode:
+        if str is unicode:  # Python >= 3.0
             data = data.decode('utf-8')
         return data
 
@@ -75,7 +75,7 @@
 
     def __str__(self):
         data = self.data
-        if str is unicode:
+        if str is unicode:  # Python >= 3.0
             data = data.decode('utf-8')
         return data
 
@@ -220,7 +220,10 @@
         call('1\t', 'copytest',
              format='text', sep='\t', null='', columns=['id', 'name'])
         self.assertRaises(TypeError, call)
+        self.assertRaises(TypeError, call, None)
+        self.assertRaises(TypeError, call, None, None)
         self.assertRaises(TypeError, call, '0\t')
+        self.assertRaises(TypeError, call, '0\t', None)
         self.assertRaises(TypeError, call, '0\t', 42)
         self.assertRaises(TypeError, call, '0\t', ['copytest'])
         self.assertRaises(TypeError, call, '0\t', 'copytest', format=42)
@@ -230,6 +233,8 @@
         self.assertRaises(TypeError, call, '0\t', 'copytest', null=42)
         self.assertRaises(ValueError, call, '0\t', 'copytest', size='bad')
         self.assertRaises(TypeError, call, '0\t', 'copytest', columns=42)
+        self.assertRaises(ValueError, call, b'', 'copytest',
+            format='binary', sep=',')
 
     def test_input_string(self):
         ret = self.copy_from('42\tHello, world!')
@@ -248,7 +253,7 @@
         self.check_table()
         self.check_rowcount()
 
-    if str is unicode:
+    if str is unicode:  # Python >= 3.0
 
         def test_input_bytes(self):
             self.copy_from(b'42\tHello, world!')
@@ -257,7 +262,7 @@
             self.copy_from(self.data_text.encode('utf-8'))
             self.check_table()
 
-    if str is not unicode:
+    else:  # Python < 3.0
 
         def test_input_unicode(self):
             self.copy_from(u'43\tWürstel, Käse!')
@@ -271,10 +276,20 @@
         self.check_table()
         self.check_rowcount()
 
+    def test_input_iterable_invalid(self):
+        self.assertRaises(IOError, self.copy_from, [None])
+
     def test_input_iterable_with_newlines(self):
         self.copy_from('%s\n' % row for row in self.data_text.splitlines())
         self.check_table()
 
+    if str is unicode:  # Python >= 3.0
+
+        def test_input_iterable_bytes(self):
+            self.copy_from(row.encode('utf-8')
+                for row in self.data_text.splitlines())
+            self.check_table()
+
     def test_sep(self):
         stream = ('%d-%s' % row for row in self.data)
         self.copy_from(stream, sep='-')
@@ -366,6 +381,10 @@
         self.assertEqual(stream.sizes, [None])
         self.check_rowcount()
 
+    def test_size_invalid(self):
+        self.assertRaises(TypeError,
+            self.copy_from, self.data_file, size='invalid')
+
 
 class TestCopyTo(TestCopy):
     """Test the copy_to method."""
@@ -415,7 +434,7 @@
         self.assertEqual(rows, self.data_text)
         self.check_rowcount()
 
-    if str is unicode:
+    if str is unicode:  # Python >= 3.0
 
         def test_generator_bytes(self):
             ret = self.copy_to(decode=False)
@@ -426,7 +445,7 @@
             self.assertIsInstance(rows, bytes)
             self.assertEqual(rows, self.data_text.encode('utf-8'))
 
-    if str is not unicode:
+    else:  # Python < 3.0
 
         def test_generator_unicode(self):
             ret = self.copy_to(decode=True)
@@ -516,6 +535,8 @@
             format='binary', decode=True)
 
     def test_query(self):
+        self.assertRaises(ValueError, self.cursor.copy_to, None,
+            "select name from copytest", columns='noname')
         ret = self.cursor.copy_to(None,
             "select name||'!' from copytest where id=1941")
         self.assertIsInstance(ret, Iterable)
@@ -531,7 +552,7 @@
         self.assertIs(ret, self.cursor)
         self.assertEqual(str(stream), self.data_text)
         data = self.data_text
-        if str is unicode:
+        if str is unicode:  # Python >= 3.0
             data = data.encode('utf-8')
         sizes = [len(row) + 1 for row in data.splitlines()]
         self.assertEqual(stream.sizes, sizes)
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to