Author: cito
Date: Tue Nov  6 10:31:03 2012
New Revision: 463

Log:
Connection as context manager for transactions in pgdb.

Modified:
   trunk/module/TEST_PyGreSQL_dbapi20.py
   trunk/module/dbapi20.py
   trunk/module/pgdb.py

Modified: trunk/module/TEST_PyGreSQL_dbapi20.py
==============================================================================
--- trunk/module/TEST_PyGreSQL_dbapi20.py       Thu Nov  1 07:44:07 2012        
(r462)
+++ trunk/module/TEST_PyGreSQL_dbapi20.py       Tue Nov  6 10:31:03 2012        
(r463)
@@ -1,6 +1,8 @@
 #!/usr/bin/env python
 # $Id$
 
+from __future__ import with_statement
+
 import unittest
 import dbapi20
 import pgdb
@@ -47,16 +49,16 @@
                 return d
 
         con = self._connect()
-        curs = myCursor(con)
-        ret = curs.execute("select 1 as a, 2 as b")
-        self.assert_(ret is curs, 'execute() should return cursor')
-        self.assertEqual(curs.fetchone(), {'a': 1, 'b': 2})
+        cur = myCursor(con)
+        ret = cur.execute("select 1 as a, 2 as b")
+        self.assert_(ret is cur, 'execute() should return cursor')
+        self.assertEqual(cur.fetchone(), {'a': 1, 'b': 2})
 
     def test_cursor_iteration(self):
         con = self._connect()
-        curs = con.cursor()
-        curs.execute("select 1 union select 2 union select 3")
-        self.assertEqual([r[0] for r in curs], [1, 2, 3])
+        cur = con.cursor()
+        cur.execute("select 1 union select 2 union select 3")
+        self.assertEqual([r[0] for r in cur], [1, 2, 3])
 
     def test_fetch_2_rows(self):
         Decimal = pgdb.decimal_type()
@@ -84,7 +86,7 @@
                 "rowidtest oid)" % table)
             for s in ('numeric', 'monetary', 'time'):
                 cur.execute("set lc_%s to 'C'" % s)
-            for i in range(2):
+            for _i in range(2):
                 cur.execute("insert into %s values ("
                     "%%s,%%s,%%s,%%s,%%s,%%s,%%s,"
                     "'%%s'::money,%%s,%%s,%%s,%%s,%%s)" % table, values)
@@ -128,30 +130,31 @@
         self.assert_(isnan(nan) and not isinf(nan))
         self.assert_(isinf(inf) and not isnan(inf))
         values = [0, 1, 0.03125, -42.53125, nan, inf, -inf]
-        table = self.table_prefix + 'float'
+        table = self.table_prefix + 'booze'
         con = self._connect()
         try:
             cur = con.cursor()
-            cur.execute("create table %s (floattest float)" % table)
-            params = [(val,) for val in values]
-            cur.executemany("insert into %s values(%%s)" % table, params)
-            cur.execute("select * from %s" % table)
+            cur.execute(
+                "create table %s (n smallint, floattest float)" % table)
+            params = enumerate(values)
+            cur.executemany("insert into %s values(%%s,%%s)" % table, params)
+            cur.execute("select * from %s order by 1" % table)
             rows = cur.fetchall()
-            self.assertEqual(len(rows), len(values))
-            rows = [row[0] for row in rows]
-            for inval, outval in zip(values, rows):
-                if isinf(inval):
-                    self.assert_(isinf(outval))
-                    if inval < 0:
-                        self.assert_(outval < 0)
-                    else:
-                        self.assert_(outval > 0)
-                elif isnan(inval):
-                    self.assert_(isnan(outval))
-                else:
-                    self.assertEqual(inval, outval)
         finally:
             con.close()
+        self.assertEqual(len(rows), len(values))
+        rows = [row[1] for row in rows]
+        for inval, outval in zip(values, rows):
+            if isinf(inval):
+                self.assert_(isinf(outval))
+                if inval < 0:
+                    self.assert_(outval < 0)
+                else:
+                    self.assert_(outval > 0)
+            elif isnan(inval):
+                self.assert_(isnan(outval))
+            else:
+                self.assertEqual(inval, outval)
 
     def test_set_decimal_type(self):
         decimal_type = pgdb.decimal_type()
@@ -175,10 +178,12 @@
         self.assert_(pgdb.decimal_type() is decimal_type)
 
     def test_nextset(self):
-        pass  # not implemented
+        con = self._connect()
+        cur = con.cursor()
+        self.assertRaises(con.NotSupportedError, cur.nextset)
 
     def test_setoutputsize(self):
-        pass  # not implemented
+        pass  # not supported
 
     def test_connection_errors(self):
         con = self._connect()
@@ -193,10 +198,50 @@
         self.assertEqual(con.DataError, pgdb.DataError)
         self.assertEqual(con.NotSupportedError, pgdb.NotSupportedError)
 
+    def test_connection_as_contextmanager(self):
+        table = self.table_prefix + 'booze'
+        con = self._connect()
+        try:
+            cur = con.cursor()
+            cur.execute("create table %s (n smallint check(n!=4))" % table)
+            with con:
+                cur.execute("insert into %s values (1)" % table)
+                cur.execute("insert into %s values (2)" % table)
+            try:
+                with con:
+                    cur.execute("insert into %s values (3)" % table)
+                    cur.execute("insert into %s values (4)" % table)
+            except con.ProgrammingError, error:
+                self.assertTrue('check' in str(error).lower())
+            with con:
+                cur.execute("insert into %s values (5)" % table)
+                cur.execute("insert into %s values (6)" % table)
+            try:
+                with con:
+                    cur.execute("insert into %s values (7)" % table)
+                    cur.execute("insert into %s values (8)" % table)
+                    raise ValueError('transaction should rollback')
+            except ValueError, error:
+                self.assertEqual(str(error), 'transaction should rollback')
+            with con:
+                cur.execute("insert into %s values (9)" % table)
+            cur.execute("select * from %s order by 1" % table)
+            rows = cur.fetchall()
+            rows = [row[0] for row in rows]
+        finally:
+            con.close()
+        self.assertEqual(rows, [1, 2, 5, 6, 9])
+
     def test_cursor_connection(self):
         con = self._connect()
-        curs = con.cursor()
-        self.assertEqual(curs.connection, con)
+        cur = con.cursor()
+        self.assertEqual(cur.connection, con)
+        cur.close()
+
+    def test_cursor_as_contextmanager(self):
+        con = self._connect()
+        with con.cursor() as cur:
+            self.assertEqual(cur.connection, con)
 
     def test_pgdb_type(self):
         self.assertEqual(pgdb.STRING, pgdb.STRING)

Modified: trunk/module/dbapi20.py
==============================================================================
--- trunk/module/dbapi20.py     Thu Nov  1 07:44:07 2012        (r462)
+++ trunk/module/dbapi20.py     Tue Nov  6 10:31:03 2012        (r463)
@@ -703,25 +703,30 @@
         finally:
             con.close()
 
-    def help_nextset_setUp(self,cur):
+    def help_nextset_setUp(self, cur):
         """Should create a procedure called deleteme
             that returns two result sets, first the
             number of rows in booze then "name from booze"
         """
-        raise NotImplementedError('Helper not implemented')
-        #sql="""
-        #    create procedure deleteme as
-        #    begin
-        #        select count(*) from booze
-        #        select name from booze
-        #    end
-        #"""
-        #cur.execute(sql)
+        if False:
+            sql = """
+                create procedure deleteme as
+                begin
+                    select count(*) from booze
+                    select name from booze
+                end
+            """
+            cur.execute(sql)
+        else:
+            raise NotImplementedError('Helper not implemented')
 
-    def help_nextset_tearDown(self,cur):
+    def help_nextset_tearDown(self, cur):
         """If cleaning up is needed after nextSetTest"""
-        raise NotImplementedError('Helper not implemented')
-        #cur.execute("drop procedure deleteme")
+        if False:
+            cur.execute("drop procedure deleteme")
+        else:
+
+            raise NotImplementedError('Helper not implemented')
 
     def test_nextset(self):
         con = self._connect()
@@ -752,9 +757,6 @@
         finally:
             con.close()
 
-    def test_nextset(self):
-        raise NotImplementedError('Drivers need to override this test')
-
     def test_arraysize(self):
         """Not much here - rest of the tests for this are in test_fetchmany"""
         con = self._connect()

Modified: trunk/module/pgdb.py
==============================================================================
--- trunk/module/pgdb.py        Thu Nov  1 07:44:07 2012        (r462)
+++ trunk/module/pgdb.py        Tue Nov  6 10:31:03 2012        (r463)
@@ -90,7 +90,7 @@
 
 ### Module Constants
 
-# compliant with DB SIG 2.0
+# compliant with DB API 2.0
 apilevel = '2.0'
 
 # module may be shared, but not connections
@@ -99,6 +99,12 @@
 # this module use extended python format codes
 paramstyle = 'pyformat'
 
+# shortcut methods are not supported by default
+# since they have been excluded from DB API 2
+# and are not recommended by the DB SIG;
+
+shortcutmethods = 0
+
 
 ### Internal Types Handling
 
@@ -208,7 +214,7 @@
 ### Cursor Object
 
 class pgdbCursor(object):
-    """Cursor Object."""
+    """Cursor object."""
 
     def __init__(self, dbcnx):
         """Create a cursor object for the database connection."""
@@ -418,19 +424,19 @@
 
     def setinputsizes(sizes):
         """Not supported."""
-        pass
+        pass  # unsupported, but silently passed
     setinputsizes = staticmethod(setinputsizes)
 
     def setoutputsize(size, column=0):
         """Not supported."""
-        pass
+        pass  # unsupported, but silently passed
     setoutputsize = staticmethod(setoutputsize)
 
 
 ### Connection Objects
 
 class pgdbCnx(object):
-    """Connection Object."""
+    """Connection object."""
 
     # expose the exceptions as attributes on the connection object
     Error = Error
@@ -455,12 +461,23 @@
             raise _op_error("invalid connection")
 
     def __enter__(self):
-        """Enter the runtime context for the connection object."""
+        """Enter the runtime context for the connection object.
+
+        The runtime context can be used for running transactions.
+
+        """
         return self
 
     def __exit__(self, et, ev, tb):
-        """Exit the runtime context for the connection object."""
-        self.close()
+        """Exit the runtime context for the connection object.
+
+        This does not close the connection, but it ends a transaction.
+
+        """
+        if et is None and ev is None and tb is None:
+            self.commit()
+        else:
+            self.rollback()
 
     def close(self):
         """Close the connection object."""
@@ -504,7 +521,7 @@
             raise _op_error("connection has been closed")
 
     def cursor(self):
-        """Return a new Cursor Object using the connection."""
+        """Return a new cursor object using the connection."""
         if self._cnx:
             try:
                 return pgdbCursor(self)
@@ -513,6 +530,20 @@
         else:
             raise _op_error("connection has been closed")
 
+    if shortcutmethods:  # otherwise do not implement and document this
+
+        def execute(self, operation, params=None):
+            """Shortcut method to run an operation on an implicit cursor."""
+            cursor = self.cursor()
+            cursor.execute(operation, params)
+            return cursor
+
+        def executemany(self, operation, param_seq):
+            """Shortcut method to run an operation against a sequence."""
+            cursor = self.cursor()
+            cursor.executemany(operation, param_seq)
+            return cursor
+
 
 ### Module Interface
 
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to