Author: Brian Kearns <bdkea...@gmail.com> Branch: py3k Changeset: r62161:254c727895bb Date: 2013-03-07 03:22 -0500 http://bitbucket.org/pypy/pypy/changeset/254c727895bb/
Log: merge default diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py --- a/lib_pypy/_sqlite3.py +++ b/lib_pypy/_sqlite3.py @@ -29,6 +29,7 @@ from collections import OrderedDict from functools import wraps import datetime +import string import sys import weakref from threading import _get_ident as _thread_get_ident @@ -226,7 +227,7 @@ sqlite.sqlite3_total_changes.argtypes = [c_void_p] sqlite.sqlite3_total_changes.restype = c_int -sqlite.sqlite3_result_blob.argtypes = [c_void_p, c_char_p, c_int, c_void_p] +sqlite.sqlite3_result_blob.argtypes = [c_void_p, c_void_p, c_int, c_void_p] sqlite.sqlite3_result_blob.restype = None sqlite.sqlite3_result_int64.argtypes = [c_void_p, c_int64] sqlite.sqlite3_result_int64.restype = None @@ -319,6 +320,7 @@ self.__initialized = True self._db = c_void_p() + database = database.encode('utf-8') if sqlite.sqlite3_open(database, byref(self._db)) != SQLITE_OK: raise OperationalError("Could not open database") if timeout is not None: @@ -408,8 +410,7 @@ def _get_exception(self, error_code=None): if error_code is None: error_code = sqlite.sqlite3_errcode(self._db) - error_message = sqlite.sqlite3_errmsg(self._db) - error_message = error_message.decode('utf-8') + error_message = sqlite.sqlite3_errmsg(self._db).decode('utf-8') if error_code == SQLITE_OK: raise ValueError("error signalled but got SQLITE_OK") @@ -503,7 +504,7 @@ statement = c_void_p() next_char = c_char_p() - ret = sqlite.sqlite3_prepare_v2(self._db, "COMMIT", -1, + ret = sqlite.sqlite3_prepare_v2(self._db, b"COMMIT", -1, byref(statement), next_char) try: if ret != SQLITE_OK: @@ -533,7 +534,7 @@ statement = c_void_p() next_char = c_char_p() - ret = sqlite.sqlite3_prepare_v2(self._db, "ROLLBACK", -1, + ret = sqlite.sqlite3_prepare_v2(self._db, b"ROLLBACK", -1, byref(statement), next_char) try: if ret != SQLITE_OK: @@ -564,6 +565,8 @@ function_callback(callback, context, nargs, c_params) c_closure = _FUNC(closure) self.__func_cache[callback] = c_closure, closure + + name = name.encode('utf-8') ret = sqlite.sqlite3_create_function(self._db, name, num_args, SQLITE_UTF8, None, c_closure, @@ -579,7 +582,6 @@ c_step_callback, c_final_callback, _, _ = self.__aggregates[cls] except KeyError: def step_callback(context, argc, c_params): - aggregate_ptr = cast( sqlite.sqlite3_aggregate_context( context, sizeof(c_ssize_t)), @@ -589,8 +591,8 @@ try: aggregate = cls() except Exception: - msg = ("user-defined aggregate's '__init__' " - "method raised error") + msg = (b"user-defined aggregate's '__init__' " + b"method raised error") sqlite.sqlite3_result_error(context, msg, len(msg)) return aggregate_id = id(aggregate) @@ -603,12 +605,11 @@ try: aggregate.step(*params) except Exception: - msg = ("user-defined aggregate's 'step' " - "method raised error") + msg = (b"user-defined aggregate's 'step' " + b"method raised error") sqlite.sqlite3_result_error(context, msg, len(msg)) def final_callback(context): - aggregate_ptr = cast( sqlite.sqlite3_aggregate_context( context, sizeof(c_ssize_t)), @@ -619,8 +620,8 @@ try: val = aggregate.finalize() except Exception: - msg = ("user-defined aggregate's 'finalize' " - "method raised error") + msg = (b"user-defined aggregate's 'finalize' " + b"method raised error") sqlite.sqlite3_result_error(context, msg, len(msg)) else: _convert_result(context, val) @@ -633,6 +634,7 @@ self.__aggregates[cls] = (c_step_callback, c_final_callback, step_callback, final_callback) + name = name.encode('utf-8') ret = sqlite.sqlite3_create_function(self._db, name, num_args, SQLITE_UTF8, None, cast(None, _FUNC), @@ -645,7 +647,7 @@ @_check_closed_wrap def create_collation(self, name, callback): name = name.upper() - if not name.replace('_', '').isalnum(): + if not all(c in string.ascii_uppercase + string.digits + '_' for c in name): raise ProgrammingError("invalid character in collation name") if callback is None: @@ -656,14 +658,15 @@ raise TypeError("parameter must be callable") def collation_callback(context, len1, str1, len2, str2): - text1 = string_at(str1, len1) - text2 = string_at(str2, len2) + text1 = string_at(str1, len1).decode('utf-8') + text2 = string_at(str2, len2).decode('utf-8') return callback(text1, text2) c_collation_callback = _COLLATION(collation_callback) self.__collations[name] = c_collation_callback + name = name.encode('utf-8') ret = sqlite.sqlite3_create_collation(self._db, name, SQLITE_UTF8, None, @@ -733,7 +736,7 @@ if val is None: self.commit() else: - self.__begin_statement = 'BEGIN ' + val + self.__begin_statement = b"BEGIN " + val.encode('utf-8') self._isolation_level = val isolation_level = property(__get_isolation_level, __set_isolation_level) @@ -748,7 +751,6 @@ class Cursor(object): __initialized = False - __connection = None __statement = None def __init__(self, con): @@ -770,11 +772,10 @@ self.__rowcount = -1 def __del__(self): - if self.__connection: - try: - self.__connection._cursors.remove(weakref.ref(self)) - except ValueError: - pass + try: + self.__connection._cursors.remove(weakref.ref(self)) + except (AttributeError, ValueError): + pass if self.__statement: self.__statement._reset() @@ -873,8 +874,8 @@ self.__connection._in_transaction = \ not sqlite.sqlite3_get_autocommit(self.__connection._db) raise self.__connection._get_exception(ret) + self.__statement._reset() self.__rowcount += sqlite.sqlite3_changes(self.__connection._db) - self.__statement._reset() finally: self.__locked = False @@ -883,10 +884,9 @@ def executescript(self, sql): self.__description = None self._reset = False - if type(sql) is str: - sql = sql.encode("utf-8") self.__check_cursor() statement = c_void_p() + sql = sql.encode('utf-8') c_sql = c_char_p(sql) self.__connection.commit() @@ -1008,11 +1008,12 @@ self._statement = c_void_p() next_char = c_char_p() - sql_char = sql - ret = sqlite.sqlite3_prepare_v2(self.__con._db, sql_char, -1, byref(self._statement), byref(next_char)) + sql = sql.encode('utf-8') + + ret = sqlite.sqlite3_prepare_v2(self.__con._db, sql, -1, byref(self._statement), byref(next_char)) if ret == SQLITE_OK and self._statement.value is None: # an empty statement, we work around that, as it's the least trouble - ret = sqlite.sqlite3_prepare_v2(self.__con._db, "select 42", -1, byref(self._statement), byref(next_char)) + ret = sqlite.sqlite3_prepare_v2(self.__con._db, b"select 42", -1, byref(self._statement), byref(next_char)) self._kind = Statement._DQL if ret != SQLITE_OK: @@ -1021,22 +1022,23 @@ next_char = next_char.value.decode('utf-8') if _check_remaining_sql(next_char): raise Warning("One and only one statement required: %r" % - (next_char,)) + next_char) def __del__(self): if self._statement: sqlite.sqlite3_finalize(self._statement) def _finalize(self): - sqlite.sqlite3_finalize(self._statement) - self._statement = None + if self._statement: + sqlite.sqlite3_finalize(self._statement) + self._statement = None self._in_use = False def _reset(self): - ret = sqlite.sqlite3_reset(self._statement) - self._in_use = False + if self._in_use and self._statement: + ret = sqlite.sqlite3_reset(self._statement) + self._in_use = False self._exhausted = False - return ret def _build_row_cast_map(self): self.__row_cast_map = [] @@ -1059,8 +1061,8 @@ if converter is None and self.__con._detect_types & PARSE_DECLTYPES: decltype = sqlite.sqlite3_column_decltype(self._statement, i) if decltype is not None: + decltype = decltype.decode('utf-8') decltype = decltype.split()[0] # if multiple words, use first, eg. "INTEGER NOT NULL" => "INTEGER" - decltype = decltype.decode('utf-8') if '(' in decltype: decltype = decltype[:decltype.index('(')] converter = converters.get(decltype.upper(), None) @@ -1070,37 +1072,36 @@ def __set_param(self, idx, param): cvt = converters.get(type(param)) if cvt is not None: - cvt = param = cvt(param) + param = cvt(param) param = adapt(param) if param is None: - sqlite.sqlite3_bind_null(self._statement, idx) + rc = sqlite.sqlite3_bind_null(self._statement, idx) elif type(param) in (bool, int): if -2147483648 <= param <= 2147483647: - sqlite.sqlite3_bind_int(self._statement, idx, param) + rc = sqlite.sqlite3_bind_int(self._statement, idx, param) else: - sqlite.sqlite3_bind_int64(self._statement, idx, param) + rc = sqlite.sqlite3_bind_int64(self._statement, idx, param) elif type(param) is float: - sqlite.sqlite3_bind_double(self._statement, idx, param) + rc = sqlite.sqlite3_bind_double(self._statement, idx, param) elif isinstance(param, str): - param = param.encode('utf-8') - sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), SQLITE_TRANSIENT) + param = param.encode("utf-8") + rc = sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), SQLITE_TRANSIENT) elif type(param) in (bytes, memoryview): param = bytes(param) - sqlite.sqlite3_bind_blob(self._statement, idx, param, len(param), SQLITE_TRANSIENT) + rc = sqlite.sqlite3_bind_blob(self._statement, idx, param, len(param), SQLITE_TRANSIENT) else: - raise InterfaceError("parameter type %s is not supported" % - type(param)) + rc = -1 + return rc def _set_params(self, params): - ret = sqlite.sqlite3_reset(self._statement) - if ret != SQLITE_OK: - raise self.__con._get_exception(ret) self._in_use = True num_params_needed = sqlite.sqlite3_bind_parameter_count(self._statement) - if not isinstance(params, dict): + if isinstance(params, (tuple, list)) or \ + not isinstance(params, dict) and \ + hasattr(params, '__len__') and hasattr(params, '__getitem__'): num_params = len(params) if num_params != num_params_needed: raise ProgrammingError("Incorrect number of bindings supplied. " @@ -1108,25 +1109,32 @@ "there are %d supplied." % (num_params_needed, num_params)) for i in range(num_params): - self.__set_param(i + 1, params[i]) - else: + rc = self.__set_param(i + 1, params[i]) + if rc != SQLITE_OK: + raise InterfaceError("Error binding parameter %d - " + "probably unsupported type." % i) + elif isinstance(params, dict): for i in range(1, num_params_needed + 1): param_name = sqlite.sqlite3_bind_parameter_name(self._statement, i) if param_name is None: raise ProgrammingError("Binding %d has no name, but you " "supplied a dictionary (which has " "only names)." % i) - param_name = param_name[1:].decode('utf-8') + param_name = param_name.decode('utf-8')[1:] try: param = params[param_name] except KeyError: raise ProgrammingError("You did not supply a value for " "binding %d." % i) - self.__set_param(i, param) + rc = self.__set_param(i, param) + if rc != SQLITE_OK: + raise InterfaceError("Error binding parameter :%s - " + "probably unsupported type." % + param_name) + else: + raise ValueError("parameters are of unsupported type") def _next(self, cursor): - self.__con._check_closed() - self.__con._check_thread() if self._exhausted: raise StopIteration item = self._item @@ -1158,14 +1166,14 @@ elif typ == SQLITE_FLOAT: val = sqlite.sqlite3_column_double(self._statement, i) elif typ == SQLITE_BLOB: + blob = sqlite.sqlite3_column_blob(self._statement, i) blob_len = sqlite.sqlite3_column_bytes(self._statement, i) - blob = sqlite.sqlite3_column_blob(self._statement, i) val = bytes(string_at(blob, blob_len)) elif typ == SQLITE_NULL: val = None elif typ == SQLITE_TEXT: + text = sqlite.sqlite3_column_text(self._statement, i) text_len = sqlite.sqlite3_column_bytes(self._statement, i) - text = sqlite.sqlite3_column_text(self._statement, i) val = string_at(text, text_len) val = self.__con.text_factory(val) else: @@ -1174,7 +1182,7 @@ val = None else: blob_len = sqlite.sqlite3_column_bytes(self._statement, i) - val = string_at(blob, blob_len) + val = bytes(string_at(blob, blob_len)) val = converter(val) row.append(val) @@ -1188,8 +1196,8 @@ return None desc = [] for i in range(sqlite.sqlite3_column_count(self._statement)): - col_name = sqlite.sqlite3_column_name(self._statement, i) - name = col_name.decode('utf-8').split("[")[0].strip() + name = sqlite.sqlite3_column_name(self._statement, i) + name = name.decode('utf-8').split("[")[0].strip() desc.append((name, None, None, None, None, None, None)) return desc @@ -1282,15 +1290,14 @@ elif typ == SQLITE_FLOAT: val = sqlite.sqlite3_value_double(params[i]) elif typ == SQLITE_BLOB: + blob = sqlite.sqlite3_value_blob(params[i]) blob_len = sqlite.sqlite3_value_bytes(params[i]) - blob = sqlite.sqlite3_value_blob(params[i]) val = bytes(string_at(blob, blob_len)) elif typ == SQLITE_NULL: val = None elif typ == SQLITE_TEXT: val = sqlite.sqlite3_value_text(params[i]) - # XXX changed from con.text_factory - val = str(val, 'utf-8') + val = val.decode('utf-8') else: raise NotImplementedError _params.append(val) @@ -1303,13 +1310,12 @@ elif isinstance(val, (bool, int)): sqlite.sqlite3_result_int64(con, int(val)) elif isinstance(val, str): - sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT) - elif isinstance(val, bytes): + val = val.encode('utf-8') sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT) elif isinstance(val, float): sqlite.sqlite3_result_double(con, val) - elif isinstance(val, buffer): - sqlite.sqlite3_result_blob(con, str(val), len(val), SQLITE_TRANSIENT) + elif isinstance(val, (bytes, memoryview)): + sqlite.sqlite3_result_blob(con, bytes(val), len(val), SQLITE_TRANSIENT) else: raise NotImplementedError @@ -1319,7 +1325,7 @@ try: val = real_cb(*params) except Exception: - msg = "user-defined function raised exception" + msg = b"user-defined function raised exception" sqlite.sqlite3_result_error(context, msg, len(msg)) else: _convert_result(context, val) diff --git a/pypy/module/test_lib_pypy/test_sqlite3.py b/pypy/module/test_lib_pypy/test_sqlite3.py --- a/pypy/module/test_lib_pypy/test_sqlite3.py +++ b/pypy/module/test_lib_pypy/test_sqlite3.py @@ -126,3 +126,20 @@ con.commit() except _sqlite3.OperationalError: pytest.fail("_sqlite3 knew nothing about the implicit ROLLBACK") + +def test_statement_param_checking(): + con = _sqlite3.connect(':memory:') + con.execute('create table foo(x)') + con.execute('insert into foo(x) values (?)', [2]) + con.execute('insert into foo(x) values (?)', (2,)) + class seq(object): + def __len__(self): + return 1 + def __getitem__(self, key): + return 2 + con.execute('insert into foo(x) values (?)', seq()) + with pytest.raises(_sqlite3.ProgrammingError): + con.execute('insert into foo(x) values (?)', {2:2}) + with pytest.raises(ValueError) as e: + con.execute('insert into foo(x) values (?)', 2) + assert str(e.value) == 'parameters are of unsupported type' _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit