Author: Brian Kearns <bdkea...@gmail.com>
Branch: py3k
Changeset: r62192:f7086f05b2e3
Date: 2013-03-07 17:43 -0500
http://bitbucket.org/pypy/pypy/changeset/f7086f05b2e3/

Log:    merge default

diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py
--- a/lib_pypy/_sqlite3.py
+++ b/lib_pypy/_sqlite3.py
@@ -34,6 +34,16 @@
 import weakref
 from threading import _get_ident as _thread_get_ident
 
+if sys.version_info[0] >= 3:
+    StandardError = Exception
+    long = int
+    xrange = range
+    basestring = unicode = str
+    buffer = memoryview
+    BLOB_TYPE = bytes
+else:
+    BLOB_TYPE = buffer
+
 names = "sqlite3.dll libsqlite3.so.0 libsqlite3.so libsqlite3.dylib".split()
 for name in names:
     try:
@@ -243,12 +253,12 @@
 ##########################################
 
 # SQLite version information
-sqlite_version = sqlite.sqlite3_libversion().decode('ascii')
+sqlite_version = str(sqlite.sqlite3_libversion().decode('ascii'))
 
-class Error(Exception):
+class Error(StandardError):
     pass
 
-class Warning(Exception):
+class Warning(StandardError):
     pass
 
 class InterfaceError(Error):
@@ -280,7 +290,17 @@
     return factory(database, **kwargs)
 
 def unicode_text_factory(x):
-    return str(x, 'utf-8')
+    return unicode(x, 'utf-8')
+
+if sys.version_info[0] < 3:
+    def OptimizedUnicode(s):
+        try:
+            val = unicode(s, "ascii").encode("ascii")
+        except UnicodeDecodeError:
+            val = unicode(s, "utf-8")
+        return val
+else:
+    OptimizedUnicode = unicode_text_factory
 
 
 class _StatementCache(object):
@@ -313,7 +333,8 @@
         self.__initialized = True
         self._db = c_void_p()
 
-        database = database.encode('utf-8')
+        if isinstance(database, unicode):
+            database = database.encode('utf-8')
         if sqlite.sqlite3_open(database, byref(self._db)) != SQLITE_OK:
             raise OperationalError("Could not open database")
         if timeout is not None:
@@ -439,7 +460,7 @@
     @_check_thread_wrap
     @_check_closed_wrap
     def __call__(self, sql):
-        if not isinstance(sql, str):
+        if not isinstance(sql, basestring):
             raise Warning("SQL is of wrong type. Must be string or unicode.")
         return self._statement_cache.get(sql, self.row_factory)
 
@@ -556,7 +577,8 @@
             c_closure = _FUNC(closure)
             self.__func_cache[callback] = c_closure, closure
 
-        name = name.encode('utf-8')
+        if isinstance(name, unicode):
+            name = name.encode('utf-8')
         ret = sqlite.sqlite3_create_function(self._db, name, num_args,
                                              SQLITE_UTF8, None,
                                              c_closure,
@@ -624,7 +646,8 @@
             self.__aggregates[cls] = (c_step_callback, c_final_callback,
                                      step_callback, final_callback)
 
-        name = name.encode('utf-8')
+        if isinstance(name, unicode):
+            name = name.encode('utf-8')
         ret = sqlite.sqlite3_create_function(self._db, name, num_args,
                                              SQLITE_UTF8, None,
                                              cast(None, _FUNC),
@@ -656,7 +679,8 @@
             c_collation_callback = _COLLATION(collation_callback)
             self.__collations[name] = c_collation_callback
 
-        name = name.encode('utf-8')
+        if isinstance(name, unicode):
+            name = name.encode('utf-8')
         ret = sqlite.sqlite3_create_collation(self._db, name,
                                               SQLITE_UTF8,
                                               None,
@@ -710,9 +734,10 @@
         if ret != SQLITE_OK:
             raise self._get_exception(ret)
 
-    def __get_in_transaction(self):
-        return self._in_transaction
-    in_transaction = property(__get_in_transaction)
+    if sys.version_info[0] >= 3:
+        def __get_in_transaction(self):
+            return self._in_transaction
+        in_transaction = property(__get_in_transaction)
 
     def __get_total_changes(self):
         self._check_closed()
@@ -726,7 +751,7 @@
         if val is None:
             self.commit()
         else:
-            self.__begin_statement = b"BEGIN " + val.encode('utf-8')
+            self.__begin_statement = str("BEGIN " + val).encode('utf-8')
         self._isolation_level = val
     isolation_level = property(__get_isolation_level, __set_isolation_level)
 
@@ -800,7 +825,7 @@
         try:
             self.__description = None
             self._reset = False
-            if not isinstance(sql, str):
+            if not isinstance(sql, basestring):
                 raise ValueError("operation parameter must be str or unicode")
             self.__statement = self.__connection._statement_cache.get(
                 sql, self.row_factory)
@@ -847,7 +872,7 @@
         try:
             self.__description = None
             self._reset = False
-            if not isinstance(sql, str):
+            if not isinstance(sql, basestring):
                 raise ValueError("operation parameter must be str or unicode")
             self.__statement = self.__connection._statement_cache.get(
                 sql, self.row_factory)
@@ -880,9 +905,9 @@
         self._reset = False
         self.__check_cursor()
         statement = c_void_p()
-        if isinstance(sql, str):
+        if isinstance(sql, unicode):
             sql = sql.encode('utf-8')
-        else:
+        elif not isinstance(sql, str):
             raise ValueError("script argument must be unicode or string.")
         c_sql = c_char_p(sql)
 
@@ -982,7 +1007,7 @@
     def __init__(self, connection, sql):
         self.__con = connection
 
-        if not isinstance(sql, str):
+        if not isinstance(sql, basestring):
             raise ValueError("sql must be a string")
         first_word = self._statement_kind = sql.lstrip().split(" ")[0].upper()
         if first_word in ("INSERT", "UPDATE", "DELETE", "REPLACE"):
@@ -998,7 +1023,8 @@
 
         self._statement = c_void_p()
         next_char = c_char_p()
-        sql = sql.encode('utf-8')
+        if isinstance(sql, unicode):
+            sql = sql.encode('utf-8')
 
         ret = sqlite.sqlite3_prepare_v2(self.__con._db, sql, -1, 
byref(self._statement), byref(next_char))
         if ret == SQLITE_OK and self._statement.value is None:
@@ -1032,7 +1058,7 @@
 
     def _build_row_cast_map(self):
         self.__row_cast_map = []
-        for i in range(sqlite.sqlite3_column_count(self._statement)):
+        for i in xrange(sqlite.sqlite3_column_count(self._statement)):
             converter = None
 
             if self.__con._detect_types & PARSE_COLNAMES:
@@ -1059,6 +1085,19 @@
 
             self.__row_cast_map.append(converter)
 
+    if sys.version_info[0] < 3:
+        def __check_decodable(self, param):
+            if self.__con.text_factory in (unicode, OptimizedUnicode,
+                                           unicode_text_factory):
+                for c in param:
+                    if ord(c) & 0x80 != 0:
+                        raise self.__con.ProgrammingError(
+                            "You must not use 8-bit bytestrings unless "
+                            "you use a text_factory that can interpret "
+                            "8-bit bytestrings (like text_factory = str). "
+                            "It is highly recommended that you instead "
+                            "just switch your application to Unicode strings.")
+
     def __set_param(self, idx, param):
         cvt = converters.get(type(param))
         if cvt is not None:
@@ -1068,17 +1107,20 @@
 
         if param is None:
             rc = sqlite.sqlite3_bind_null(self._statement, idx)
-        elif type(param) in (bool, int):
+        elif isinstance(param, (bool, int, long)):
             if -2147483648 <= param <= 2147483647:
                 rc = sqlite.sqlite3_bind_int(self._statement, idx, param)
             else:
                 rc = sqlite.sqlite3_bind_int64(self._statement, idx, param)
-        elif type(param) is float:
+        elif isinstance(param, float):
             rc = sqlite.sqlite3_bind_double(self._statement, idx, param)
-        elif isinstance(param, str):
+        elif isinstance(param, unicode):
             param = param.encode("utf-8")
             rc = sqlite.sqlite3_bind_text(self._statement, idx, param, 
len(param), SQLITE_TRANSIENT)
-        elif type(param) in (bytes, memoryview):
+        elif isinstance(param, str):
+            self.__check_decodable(param)
+            rc = sqlite.sqlite3_bind_text(self._statement, idx, param, 
len(param), SQLITE_TRANSIENT)
+        elif isinstance(param, (buffer, bytes)):
             param = bytes(param)
             rc = sqlite.sqlite3_bind_blob(self._statement, idx, param, 
len(param), SQLITE_TRANSIENT)
         else:
@@ -1147,28 +1189,26 @@
     def _readahead(self, cursor):
         self.column_count = sqlite.sqlite3_column_count(self._statement)
         row = []
-        for i in range(self.column_count):
+        for i in xrange(self.column_count):
             typ = sqlite.sqlite3_column_type(self._statement, i)
 
             converter = self.__row_cast_map[i]
             if converter is None:
-                if typ == SQLITE_INTEGER:
+                if typ == SQLITE_NULL:
+                    val = None
+                elif typ == SQLITE_INTEGER:
                     val = sqlite.sqlite3_column_int64(self._statement, i)
-                    if -sys.maxsize-1 <= val <= sys.maxsize:
-                        val = int(val)
                 elif typ == SQLITE_FLOAT:
                     val = sqlite.sqlite3_column_double(self._statement, i)
-                elif typ == SQLITE_BLOB:
-                    blob = sqlite.sqlite3_column_blob(self._statement, i)
-                    blob_len = sqlite.sqlite3_column_bytes(self._statement, i)
-                    val = bytes(string_at(blob, blob_len))
-                elif typ == SQLITE_NULL:
-                    val = None
                 elif typ == SQLITE_TEXT:
                     text = sqlite.sqlite3_column_text(self._statement, i)
                     text_len = sqlite.sqlite3_column_bytes(self._statement, i)
                     val = string_at(text, text_len)
                     val = self.__con.text_factory(val)
+                elif typ == SQLITE_BLOB:
+                    blob = sqlite.sqlite3_column_blob(self._statement, i)
+                    blob_len = sqlite.sqlite3_column_bytes(self._statement, i)
+                    val = BLOB_TYPE(string_at(blob, blob_len))
             else:
                 blob = sqlite.sqlite3_column_blob(self._statement, i)
                 if not blob:
@@ -1188,7 +1228,7 @@
         if self._kind == Statement._DML:
             return None
         desc = []
-        for i in range(sqlite.sqlite3_column_count(self._statement)):
+        for i in xrange(sqlite.sqlite3_column_count(self._statement)):
             name = sqlite.sqlite3_column_name(self._statement, i)
             if name is not None:
                 name = name.decode('utf-8').split("[")[0].strip()
@@ -1277,21 +1317,19 @@
     _params = []
     for i in range(nargs):
         typ = sqlite.sqlite3_value_type(params[i])
-        if typ == SQLITE_INTEGER:
+        if typ == SQLITE_NULL:
+            val = None
+        elif typ == SQLITE_INTEGER:
             val = sqlite.sqlite3_value_int64(params[i])
-            if -sys.maxsize-1 <= val <= sys.maxsize:
-                val = int(val)
         elif typ == SQLITE_FLOAT:
             val = sqlite.sqlite3_value_double(params[i])
+        elif typ == SQLITE_TEXT:
+            val = sqlite.sqlite3_value_text(params[i])
+            val = val.decode('utf-8')
         elif typ == SQLITE_BLOB:
             blob = sqlite.sqlite3_value_blob(params[i])
             blob_len = sqlite.sqlite3_value_bytes(params[i])
-            val = bytes(string_at(blob, blob_len))
-        elif typ == SQLITE_NULL:
-            val = None
-        elif typ == SQLITE_TEXT:
-            val = sqlite.sqlite3_value_text(params[i])
-            val = val.decode('utf-8')
+            val = BLOB_TYPE(string_at(blob, blob_len))
         else:
             raise NotImplementedError
         _params.append(val)
@@ -1301,14 +1339,16 @@
 def _convert_result(con, val):
     if val is None:
         sqlite.sqlite3_result_null(con)
-    elif isinstance(val, (bool, int)):
+    elif isinstance(val, (bool, int, long)):
         sqlite.sqlite3_result_int64(con, int(val))
-    elif isinstance(val, str):
+    elif isinstance(val, float):
+        sqlite.sqlite3_result_double(con, val)
+    elif isinstance(val, unicode):
         val = val.encode('utf-8')
         sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT)
-    elif isinstance(val, float):
-        sqlite.sqlite3_result_double(con, val)
-    elif isinstance(val, (bytes, memoryview)):
+    elif isinstance(val, str):
+        sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT)
+    elif isinstance(val, (buffer, bytes)):
         sqlite.sqlite3_result_blob(con, bytes(val), len(val), SQLITE_TRANSIENT)
     else:
         raise NotImplementedError
@@ -1380,8 +1420,8 @@
             microseconds = int(timepart_full[1])
         else:
             microseconds = 0
-        return datetime.datetime(year, month, day,
-                                 hours, minutes, seconds, microseconds)
+        return datetime.datetime(year, month, day, hours, minutes, seconds,
+                                 microseconds)
 
     register_adapter(datetime.date, adapt_date)
     register_adapter(datetime.datetime, adapt_datetime)
@@ -1418,6 +1458,3 @@
     return val
 
 register_adapters_and_converters()
-
-
-OptimizedUnicode = unicode_text_factory
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to