Author: Brian Kearns <bdkea...@gmail.com>
Branch: py3k
Changeset: r62126:878a364bd84f
Date: 2013-03-06 03:45 -0500
http://bitbucket.org/pypy/pypy/changeset/878a364bd84f/

Log:    merge default

diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py
--- a/lib_pypy/_sqlite3.py
+++ b/lib_pypy/_sqlite3.py
@@ -291,7 +291,7 @@
     return str(x, 'utf-8')
 
 
-class StatementCache(object):
+class _StatementCache(object):
     def __init__(self, connection, maxcount):
         self.connection = connection
         self.maxcount = maxcount
@@ -305,7 +305,7 @@
             self.cache[sql] = stat
             if len(self.cache) > self.maxcount:
                 self.cache.popitem(0)
-        #
+
         if stat.in_use:
             stat = Statement(self.connection, sql)
         stat.set_row_factory(row_factory)
@@ -337,7 +337,7 @@
         self._cursors = []
         self.__statements = []
         self.__statement_counter = 0
-        self._statement_cache = StatementCache(self, cached_statements)
+        self._statement_cache = _StatementCache(self, cached_statements)
 
         self.__func_cache = {}
         self.__aggregates = {}
@@ -383,10 +383,10 @@
 
     def _check_closed_wrap(func):
         @wraps(func)
-        def _check_closed_func(self, *args, **kwargs):
+        def wrapper(self, *args, **kwargs):
             self._check_closed()
             return func(self, *args, **kwargs)
-        return _check_closed_func
+        return wrapper
 
     def _check_thread(self):
         try:
@@ -402,10 +402,10 @@
 
     def _check_thread_wrap(func):
         @wraps(func)
-        def _check_thread_func(self, *args, **kwargs):
+        def wrapper(self, *args, **kwargs):
             self._check_thread()
             return func(self, *args, **kwargs)
-        return _check_thread_func
+        return wrapper
 
     def _get_exception(self, error_code=None):
         if error_code is None:
@@ -449,8 +449,7 @@
     def __call__(self, sql):
         if not isinstance(sql, str):
             raise Warning("SQL is of wrong type. Must be string or unicode.")
-        statement = self._statement_cache.get(sql, self.row_factory)
-        return statement
+        return self._statement_cache.get(sql, self.row_factory)
 
     def cursor(self, factory=None):
         self._check_thread()
@@ -749,20 +748,6 @@
                 raise OperationalError("Error enabling load extension")
 
 
-class _CursorLock(object):
-    def __init__(self, cursor):
-        self.cursor = cursor
-
-    def __enter__(self):
-        self.cursor._check_closed()
-        if self.cursor._locked:
-            raise ProgrammingError("Recursive use of cursors not allowed.")
-        self.cursor._locked = True
-
-    def __exit__(self, *args):
-        self.cursor._locked = False
-
-
 class Cursor(object):
     __initialized = False
     __connection = None
@@ -780,8 +765,8 @@
 
         self.arraysize = 1
         self.row_factory = None
-        self._locked = False
         self._reset = False
+        self.__locked = False
         self.__closed = False
         self.__description = None
         self.__rowcount = -1
@@ -803,16 +788,20 @@
             self.__statement = None
         self.__closed = True
 
-    def _check_closed(self):
+    def __check_cursor(self):
         if not self.__initialized:
             raise ProgrammingError("Base Cursor.__init__ not called.")
         if self.__closed:
             raise ProgrammingError("Cannot operate on a closed cursor.")
+        if self.__locked:
+            raise ProgrammingError("Recursive use of cursors not allowed.")
         self.__connection._check_thread()
         self.__connection._check_closed()
 
     def execute(self, sql, params=None):
-        with _CursorLock(self):
+        self.__check_cursor()
+        self.__locked = True
+        try:
             self.__description = None
             self._reset = False
             self.__statement = self.__connection._statement_cache.get(
@@ -849,11 +838,15 @@
             self.__rowcount = -1
             if self.__statement.kind == _DML:
                 self.__rowcount = sqlite.sqlite3_changes(self.__connection._db)
+        finally:
+            self.__locked = False
 
         return self
 
     def executemany(self, sql, many_params):
-        with _CursorLock(self):
+        self.__check_cursor()
+        self.__locked = True
+        try:
             self.__description = None
             self._reset = False
             self.__statement = self.__connection._statement_cache.get(
@@ -877,6 +870,8 @@
                     raise self.__connection._get_exception(ret)
                 self.__rowcount += 
sqlite.sqlite3_changes(self.__connection._db)
             self.__statement.reset()
+        finally:
+            self.__locked = False
 
         return self
 
@@ -885,7 +880,7 @@
         self._reset = False
         if type(sql) is str:
             sql = sql.encode("utf-8")
-        self._check_closed()
+        self.__check_cursor()
         statement = c_void_p()
         c_sql = c_char_p(sql)
 
@@ -916,16 +911,16 @@
                 break
         return self
 
-    def _check_reset(self):
+    def __check_reset(self):
         if self._reset:
-            raise self.__connection.InterfaceError("Cursor needed to be reset 
because "
-                                                 "of commit/rollback and can "
-                                                 "no longer be fetched from.")
+            raise self.__connection.InterfaceError(
+                    "Cursor needed to be reset because of commit/rollback "
+                    "and can no longer be fetched from.")
 
     # do all statements
     def fetchone(self):
-        self._check_closed()
-        self._check_reset()
+        self.__check_cursor()
+        self.__check_reset()
 
         if self.__statement is None:
             return None
@@ -936,8 +931,8 @@
             return None
 
     def fetchmany(self, size=None):
-        self._check_closed()
-        self._check_reset()
+        self.__check_cursor()
+        self.__check_reset()
         if self.__statement is None:
             return []
         if size is None:
@@ -950,8 +945,8 @@
         return lst
 
     def fetchall(self):
-        self._check_closed()
-        self._check_reset()
+        self.__check_cursor()
+        self.__check_reset()
         if self.__statement is None:
             return []
         return list(self)
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to