Author: Brian Kearns <bdkea...@gmail.com>
Branch: 
Changeset: r63020:1e485fc86788
Date: 2013-04-04 16:54 -0400
http://bitbucket.org/pypy/pypy/changeset/1e485fc86788/

Log:    also store row_cast_map on cursor not statement, don't call
        row_factory unless row is actually fetched

diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py
--- a/lib_pypy/_sqlite3.py
+++ b/lib_pypy/_sqlite3.py
@@ -377,7 +377,7 @@
         self.maxcount = maxcount
         self.cache = OrderedDict()
 
-    def get(self, sql, row_factory):
+    def get(self, sql):
         try:
             stat = self.cache[sql]
         except KeyError:
@@ -389,7 +389,6 @@
             if stat._in_use:
                 stat = Statement(self.connection, sql)
                 self.cache[sql] = stat
-        stat._row_factory = row_factory
         return stat
 
 
@@ -552,7 +551,7 @@
     @_check_thread_wrap
     @_check_closed_wrap
     def __call__(self, sql):
-        return self._statement_cache.get(sql, self.row_factory)
+        return self._statement_cache.get(sql)
 
     def cursor(self, factory=None):
         self._check_thread()
@@ -881,20 +880,96 @@
             return func(self, *args, **kwargs)
         return wrapper
 
+    def __check_reset(self):
+        if self._reset:
+            raise InterfaceError(
+                    "Cursor needed to be reset because of commit/rollback "
+                    "and can no longer be fetched from.")
+
+    def __build_row_cast_map(self):
+        if not self.__connection._detect_types:
+            return
+        self.__row_cast_map = []
+        for i in 
xrange(_lib.sqlite3_column_count(self.__statement._statement)):
+            converter = None
+
+            if self.__connection._detect_types & PARSE_COLNAMES:
+                colname = 
_lib.sqlite3_column_name(self.__statement._statement, i)
+                if colname:
+                    colname = _ffi.string(colname).decode('utf-8')
+                    type_start = -1
+                    key = None
+                    for pos in range(len(colname)):
+                        if colname[pos] == '[':
+                            type_start = pos + 1
+                        elif colname[pos] == ']' and type_start != -1:
+                            key = colname[type_start:pos]
+                            converter = converters[key.upper()]
+
+            if converter is None and self.__connection._detect_types & 
PARSE_DECLTYPES:
+                decltype = 
_lib.sqlite3_column_decltype(self.__statement._statement, i)
+                if decltype:
+                    decltype = _ffi.string(decltype).decode('utf-8')
+                    # if multiple words, use first, eg.
+                    # "INTEGER NOT NULL" => "INTEGER"
+                    decltype = decltype.split()[0]
+                    if '(' in decltype:
+                        decltype = decltype[:decltype.index('(')]
+                    converter = converters.get(decltype.upper(), None)
+
+            self.__row_cast_map.append(converter)
+
+    def __fetch_one_row(self):
+        row = []
+        num_cols = _lib.sqlite3_data_count(self.__statement._statement)
+        for i in xrange(num_cols):
+            if self.__connection._detect_types:
+                converter = self.__row_cast_map[i]
+            else:
+                converter = None
+
+            if converter is not None:
+                blob = _lib.sqlite3_column_blob(self.__statement._statement, i)
+                if not blob:
+                    val = None
+                else:
+                    blob_len = 
_lib.sqlite3_column_bytes(self.__statement._statement, i)
+                    val = _ffi.buffer(blob, blob_len)[:]
+                    val = converter(val)
+            else:
+                typ = _lib.sqlite3_column_type(self.__statement._statement, i)
+                if typ == _lib.SQLITE_NULL:
+                    val = None
+                elif typ == _lib.SQLITE_INTEGER:
+                    val = 
_lib.sqlite3_column_int64(self.__statement._statement, i)
+                    val = int(val)
+                elif typ == _lib.SQLITE_FLOAT:
+                    val = 
_lib.sqlite3_column_double(self.__statement._statement, i)
+                elif typ == _lib.SQLITE_TEXT:
+                    text = 
_lib.sqlite3_column_text(self.__statement._statement, i)
+                    text_len = 
_lib.sqlite3_column_bytes(self.__statement._statement, i)
+                    val = _ffi.buffer(text, text_len)[:]
+                    val = self.__connection.text_factory(val)
+                elif typ == _lib.SQLITE_BLOB:
+                    blob = 
_lib.sqlite3_column_blob(self.__statement._statement, i)
+                    blob_len = 
_lib.sqlite3_column_bytes(self.__statement._statement, i)
+                    val = _BLOB_TYPE(_ffi.buffer(blob, blob_len))
+            row.append(val)
+        return tuple(row)
+
     def __execute(self, multiple, sql, many_params):
         self.__locked = True
+        self._reset = False
         try:
             del self.__next_row
         except AttributeError:
             pass
         try:
-            self._reset = False
             if not isinstance(sql, basestring):
                 raise ValueError("operation parameter must be str or unicode")
             self.__description = None
             self.__rowcount = -1
-            self.__statement = self.__connection._statement_cache.get(
-                sql, self.row_factory)
+            self.__statement = self.__connection._statement_cache.get(sql)
 
             if self.__connection._isolation_level is not None:
                 if self.__statement._kind == Statement._DDL:
@@ -920,8 +995,8 @@
                     self.__statement._reset()
 
                 if ret == _lib.SQLITE_ROW:
-                    self.__statement._build_row_cast_map()
-                    self.__next_row = self.__statement._readahead(self)
+                    self.__build_row_cast_map()
+                    self.__next_row = self.__fetch_one_row()
 
                 if self.__statement._kind == Statement._DML:
                     if self.__rowcount == -1:
@@ -982,12 +1057,6 @@
                 break
         return self
 
-    def __check_reset(self):
-        if self._reset:
-            raise self.__connection.InterfaceError(
-                    "Cursor needed to be reset because of commit/rollback "
-                    "and can no longer be fetched from.")
-
     def __iter__(self):
         return self
 
@@ -1005,12 +1074,15 @@
             raise StopIteration
         del self.__next_row
 
+        if self.row_factory is not None:
+            next_row = self.row_factory(self, next_row)
+
         ret = _lib.sqlite3_step(self.__statement._statement)
         if ret not in (_lib.SQLITE_DONE, _lib.SQLITE_ROW):
             self.__statement._reset()
             raise self.__connection._get_exception(ret)
         elif ret == _lib.SQLITE_ROW:
-            self.__next_row = self.__statement._readahead(self)
+            self.__next_row = self.__fetch_one_row()
         return next_row
 
     if sys.version_info[0] < 3:
@@ -1068,7 +1140,6 @@
         self.__con._remember_statement(self)
 
         self._in_use = False
-        self._row_factory = None
 
         if not isinstance(sql, basestring):
             raise Warning("SQL is of wrong type. Must be string or unicode.")
@@ -1205,81 +1276,6 @@
         else:
             raise ValueError("parameters are of unsupported type")
 
-    def _build_row_cast_map(self):
-        if not self.__con._detect_types:
-            return
-        self.__row_cast_map = []
-        for i in xrange(_lib.sqlite3_column_count(self._statement)):
-            converter = None
-
-            if self.__con._detect_types & PARSE_COLNAMES:
-                colname = _lib.sqlite3_column_name(self._statement, i)
-                if colname:
-                    colname = _ffi.string(colname).decode('utf-8')
-                    type_start = -1
-                    key = None
-                    for pos in range(len(colname)):
-                        if colname[pos] == '[':
-                            type_start = pos + 1
-                        elif colname[pos] == ']' and type_start != -1:
-                            key = colname[type_start:pos]
-                            converter = converters[key.upper()]
-
-            if converter is None and self.__con._detect_types & 
PARSE_DECLTYPES:
-                decltype = _lib.sqlite3_column_decltype(self._statement, i)
-                if decltype:
-                    decltype = _ffi.string(decltype).decode('utf-8')
-                    # if multiple words, use first, eg.
-                    # "INTEGER NOT NULL" => "INTEGER"
-                    decltype = decltype.split()[0]
-                    if '(' in decltype:
-                        decltype = decltype[:decltype.index('(')]
-                    converter = converters.get(decltype.upper(), None)
-
-            self.__row_cast_map.append(converter)
-
-    def _readahead(self, cursor):
-        row = []
-        num_cols = _lib.sqlite3_data_count(self._statement)
-        for i in xrange(num_cols):
-            if self.__con._detect_types:
-                converter = self.__row_cast_map[i]
-            else:
-                converter = None
-
-            if converter is not None:
-                blob = _lib.sqlite3_column_blob(self._statement, i)
-                if not blob:
-                    val = None
-                else:
-                    blob_len = _lib.sqlite3_column_bytes(self._statement, i)
-                    val = _ffi.buffer(blob, blob_len)[:]
-                    val = converter(val)
-            else:
-                typ = _lib.sqlite3_column_type(self._statement, i)
-                if typ == _lib.SQLITE_NULL:
-                    val = None
-                elif typ == _lib.SQLITE_INTEGER:
-                    val = _lib.sqlite3_column_int64(self._statement, i)
-                    val = int(val)
-                elif typ == _lib.SQLITE_FLOAT:
-                    val = _lib.sqlite3_column_double(self._statement, i)
-                elif typ == _lib.SQLITE_TEXT:
-                    text = _lib.sqlite3_column_text(self._statement, i)
-                    text_len = _lib.sqlite3_column_bytes(self._statement, i)
-                    val = _ffi.buffer(text, text_len)[:]
-                    val = self.__con.text_factory(val)
-                elif typ == _lib.SQLITE_BLOB:
-                    blob = _lib.sqlite3_column_blob(self._statement, i)
-                    blob_len = _lib.sqlite3_column_bytes(self._statement, i)
-                    val = _BLOB_TYPE(_ffi.buffer(blob, blob_len))
-            row.append(val)
-
-        row = tuple(row)
-        if self._row_factory is not None:
-            row = self._row_factory(cursor, row)
-        return row
-
     def _get_description(self):
         if self._kind == Statement._DML:
             return None
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to