Author: Antonio Cuni <[email protected]>
Branch: 
Changeset: r46988:72321d3c8b69
Date: 2011-09-01 18:04 +0200
http://bitbucket.org/pypy/pypy/changeset/72321d3c8b69/

Log:    implement a cache of statements similar to the one which is in the
        CPython C extension. It makes some trivial benchmark twice as fast

diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py
--- a/lib_pypy/_sqlite3.py
+++ b/lib_pypy/_sqlite3.py
@@ -24,6 +24,7 @@
 from ctypes import c_void_p, c_int, c_double, c_int64, c_char_p, cdll
 from ctypes import POINTER, byref, string_at, CFUNCTYPE, cast
 from ctypes import sizeof, c_ssize_t
+from collections import OrderedDict
 import datetime
 import sys
 import time
@@ -274,6 +275,28 @@
 def unicode_text_factory(x):
     return unicode(x, 'utf-8')
 
+
+class StatementCache(object):
+    def __init__(self, connection, maxcount):
+        self.connection = connection
+        self.maxcount = maxcount
+        self.cache = OrderedDict()
+
+    def get(self, sql, cursor, row_factory):
+        try:
+            stat = self.cache[sql]
+        except KeyError:
+            stat = Statement(self.connection, sql)
+            self.cache[sql] = stat
+            if len(self.cache) > self.maxcount:
+                self.cache.popitem(0)
+        #
+        if stat.in_use:
+            stat = Statement(self.connection, sql)
+        stat.set_cursor_and_factory(cursor, row_factory)
+        return stat
+
+
 class Connection(object):
     def __init__(self, database, timeout=5.0, detect_types=0, 
isolation_level="",
                  check_same_thread=True, factory=None, cached_statements=100):
@@ -291,6 +314,7 @@
         self.row_factory = None
         self._isolation_level = isolation_level
         self.detect_types = detect_types
+        self.statement_cache = StatementCache(self, cached_statements)
 
         self.cursors = []
 
@@ -399,7 +423,7 @@
         cur = Cursor(self)
         if not isinstance(sql, (str, unicode)):
             raise Warning("SQL is of wrong type. Must be string or unicode.")
-        statement = Statement(cur, sql, self.row_factory)
+        statement = self.statement_cache.get(sql, cur, self.row_factory)
         return statement
 
     def _get_isolation_level(self):
@@ -708,7 +732,7 @@
         if type(sql) is unicode:
             sql = sql.encode("utf-8")
         self._check_closed()
-        self.statement = Statement(self, sql, self.row_factory)
+        self.statement = self.connection.statement_cache.get(sql, self, 
self.row_factory)
 
         if self.connection._isolation_level is not None:
             if self.statement.kind == "DDL":
@@ -746,7 +770,8 @@
         if type(sql) is unicode:
             sql = sql.encode("utf-8")
         self._check_closed()
-        self.statement = Statement(self, sql, self.row_factory)
+        self.statement = self.connection.statement_cache.get(sql, self, 
self.row_factory)
+        
         if self.statement.kind == "DML":
             self.connection._begin()
         else:
@@ -871,14 +896,12 @@
     lastrowid = property(_getlastrowid)
 
 class Statement(object):
-    def __init__(self, cur, sql, row_factory):
+    def __init__(self, connection, sql):
         self.statement = None
         if not isinstance(sql, str):
             raise ValueError, "sql must be a string"
-        self.con = cur.connection
-        self.cur = weakref.ref(cur)
+        self.con = connection
         self.sql = sql # DEBUG ONLY
-        self.row_factory = row_factory
         first_word = self._statement_kind = sql.lstrip().split(" ")[0].upper()
         if first_word in ("INSERT", "UPDATE", "DELETE", "REPLACE"):
             self.kind = "DML"
@@ -887,6 +910,11 @@
         else:
             self.kind = "DDL"
         self.exhausted = False
+        self.in_use = False
+        #
+        # set by set_cursor_and_factory
+        self.cur = None
+        self.row_factory = None
 
         self.statement = c_void_p()
         next_char = c_char_p()
@@ -907,6 +935,10 @@
 
         self._build_row_cast_map()
 
+    def set_cursor_and_factory(self, cur, row_factory):
+        self.cur = weakref.ref(cur)
+        self.row_factory = row_factory
+
     def _build_row_cast_map(self):
         self.row_cast_map = []
         for i in xrange(sqlite.sqlite3_column_count(self.statement)):
@@ -976,6 +1008,7 @@
         ret = sqlite.sqlite3_reset(self.statement)
         if ret != SQLITE_OK:
             raise self.con._get_exception(ret)
+        self.mark_dirty()
 
         if params is None:
             if sqlite.sqlite3_bind_parameter_count(self.statement) != 0:
@@ -1068,11 +1101,17 @@
 
     def reset(self):
         self.row_cast_map = None
-        return sqlite.sqlite3_reset(self.statement)
+        ret = sqlite.sqlite3_reset(self.statement)
+        self.in_use = False
+        return ret
 
     def finalize(self):
         sqlite.sqlite3_finalize(self.statement)
         self.statement = None
+        self.in_use = False
+
+    def mark_dirty(self):
+        self.in_use = True
 
     def __del__(self):
         sqlite.sqlite3_finalize(self.statement)
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to