Author: cito
Date: Fri Feb  5 16:35:34 2016
New Revision: 823

Log:
Raise the proper subclasses of DatabaseError

Particularly, we raise IntegrityError instead of ProgrammingError for
duplicate keys. This also makes PyGreSQL more useable with SQLAlchemy.

Modified:
   trunk/docs/contents/changelog.rst
   trunk/pgdb.py
   trunk/pgmodule.c
   trunk/tests/test_classic.py
   trunk/tests/test_classic_connection.py
   trunk/tests/test_classic_dbwrapper.py
   trunk/tests/test_dbapi20.py
   trunk/tests/test_dbapi20_copy.py

Modified: trunk/docs/contents/changelog.rst
==============================================================================
--- trunk/docs/contents/changelog.rst   Fri Feb  5 12:05:45 2016        (r822)
+++ trunk/docs/contents/changelog.rst   Fri Feb  5 16:35:34 2016        (r823)
@@ -93,6 +93,9 @@
       also supported, but yield only an ordinary tuple containing text strings.
     - A new type helper Interval() has been added.
 - Changes concerning both modules:
+    - PyGreSQL now tries to raise more specific and appropriate subclasses of
+      DatabaseError than just ProgrammingError. Particularly, when database
+      constraints are violated, it raises an IntegrityError now.
     - The modules now provide get_typecast() and set_typecast() methods
       allowing to control the typecasting on the global level.  The connection
       objects have got type caches with the same methods which give control

Modified: trunk/pgdb.py
==============================================================================
--- trunk/pgdb.py       Fri Feb  5 12:05:45 2016        (r822)
+++ trunk/pgdb.py       Fri Feb  5 16:35:34 2016        (r823)
@@ -864,7 +864,7 @@
                     self._cnx.source().execute(sql)
                 except DatabaseError:
                     raise  # database provides error message
-                except Exception as err:
+                except Exception:
                     raise _op_error("Can't start transaction")
                 self._dbcnx._tnx = True
             for parameters in seq_of_parameters:

Modified: trunk/pgmodule.c
==============================================================================
--- trunk/pgmodule.c    Fri Feb  5 12:05:45 2016        (r822)
+++ trunk/pgmodule.c    Fri Feb  5 16:35:34 2016        (r823)
@@ -1296,42 +1296,111 @@
        PyGILState_Release(gstate);
 }
 
+/* gets appropriate error type from sqlstate */
+static PyObject *
+get_error_type(const char *sqlstate)
+{
+       switch (sqlstate[0]) {
+               case '0':
+                       switch (sqlstate[1])
+                       {
+                               case 'A':
+                                       return NotSupportedError;
+                       }
+                       break;
+               case '2':
+                       switch (sqlstate[1])
+                       {
+                               case '0':
+                               case '1':
+                                       return ProgrammingError;
+                               case '2':
+                                       return DataError;
+                               case '3':
+                                       return IntegrityError;
+                               case '4':
+                               case '5':
+                                       return InternalError;
+                               case '6':
+                               case '7':
+                               case '8':
+                                       return OperationalError;
+                               case 'B':
+                               case 'D':
+                               case 'F':
+                                       return InternalError;
+                       }
+                       break;
+               case '3':
+                       switch (sqlstate[1])
+                       {
+                               case '4':
+                                       return OperationalError;
+                               case '8':
+                               case '9':
+                               case 'B':
+                                       return InternalError;
+                               case 'D':
+                               case 'F':
+                                       return ProgrammingError;
+                       }
+                       break;
+               case '4':
+                       switch (sqlstate[1])
+                       {
+                               case '0':
+                                       return OperationalError;
+                               case '2':
+                               case '4':
+                                       return ProgrammingError;
+                       }
+                       break;
+               case '5':
+               case 'H':
+                       return OperationalError;
+               case 'F':
+               case 'P':
+               case 'X':
+                       return InternalError;
+       }
+       return DatabaseError;
+}
+
 /* sets database error with sqlstate attribute */
 /* This should be used when raising a subclass of DatabaseError */
 static void
 set_dberror(PyObject *type, const char *msg, PGresult *result)
 {
-       PyObject *err = NULL;
-       PyObject *str;
+       PyObject   *err_obj, *msg_obj, *sql_obj = NULL;
 
-       if (!(str = PyStr_FromString(msg)))
-               err = NULL;
-       else
-       {
-               err = PyObject_CallFunctionObjArgs(type, str, NULL);
-               Py_DECREF(str);
-       }
-       if (err)
+       if (result)
        {
-               if (result)
-               {
-                       char *sqlstate = PQresultErrorField(result, 
PG_DIAG_SQLSTATE);
-                       str = sqlstate ? PyStr_FromStringAndSize(sqlstate, 5) : 
NULL;
-               }
-               else
-                       str = NULL;
-               if (!str)
+               char *sqlstate = PQresultErrorField(result, PG_DIAG_SQLSTATE);
+               if (sqlstate)
                {
-                       Py_INCREF(Py_None);
-                       str = Py_None;
+                       sql_obj = PyStr_FromStringAndSize(sqlstate, 5);
+                       type = get_error_type(sqlstate);
                }
-               PyObject_SetAttrString(err, "sqlstate", str);
-               Py_DECREF(str);
-               PyErr_SetObject(type, err);
-               Py_DECREF(err);
+       }
+       if (!sql_obj)
+       {
+               Py_INCREF(Py_None);
+               sql_obj = Py_None;
+       }
+       msg_obj = PyStr_FromString(msg);
+       err_obj = PyObject_CallFunctionObjArgs(type, msg_obj, NULL);
+       if (err_obj)
+       {
+               Py_DECREF(msg_obj);
+               PyObject_SetAttrString(err_obj, "sqlstate", sql_obj);
+               Py_DECREF(sql_obj);
+               PyErr_SetObject(type, err_obj);
+               Py_DECREF(err_obj);
        }
        else
+       {
                PyErr_SetString(type, msg);
+       }
 }
 
 /* checks connection validity */

Modified: trunk/tests/test_classic.py
==============================================================================
--- trunk/tests/test_classic.py Fri Feb  5 12:05:45 2016        (r822)
+++ trunk/tests/test_classic.py Fri Feb  5 16:35:34 2016        (r823)
@@ -36,6 +36,7 @@
     db.query("SET TIME ZONE 'EST5EDT'")
     db.query("SET DEFAULT_WITH_OIDS=FALSE")
     db.query("SET STANDARD_CONFORMING_STRINGS=FALSE")
+    db.query("SET CLIENT_MIN_MESSAGES=WARNING")
     return db
 
 db = opendb()
@@ -82,7 +83,7 @@
     def test_invalidname(self):
         """Make sure that invalid table names are caught"""
         db = opendb()
-        self.assertRaises(ProgrammingError, db.get_attnames, 'x.y.z')
+        self.assertRaises(NotSupportedError, db.get_attnames, 'x.y.z')
 
     def test_schema(self):
         """Does it differentiate the same table name in different schemas"""
@@ -148,7 +149,7 @@
                 d['_test'] += 1
                 db.insert(t, d)
                 db.insert(t, d)
-        except ProgrammingError:
+        except IntegrityError:
             pass
         with db:
             d['_test'] += 1
@@ -167,8 +168,7 @@
         try:
             db.query("INSERT INTO _test_schema VALUES (1234)")
         except DatabaseError as error:
-            # currently PyGreSQL does not support IntegrityError
-            self.assertTrue(isinstance(error, ProgrammingError))
+            self.assertTrue(isinstance(error, IntegrityError))
             # the SQLSTATE error code for unique violation is 23505
             self.assertEqual(error.sqlstate, '23505')
 
@@ -340,4 +340,4 @@
     failfast = '-l' in sys.argv[1:]
     runner = unittest.TextTestRunner(verbosity=verbosity, failfast=failfast)
     rc = runner.run(suite)
-    sys.exit(1 if rc.errors or rc.failures else 0)
\ No newline at end of file
+    sys.exit(1 if rc.errors or rc.failures else 0)

Modified: trunk/tests/test_classic_connection.py
==============================================================================
--- trunk/tests/test_classic_connection.py      Fri Feb  5 12:05:45 2016        
(r822)
+++ trunk/tests/test_classic_connection.py      Fri Feb  5 16:35:34 2016        
(r823)
@@ -231,7 +231,7 @@
         def sleep():
             try:
                 self.connection.query('select pg_sleep(5)').getresult()
-            except pg.ProgrammingError as error:
+            except pg.DatabaseError as error:
                 errors.append(str(error))
 
         thread = threading.Thread(target=sleep)
@@ -331,7 +331,7 @@
 
     def testSelectDotSemicolon(self):
         q = "select .;"
-        self.assertRaises(pg.ProgrammingError, self.c.query, q)
+        self.assertRaises(pg.DatabaseError, self.c.query, q)
 
     def testGetresult(self):
         q = "select 0"
@@ -603,7 +603,7 @@
         # pass the query as unicode
         try:
             v = self.c.query(q).getresult()[0][0]
-        except pg.ProgrammingError:
+        except(pg.DataError, pg.NotSupportedError):
             self.skipTest("database does not support utf8")
         self.assertIsInstance(v, str)
         self.assertEqual(v, result)
@@ -620,7 +620,7 @@
             result = result.encode('utf8')
         try:
             v = self.c.query(q).dictresult()[0]['greeting']
-        except pg.ProgrammingError:
+        except (pg.DataError, pg.NotSupportedError):
             self.skipTest("database does not support utf8")
         self.assertIsInstance(v, str)
         self.assertEqual(v, result)
@@ -632,7 +632,7 @@
     def testDictresultLatin1(self):
         try:
             self.c.query('set client_encoding=latin1')
-        except pg.ProgrammingError:
+        except (pg.DataError, pg.NotSupportedError):
             self.skipTest("database does not support latin1")
         result = u'Hello, wörld!'
         q = u"select '%s'" % result
@@ -649,7 +649,7 @@
     def testDictresultLatin1(self):
         try:
             self.c.query('set client_encoding=latin1')
-        except pg.ProgrammingError:
+        except (pg.DataError, pg.NotSupportedError):
             self.skipTest("database does not support latin1")
         result = u'Hello, wörld!'
         q = u"select '%s' as greeting" % result
@@ -666,7 +666,7 @@
     def testGetresultCyrillic(self):
         try:
             self.c.query('set client_encoding=iso_8859_5')
-        except pg.ProgrammingError:
+        except (pg.DataError, pg.NotSupportedError):
             self.skipTest("database does not support cyrillic")
         result = u'Hello, мир!'
         q = u"select '%s'" % result
@@ -683,7 +683,7 @@
     def testDictresultCyrillic(self):
         try:
             self.c.query('set client_encoding=iso_8859_5')
-        except pg.ProgrammingError:
+        except (pg.DataError, pg.NotSupportedError):
             self.skipTest("database does not support cyrillic")
         result = u'Hello, мир!'
         q = u"select '%s' as greeting" % result
@@ -700,7 +700,7 @@
     def testGetresultLatin9(self):
         try:
             self.c.query('set client_encoding=latin9')
-        except pg.ProgrammingError:
+        except (pg.DataError, pg.NotSupportedError):
             self.skipTest("database does not support latin9")
         result = u'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)'
         q = u"select '%s'" % result
@@ -717,7 +717,7 @@
     def testDictresultLatin9(self):
         try:
             self.c.query('set client_encoding=latin9')
-        except pg.ProgrammingError:
+        except (pg.DataError, pg.NotSupportedError):
             self.skipTest("database does not support latin9")
         result = u'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)'
         q = u"select '%s' as menu" % result
@@ -818,6 +818,10 @@
             ).getresult(), [('Hello', 'world')])
         self.assertEqual(query("select $1::text union select $2::text",
             ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
+        try:
+            query("select 'wörld'")
+        except (pg.DataError, pg.NotSupportedError):
+            self.skipTest('database does not support utf8')
         self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
             'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
 
@@ -826,7 +830,7 @@
         try:
             query('set client_encoding=utf8')
             query("select 'wörld'").getresult()[0][0] == 'wörld'
-        except pg.ProgrammingError:
+        except (pg.DataError, pg.NotSupportedError):
             self.skipTest("database does not support utf8")
         self.assertEqual(query("select $1||', '||$2||'!'",
             ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)])
@@ -836,7 +840,7 @@
         try:
             query('set client_encoding=latin1')
             query("select 'wörld'").getresult()[0][0] == 'wörld'
-        except pg.ProgrammingError:
+        except (pg.DataError, pg.NotSupportedError):
             self.skipTest("database does not support latin1")
         r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult()
         if unicode_strings:
@@ -863,7 +867,7 @@
         try:
             query('set client_encoding=iso_8859_5')
             query("select 'мир'").getresult()[0][0] == 'мир'
-        except pg.ProgrammingError:
+        except (pg.DataError, pg.NotSupportedError):
             self.skipTest("database does not support cyrillic")
         self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
             ('Hello', u'wörld'))
@@ -912,7 +916,11 @@
 
     def assert_proper_cast(self, value, pgtype, pytype):
         q = 'select $1::%s' % (pgtype,)
-        r = self.c.query(q, (value,)).getresult()[0][0]
+        try:
+            r = self.c.query(q, (value,)).getresult()[0][0]
+        except pg.ProgrammingError:
+            if pgtype in ('json', 'jsonb'):
+                self.skipTest('database does not support json')
         self.assertIsInstance(r, pytype)
         if isinstance(value, str):
             if not value or ' ' in value or '{' in value:
@@ -994,8 +1002,13 @@
         # Check whether the test database uses SQL_ASCII - this means
         # that it does not consider encoding when calculating lengths.
         c.query("set client_encoding=utf8")
-        cls.has_encoding = c.query(
-            "select length('ä') - length('a')").getresult()[0][0] == 0
+        try:
+            c.query("select 'ä'")
+        except (pg.DataError, pg.NotSupportedError):
+            cls.has_encoding = False
+        else:
+            cls.has_encoding = c.query(
+                "select length('ä') - length('a')").getresult()[0][0] == 0
         c.close()
         cls.cls_set_up = True
 
@@ -1122,7 +1135,7 @@
     def testInserttableByteValues(self):
         try:
             self.c.query("select '€', 'käse', 'сыр', 'pont-l''évêque'")
-        except pg.ProgrammingError:
+        except pg.DataError:
             self.skipTest("database does not support utf8")
         # non-ascii chars do not fit in char(1) when there is no encoding
         c = u'€' if self.has_encoding else u'$'
@@ -1140,7 +1153,7 @@
     def testInserttableUnicodeUtf8(self):
         try:
             self.c.query("select '€', 'käse', 'сыр', 'pont-l''évêque'")
-        except pg.ProgrammingError:
+        except pg.DataError:
             self.skipTest("database does not support utf8")
         # non-ascii chars do not fit in char(1) when there is no encoding
         c = u'€' if self.has_encoding else u'$'
@@ -1160,7 +1173,7 @@
         try:
             self.c.query("set client_encoding=latin1")
             self.c.query("select '¥'")
-        except pg.ProgrammingError:
+        except (pg.DataError, pg.NotSupportedError):
             self.skipTest("database does not support latin1")
         # non-ascii chars do not fit in char(1) when there is no encoding
         c = u'€' if self.has_encoding else u'$'
@@ -1184,7 +1197,7 @@
         try:
             self.c.query("set client_encoding=latin9")
             self.c.query("select '€'")
-        except pg.ProgrammingError:
+        except (pg.DataError, pg.NotSupportedError):
             self.skipTest("database does not support latin9")
             return
         # non-ascii chars do not fit in char(1) when there is no encoding
@@ -1257,6 +1270,10 @@
     def testPutlineBytesAndUnicode(self):
         putline = self.c.putline
         query = self.c.query
+        try:
+            query("select 'käse+würstel'")
+        except (pg.DataError, pg.NotSupportedError):
+            self.skipTest('database does not support utf8')
         query("copy test from stdin")
         try:
             putline(u"47\tkäse\n".encode('utf8'))
@@ -1292,6 +1309,10 @@
     def testGetlineBytesAndUnicode(self):
         getline = self.c.getline
         query = self.c.query
+        try:
+            query("select 'käse+würstel'")
+        except (pg.DataError, pg.NotSupportedError):
+            self.skipTest('database does not support utf8')
         data = [(54, u'käse'.encode('utf8')), (73, u'würstel')]
         self.c.inserttable('test', data)
         query("copy test to stdout")
@@ -1474,7 +1495,7 @@
         for lc in en_locales:
             try:
                 query("set lc_monetary='%s'" % lc)
-            except pg.ProgrammingError:
+            except pg.DataError:
                 pass
             else:
                 break
@@ -1482,7 +1503,7 @@
             self.skipTest("cannot set English money locale")
         try:
             r = query(select_money)
-        except pg.ProgrammingError:
+        except pg.DataError:
             # this can happen if the currency signs cannot be
             # converted using the encoding of the test database
             self.skipTest("database does not support English money")
@@ -1529,7 +1550,7 @@
         for lc in de_locales:
             try:
                 query("set lc_monetary='%s'" % lc)
-            except pg.ProgrammingError:
+            except pg.DataError:
                 pass
             else:
                 break
@@ -1538,7 +1559,7 @@
         select_money = select_money.replace('.', ',')
         try:
             r = query(select_money)
-        except pg.ProgrammingError:
+        except pg.DataError:
             self.skipTest("database does not support English money")
         pg.set_decimal_point(None)
         try:
@@ -1599,7 +1620,7 @@
         query = self.c.query
         try:
             r = query("select 3425::numeric")
-        except pg.ProgrammingError:
+        except pg.DatabaseError:
             self.skipTest('database does not support numeric')
         r = r.getresult()[0][0]
         self.assertIsInstance(r, decimal_class)

Modified: trunk/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/tests/test_classic_dbwrapper.py       Fri Feb  5 12:05:45 2016        
(r822)
+++ trunk/tests/test_classic_dbwrapper.py       Fri Feb  5 16:35:34 2016        
(r823)
@@ -306,10 +306,10 @@
     def testMethodQueryEmpty(self):
         self.assertRaises(ValueError, self.db.query, '')
 
-    def testMethodQueryProgrammingError(self):
+    def testMethodQueryDataError(self):
         try:
             self.db.query("select 1/0")
-        except pg.ProgrammingError as error:
+        except pg.DataError as error:
             self.assertEqual(error.sqlstate, '22012')
 
     def testMethodEndcopy(self):
@@ -873,10 +873,10 @@
     def testEmptyQuery(self):
         self.assertRaises(ValueError, self.db.query, '')
 
-    def testQueryProgrammingError(self):
+    def testQueryDataError(self):
         try:
             self.db.query("select 1/0")
-        except pg.ProgrammingError as error:
+        except pg.DataError as error:
             self.assertEqual(error.sqlstate, '22012')
 
     def testQueryFormatted(self):
@@ -1212,7 +1212,7 @@
         self.assertEqual(can('test', 'insert'), True)
         self.assertEqual(can('test', 'update'), True)
         self.assertEqual(can('test', 'delete'), True)
-        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
+        self.assertRaises(pg.DataError, can, 'test', 'foobar')
         self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
         r = self.db.query('select rolsuper FROM pg_roles'
             ' where rolname=current_user').getresult()[0][0]
@@ -1583,9 +1583,9 @@
         r = insert('test_table', dict(n=7))
         self.assertIsInstance(r, dict)
         self.assertEqual(r['n'], 7)
-        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r)
+        self.assertRaises(pg.IntegrityError, insert, 'test_table', r)
         r['n'] = 6
-        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r, n=7)
+        self.assertRaises(pg.IntegrityError, insert, 'test_table', r, n=7)
         self.assertIsInstance(r, dict)
         self.assertEqual(r['n'], 7)
         r['n'] = 6
@@ -1633,7 +1633,7 @@
         r = dict(i4=5678, v4='efgh')
         try:
             insert('test_view', r)
-        except pg.ProgrammingError as error:
+        except pg.NotSupportedError as error:
             if self.db.server_version < 90300:
                 # must setup rules in older PostgreSQL versions
                 self.skipTest('database cannot insert into view')
@@ -2273,9 +2273,9 @@
         q = ("select (select count(*) from test_parent),"
              " (select count(*) from test_child)")
         self.assertEqual(query(q).getresult()[0], (3, 3))
-        self.assertRaises(pg.ProgrammingError,
+        self.assertRaises(pg.IntegrityError,
                           delete, 'test_parent', None, n=2)
-        self.assertRaises(pg.ProgrammingError,
+        self.assertRaises(pg.IntegrityError,
                           delete, 'test_parent *', None, n=2)
         r = delete('test_child', None, n=2)
         self.assertEqual(r, 1)
@@ -2283,9 +2283,9 @@
         r = delete('test_parent', None, n=2)
         self.assertEqual(r, 1)
         self.assertEqual(query(q).getresult()[0], (2, 2))
-        self.assertRaises(pg.ProgrammingError,
+        self.assertRaises(pg.IntegrityError,
                           delete, 'test_parent', dict(n=0))
-        self.assertRaises(pg.ProgrammingError,
+        self.assertRaises(pg.IntegrityError,
                           delete, 'test_parent *', dict(n=0))
         r = delete('test_child', dict(n=0))
         self.assertEqual(r, 1)
@@ -2372,7 +2372,7 @@
              " (select count(*) from test_child)")
         r = query(q).getresult()[0]
         self.assertEqual(r, (3, 3))
-        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
+        self.assertRaises(pg.NotSupportedError, truncate, 'test_parent')
         truncate(['test_parent', 'test_child'])
         r = query(q).getresult()[0]
         self.assertEqual(r, (0, 0))
@@ -2392,7 +2392,7 @@
         truncate('test_child')
         r = query(q).getresult()[0]
         self.assertEqual(r, (3, 0))
-        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
+        self.assertRaises(pg.NotSupportedError, truncate, 'test_parent')
         truncate('test_parent', cascade=True)
         r = query(q).getresult()[0]
         self.assertEqual(r, (0, 0))
@@ -2777,7 +2777,7 @@
         self.db.savepoint('before8')
         query("insert into test_table values (8)")
         self.db.release('before8')
-        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
+        self.assertRaises(pg.InternalError, self.db.rollback, 'before8')
         self.db.commit()
         self.db.start()
         query("insert into test_table values (9)")
@@ -2786,11 +2786,11 @@
             "select * from test_table order by 1").getresult()]
         self.assertEqual(r, [1, 2, 5, 7, 9])
         self.db.begin(mode='read only')
-        self.assertRaises(pg.ProgrammingError,
+        self.assertRaises(pg.InternalError,
                           query, "insert into test_table values (0)")
         self.db.rollback()
         self.db.start(mode='Read Only')
-        self.assertRaises(pg.ProgrammingError,
+        self.assertRaises(pg.InternalError,
                           query, "insert into test_table values (0)")
         self.db.abort()
 
@@ -2818,7 +2818,7 @@
             with self.db:
                 query("insert into test_table values (6)")
                 query("insert into test_table values (-1)")
-        except pg.ProgrammingError as error:
+        except pg.IntegrityError as error:
             self.assertTrue('check' in str(error))
         with self.db:
             query("insert into test_table values (7)")
@@ -3166,7 +3166,7 @@
             self.assertEqual(r['i'], '{1,2,3}')
             self.assertEqual(r['t'], '{a,b,c}')
         r = dict(i="1, 2, 3", t="'a', 'b', 'c'")
-        self.assertRaises(pg.ProgrammingError, self.db.insert, 'arraytest', r)
+        self.assertRaises(pg.DataError, self.db.insert, 'arraytest', r)
 
     def testArrayOfIds(self):
         array_on = pg.get_array()
@@ -3689,10 +3689,10 @@
     def testHstore(self):
         try:
             self.db.query("select 'k=>v'::hstore")
-        except pg.ProgrammingEror:
+        except pg.DatabaseError:
             try:
                 self.db.query("create extension hstore")
-            except pg.ProgrammingError:
+            except pg.DatabaseError:
                 self.skipTest("hstore extension not enabled")
         d = {'k': 'v', 'foo': 'bar', 'baz': 'whatever',
             '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3',

Modified: trunk/tests/test_dbapi20.py
==============================================================================
--- trunk/tests/test_dbapi20.py Fri Feb  5 12:05:45 2016        (r822)
+++ trunk/tests/test_dbapi20.py Fri Feb  5 16:35:34 2016        (r823)
@@ -463,13 +463,27 @@
         finally:
             con.close()
 
+    def test_integrity_error(self):
+        table = self.table_prefix + 'booze'
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            cur.execute("set client_min_messages = warning")
+            cur.execute("create table %s (i int primary key)" % table)
+            cur.execute("insert into %s values (1)" % table)
+            cur.execute("insert into %s values (2)" % table)
+            self.assertRaises(pgdb.IntegrityError, cur.execute,
+                "insert into %s values (1)" % table)
+        finally:
+            con.close()
+
     def test_sqlstate(self):
         con = self._connect()
         cur = con.cursor()
         try:
             cur.execute("select 1/0")
         except pgdb.DatabaseError as error:
-            self.assertTrue(isinstance(error, pgdb.ProgrammingError))
+            self.assertTrue(isinstance(error, pgdb.DataError))
             # the SQLSTATE error code for division by zero is 22012
             self.assertEqual(error.sqlstate, '22012')
 
@@ -602,10 +616,10 @@
         try:
             cur = con.cursor()
             cur.execute("select 'k=>v'::hstore")
-        except pgdb.ProgrammingError:
+        except pgdb.DatabaseError:
             try:
                 cur.execute("create extension hstore")
-            except pgdb.ProgrammingError:
+            except pgdb.DatabaseError:
                 self.skipTest("hstore extension not enabled")
         finally:
             con.close()
@@ -997,7 +1011,7 @@
             cur.execute(sql, [(1,), (2,)])  # deprecated use of execute()
             self.assertEqual(cur.fetchone()[0], 3)
             sql = 'select 1/0'  # cannot be executed
-            self.assertRaises(pgdb.ProgrammingError, cur.execute, sql)
+            self.assertRaises(pgdb.DataError, cur.execute, sql)
             cur.close()
             con.rollback()
             if pgdb.shortcutmethods:
@@ -1069,7 +1083,7 @@
                 with con:
                     cur.execute("insert into %s values (3)" % table)
                     cur.execute("insert into %s values (4)" % table)
-            except con.ProgrammingError as error:
+            except con.IntegrityError as error:
                 self.assertTrue('check' in str(error).lower())
             with con:
                 cur.execute("insert into %s values (5)" % table)

Modified: trunk/tests/test_dbapi20_copy.py
==============================================================================
--- trunk/tests/test_dbapi20_copy.py    Fri Feb  5 12:05:45 2016        (r822)
+++ trunk/tests/test_dbapi20_copy.py    Fri Feb  5 16:35:34 2016        (r823)
@@ -139,6 +139,15 @@
             "id smallint primary key, name varchar(64))")
         cur.close()
         con.commit()
+        cur = con.cursor()
+        try:
+            cur.execute("set client_encoding=utf8")
+            cur.execute("select 'Plácido and José'").fetchone()
+        except (pgdb.DataError, pgdb.NotSupportedError):
+            cls.data[1] = (1941, 'Plaacido Domingo')
+            cls.data[2] = (1946, 'Josee Carreras')
+            cls.can_encode = False
+        cur.close()
         con.close()
         cls.cls_set_up = True
 
@@ -175,6 +184,8 @@
             (1941, 'Plácido Domingo'),
             (1946, 'José Carreras')]
 
+    can_encode = True
+
     @property
     def data_text(self):
         return ''.join('%d\t%s\n' % row for row in self.data)
@@ -265,6 +276,8 @@
     else:  # Python < 3.0
 
         def test_input_unicode(self):
+            if not self.can_encode:
+                self.skipTest('database does not support utf8')
             self.copy_from(u'43\tWürstel, Käse!')
             self.assertEqual(self.table_data, [(43, 'Würstel, Käse!')])
             self.truncate_table()
@@ -394,6 +407,7 @@
         super(TestCopyTo, cls).setUpClass()
         con = cls.connect()
         cur = con.cursor()
+        cur.execute("set client_encoding=utf8")
         cur.execute("insert into copytest values (%d, %s)", cls.data)
         cur.close()
         con.commit()
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to