Author: Brian Kearns <bdkea...@gmail.com>
Branch: py3k
Changeset: r62161:254c727895bb
Date: 2013-03-07 03:22 -0500
http://bitbucket.org/pypy/pypy/changeset/254c727895bb/

Log:    merge default

diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py
--- a/lib_pypy/_sqlite3.py
+++ b/lib_pypy/_sqlite3.py
@@ -29,6 +29,7 @@
 from collections import OrderedDict
 from functools import wraps
 import datetime
+import string
 import sys
 import weakref
 from threading import _get_ident as _thread_get_ident
@@ -226,7 +227,7 @@
 sqlite.sqlite3_total_changes.argtypes = [c_void_p]
 sqlite.sqlite3_total_changes.restype = c_int
 
-sqlite.sqlite3_result_blob.argtypes = [c_void_p, c_char_p, c_int, c_void_p]
+sqlite.sqlite3_result_blob.argtypes = [c_void_p, c_void_p, c_int, c_void_p]
 sqlite.sqlite3_result_blob.restype = None
 sqlite.sqlite3_result_int64.argtypes = [c_void_p, c_int64]
 sqlite.sqlite3_result_int64.restype = None
@@ -319,6 +320,7 @@
         self.__initialized = True
         self._db = c_void_p()
 
+        database = database.encode('utf-8')
         if sqlite.sqlite3_open(database, byref(self._db)) != SQLITE_OK:
             raise OperationalError("Could not open database")
         if timeout is not None:
@@ -408,8 +410,7 @@
     def _get_exception(self, error_code=None):
         if error_code is None:
             error_code = sqlite.sqlite3_errcode(self._db)
-        error_message = sqlite.sqlite3_errmsg(self._db)
-        error_message = error_message.decode('utf-8')
+        error_message = sqlite.sqlite3_errmsg(self._db).decode('utf-8')
 
         if error_code == SQLITE_OK:
             raise ValueError("error signalled but got SQLITE_OK")
@@ -503,7 +504,7 @@
 
         statement = c_void_p()
         next_char = c_char_p()
-        ret = sqlite.sqlite3_prepare_v2(self._db, "COMMIT", -1,
+        ret = sqlite.sqlite3_prepare_v2(self._db, b"COMMIT", -1,
                                         byref(statement), next_char)
         try:
             if ret != SQLITE_OK:
@@ -533,7 +534,7 @@
 
         statement = c_void_p()
         next_char = c_char_p()
-        ret = sqlite.sqlite3_prepare_v2(self._db, "ROLLBACK", -1,
+        ret = sqlite.sqlite3_prepare_v2(self._db, b"ROLLBACK", -1,
                                         byref(statement), next_char)
         try:
             if ret != SQLITE_OK:
@@ -564,6 +565,8 @@
                 function_callback(callback, context, nargs, c_params)
             c_closure = _FUNC(closure)
             self.__func_cache[callback] = c_closure, closure
+
+        name = name.encode('utf-8')
         ret = sqlite.sqlite3_create_function(self._db, name, num_args,
                                              SQLITE_UTF8, None,
                                              c_closure,
@@ -579,7 +582,6 @@
             c_step_callback, c_final_callback, _, _ = self.__aggregates[cls]
         except KeyError:
             def step_callback(context, argc, c_params):
-
                 aggregate_ptr = cast(
                     sqlite.sqlite3_aggregate_context(
                         context, sizeof(c_ssize_t)),
@@ -589,8 +591,8 @@
                     try:
                         aggregate = cls()
                     except Exception:
-                        msg = ("user-defined aggregate's '__init__' "
-                               "method raised error")
+                        msg = (b"user-defined aggregate's '__init__' "
+                               b"method raised error")
                         sqlite.sqlite3_result_error(context, msg, len(msg))
                         return
                     aggregate_id = id(aggregate)
@@ -603,12 +605,11 @@
                 try:
                     aggregate.step(*params)
                 except Exception:
-                    msg = ("user-defined aggregate's 'step' "
-                           "method raised error")
+                    msg = (b"user-defined aggregate's 'step' "
+                           b"method raised error")
                     sqlite.sqlite3_result_error(context, msg, len(msg))
 
             def final_callback(context):
-
                 aggregate_ptr = cast(
                     sqlite.sqlite3_aggregate_context(
                         context, sizeof(c_ssize_t)),
@@ -619,8 +620,8 @@
                     try:
                         val = aggregate.finalize()
                     except Exception:
-                        msg = ("user-defined aggregate's 'finalize' "
-                               "method raised error")
+                        msg = (b"user-defined aggregate's 'finalize' "
+                               b"method raised error")
                         sqlite.sqlite3_result_error(context, msg, len(msg))
                     else:
                         _convert_result(context, val)
@@ -633,6 +634,7 @@
             self.__aggregates[cls] = (c_step_callback, c_final_callback,
                                      step_callback, final_callback)
 
+        name = name.encode('utf-8')
         ret = sqlite.sqlite3_create_function(self._db, name, num_args,
                                              SQLITE_UTF8, None,
                                              cast(None, _FUNC),
@@ -645,7 +647,7 @@
     @_check_closed_wrap
     def create_collation(self, name, callback):
         name = name.upper()
-        if not name.replace('_', '').isalnum():
+        if not all(c in string.ascii_uppercase + string.digits + '_' for c in 
name):
             raise ProgrammingError("invalid character in collation name")
 
         if callback is None:
@@ -656,14 +658,15 @@
                 raise TypeError("parameter must be callable")
 
             def collation_callback(context, len1, str1, len2, str2):
-                text1 = string_at(str1, len1)
-                text2 = string_at(str2, len2)
+                text1 = string_at(str1, len1).decode('utf-8')
+                text2 = string_at(str2, len2).decode('utf-8')
 
                 return callback(text1, text2)
 
             c_collation_callback = _COLLATION(collation_callback)
             self.__collations[name] = c_collation_callback
 
+        name = name.encode('utf-8')
         ret = sqlite.sqlite3_create_collation(self._db, name,
                                               SQLITE_UTF8,
                                               None,
@@ -733,7 +736,7 @@
         if val is None:
             self.commit()
         else:
-            self.__begin_statement = 'BEGIN ' + val
+            self.__begin_statement = b"BEGIN " + val.encode('utf-8')
         self._isolation_level = val
     isolation_level = property(__get_isolation_level, __set_isolation_level)
 
@@ -748,7 +751,6 @@
 
 class Cursor(object):
     __initialized = False
-    __connection = None
     __statement = None
 
     def __init__(self, con):
@@ -770,11 +772,10 @@
         self.__rowcount = -1
 
     def __del__(self):
-        if self.__connection:
-            try:
-                self.__connection._cursors.remove(weakref.ref(self))
-            except ValueError:
-                pass
+        try:
+            self.__connection._cursors.remove(weakref.ref(self))
+        except (AttributeError, ValueError):
+            pass
         if self.__statement:
             self.__statement._reset()
 
@@ -873,8 +874,8 @@
                     self.__connection._in_transaction = \
                             not 
sqlite.sqlite3_get_autocommit(self.__connection._db)
                     raise self.__connection._get_exception(ret)
+                self.__statement._reset()
                 self.__rowcount += 
sqlite.sqlite3_changes(self.__connection._db)
-            self.__statement._reset()
         finally:
             self.__locked = False
 
@@ -883,10 +884,9 @@
     def executescript(self, sql):
         self.__description = None
         self._reset = False
-        if type(sql) is str:
-            sql = sql.encode("utf-8")
         self.__check_cursor()
         statement = c_void_p()
+        sql = sql.encode('utf-8')
         c_sql = c_char_p(sql)
 
         self.__connection.commit()
@@ -1008,11 +1008,12 @@
 
         self._statement = c_void_p()
         next_char = c_char_p()
-        sql_char = sql
-        ret = sqlite.sqlite3_prepare_v2(self.__con._db, sql_char, -1, 
byref(self._statement), byref(next_char))
+        sql = sql.encode('utf-8')
+
+        ret = sqlite.sqlite3_prepare_v2(self.__con._db, sql, -1, 
byref(self._statement), byref(next_char))
         if ret == SQLITE_OK and self._statement.value is None:
             # an empty statement, we work around that, as it's the least 
trouble
-            ret = sqlite.sqlite3_prepare_v2(self.__con._db, "select 42", -1, 
byref(self._statement), byref(next_char))
+            ret = sqlite.sqlite3_prepare_v2(self.__con._db, b"select 42", -1, 
byref(self._statement), byref(next_char))
             self._kind = Statement._DQL
 
         if ret != SQLITE_OK:
@@ -1021,22 +1022,23 @@
         next_char = next_char.value.decode('utf-8')
         if _check_remaining_sql(next_char):
             raise Warning("One and only one statement required: %r" %
-                          (next_char,))
+                          next_char)
 
     def __del__(self):
         if self._statement:
             sqlite.sqlite3_finalize(self._statement)
 
     def _finalize(self):
-        sqlite.sqlite3_finalize(self._statement)
-        self._statement = None
+        if self._statement:
+            sqlite.sqlite3_finalize(self._statement)
+            self._statement = None
         self._in_use = False
 
     def _reset(self):
-        ret = sqlite.sqlite3_reset(self._statement)
-        self._in_use = False
+        if self._in_use and self._statement:
+            ret = sqlite.sqlite3_reset(self._statement)
+            self._in_use = False
         self._exhausted = False
-        return ret
 
     def _build_row_cast_map(self):
         self.__row_cast_map = []
@@ -1059,8 +1061,8 @@
             if converter is None and self.__con._detect_types & 
PARSE_DECLTYPES:
                 decltype = sqlite.sqlite3_column_decltype(self._statement, i)
                 if decltype is not None:
+                    decltype = decltype.decode('utf-8')
                     decltype = decltype.split()[0]      # if multiple words, 
use first, eg. "INTEGER NOT NULL" => "INTEGER"
-                    decltype = decltype.decode('utf-8')
                     if '(' in decltype:
                         decltype = decltype[:decltype.index('(')]
                     converter = converters.get(decltype.upper(), None)
@@ -1070,37 +1072,36 @@
     def __set_param(self, idx, param):
         cvt = converters.get(type(param))
         if cvt is not None:
-            cvt = param = cvt(param)
+            param = cvt(param)
 
         param = adapt(param)
 
         if param is None:
-            sqlite.sqlite3_bind_null(self._statement, idx)
+            rc = sqlite.sqlite3_bind_null(self._statement, idx)
         elif type(param) in (bool, int):
             if -2147483648 <= param <= 2147483647:
-                sqlite.sqlite3_bind_int(self._statement, idx, param)
+                rc = sqlite.sqlite3_bind_int(self._statement, idx, param)
             else:
-                sqlite.sqlite3_bind_int64(self._statement, idx, param)
+                rc = sqlite.sqlite3_bind_int64(self._statement, idx, param)
         elif type(param) is float:
-            sqlite.sqlite3_bind_double(self._statement, idx, param)
+            rc = sqlite.sqlite3_bind_double(self._statement, idx, param)
         elif isinstance(param, str):
-            param = param.encode('utf-8')
-            sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), 
SQLITE_TRANSIENT)
+            param = param.encode("utf-8")
+            rc = sqlite.sqlite3_bind_text(self._statement, idx, param, 
len(param), SQLITE_TRANSIENT)
         elif type(param) in (bytes, memoryview):
             param = bytes(param)
-            sqlite.sqlite3_bind_blob(self._statement, idx, param, len(param), 
SQLITE_TRANSIENT)
+            rc = sqlite.sqlite3_bind_blob(self._statement, idx, param, 
len(param), SQLITE_TRANSIENT)
         else:
-            raise InterfaceError("parameter type %s is not supported" %
-                                 type(param))
+            rc = -1
+        return rc
 
     def _set_params(self, params):
-        ret = sqlite.sqlite3_reset(self._statement)
-        if ret != SQLITE_OK:
-            raise self.__con._get_exception(ret)
         self._in_use = True
 
         num_params_needed = 
sqlite.sqlite3_bind_parameter_count(self._statement)
-        if not isinstance(params, dict):
+        if isinstance(params, (tuple, list)) or \
+                not isinstance(params, dict) and \
+                hasattr(params, '__len__') and hasattr(params, '__getitem__'):
             num_params = len(params)
             if num_params != num_params_needed:
                 raise ProgrammingError("Incorrect number of bindings supplied. 
"
@@ -1108,25 +1109,32 @@
                                        "there are %d supplied." %
                                        (num_params_needed, num_params))
             for i in range(num_params):
-                self.__set_param(i + 1, params[i])
-        else:
+                rc = self.__set_param(i + 1, params[i])
+                if rc != SQLITE_OK:
+                    raise InterfaceError("Error binding parameter %d - "
+                                         "probably unsupported type." % i)
+        elif isinstance(params, dict):
             for i in range(1, num_params_needed + 1):
                 param_name = 
sqlite.sqlite3_bind_parameter_name(self._statement, i)
                 if param_name is None:
                     raise ProgrammingError("Binding %d has no name, but you "
                                            "supplied a dictionary (which has "
                                            "only names)." % i)
-                param_name = param_name[1:].decode('utf-8')
+                param_name = param_name.decode('utf-8')[1:]
                 try:
                     param = params[param_name]
                 except KeyError:
                     raise ProgrammingError("You did not supply a value for "
                                            "binding %d." % i)
-                self.__set_param(i, param)
+                rc = self.__set_param(i, param)
+                if rc != SQLITE_OK:
+                    raise InterfaceError("Error binding parameter :%s - "
+                                         "probably unsupported type." %
+                                         param_name)
+        else:
+            raise ValueError("parameters are of unsupported type")
 
     def _next(self, cursor):
-        self.__con._check_closed()
-        self.__con._check_thread()
         if self._exhausted:
             raise StopIteration
         item = self._item
@@ -1158,14 +1166,14 @@
                 elif typ == SQLITE_FLOAT:
                     val = sqlite.sqlite3_column_double(self._statement, i)
                 elif typ == SQLITE_BLOB:
+                    blob = sqlite.sqlite3_column_blob(self._statement, i)
                     blob_len = sqlite.sqlite3_column_bytes(self._statement, i)
-                    blob = sqlite.sqlite3_column_blob(self._statement, i)
                     val = bytes(string_at(blob, blob_len))
                 elif typ == SQLITE_NULL:
                     val = None
                 elif typ == SQLITE_TEXT:
+                    text = sqlite.sqlite3_column_text(self._statement, i)
                     text_len = sqlite.sqlite3_column_bytes(self._statement, i)
-                    text = sqlite.sqlite3_column_text(self._statement, i)
                     val = string_at(text, text_len)
                     val = self.__con.text_factory(val)
             else:
@@ -1174,7 +1182,7 @@
                     val = None
                 else:
                     blob_len = sqlite.sqlite3_column_bytes(self._statement, i)
-                    val = string_at(blob, blob_len)
+                    val = bytes(string_at(blob, blob_len))
                     val = converter(val)
             row.append(val)
 
@@ -1188,8 +1196,8 @@
             return None
         desc = []
         for i in range(sqlite.sqlite3_column_count(self._statement)):
-            col_name = sqlite.sqlite3_column_name(self._statement, i)
-            name = col_name.decode('utf-8').split("[")[0].strip()
+            name = sqlite.sqlite3_column_name(self._statement, i)
+            name = name.decode('utf-8').split("[")[0].strip()
             desc.append((name, None, None, None, None, None, None))
         return desc
 
@@ -1282,15 +1290,14 @@
         elif typ == SQLITE_FLOAT:
             val = sqlite.sqlite3_value_double(params[i])
         elif typ == SQLITE_BLOB:
+            blob = sqlite.sqlite3_value_blob(params[i])
             blob_len = sqlite.sqlite3_value_bytes(params[i])
-            blob = sqlite.sqlite3_value_blob(params[i])
             val = bytes(string_at(blob, blob_len))
         elif typ == SQLITE_NULL:
             val = None
         elif typ == SQLITE_TEXT:
             val = sqlite.sqlite3_value_text(params[i])
-            # XXX changed from con.text_factory
-            val = str(val, 'utf-8')
+            val = val.decode('utf-8')
         else:
             raise NotImplementedError
         _params.append(val)
@@ -1303,13 +1310,12 @@
     elif isinstance(val, (bool, int)):
         sqlite.sqlite3_result_int64(con, int(val))
     elif isinstance(val, str):
-        sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT)
-    elif isinstance(val, bytes):
+        val = val.encode('utf-8')
         sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT)
     elif isinstance(val, float):
         sqlite.sqlite3_result_double(con, val)
-    elif isinstance(val, buffer):
-        sqlite.sqlite3_result_blob(con, str(val), len(val), SQLITE_TRANSIENT)
+    elif isinstance(val, (bytes, memoryview)):
+        sqlite.sqlite3_result_blob(con, bytes(val), len(val), SQLITE_TRANSIENT)
     else:
         raise NotImplementedError
 
@@ -1319,7 +1325,7 @@
     try:
         val = real_cb(*params)
     except Exception:
-        msg = "user-defined function raised exception"
+        msg = b"user-defined function raised exception"
         sqlite.sqlite3_result_error(context, msg, len(msg))
     else:
         _convert_result(context, val)
diff --git a/pypy/module/test_lib_pypy/test_sqlite3.py 
b/pypy/module/test_lib_pypy/test_sqlite3.py
--- a/pypy/module/test_lib_pypy/test_sqlite3.py
+++ b/pypy/module/test_lib_pypy/test_sqlite3.py
@@ -126,3 +126,20 @@
         con.commit()
     except _sqlite3.OperationalError:
         pytest.fail("_sqlite3 knew nothing about the implicit ROLLBACK")
+
+def test_statement_param_checking():
+    con = _sqlite3.connect(':memory:')
+    con.execute('create table foo(x)')
+    con.execute('insert into foo(x) values (?)', [2])
+    con.execute('insert into foo(x) values (?)', (2,))
+    class seq(object):
+        def __len__(self):
+            return 1
+        def __getitem__(self, key):
+            return 2
+    con.execute('insert into foo(x) values (?)', seq())
+    with pytest.raises(_sqlite3.ProgrammingError):
+        con.execute('insert into foo(x) values (?)', {2:2})
+    with pytest.raises(ValueError) as e:
+        con.execute('insert into foo(x) values (?)', 2)
+    assert str(e.value) == 'parameters are of unsupported type'
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to