Author: Brian Kearns <bdkea...@gmail.com> Branch: py3k Changeset: r62134:36a3c5c18e36 Date: 2013-03-06 15:35 -0500 http://bitbucket.org/pypy/pypy/changeset/36a3c5c18e36/
Log: merge default diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py --- a/lib_pypy/_sqlite3.py +++ b/lib_pypy/_sqlite3.py @@ -244,8 +244,6 @@ sqlite.sqlite3_enable_load_extension.argtypes = [c_void_p, c_int] sqlite.sqlite3_enable_load_extension.restype = c_int -_DML, _DQL, _DDL = range(3) - ########################################## # END Wrapped SQLite C API and constants ########################################## @@ -306,9 +304,9 @@ if len(self.cache) > self.maxcount: self.cache.popitem(0) - if stat.in_use: + if stat._in_use: stat = Statement(self.connection, sql) - stat.set_row_factory(row_factory) + stat._row_factory = row_factory return stat @@ -367,7 +365,7 @@ for statement in self.__statements: obj = statement() if obj is not None: - obj.finalize() + obj._finalize() if self._db: ret = sqlite.sqlite3_close(self._db) @@ -501,7 +499,7 @@ for statement in self.__statements: obj = statement() if obj is not None: - obj.reset() + obj._reset() statement = c_void_p() next_char = c_char_p() @@ -526,7 +524,7 @@ for statement in self.__statements: obj = statement() if obj is not None: - obj.reset() + obj._reset() for cursor_ref in self._cursors: cursor = cursor_ref() @@ -778,13 +776,13 @@ except ValueError: pass if self.__statement: - self.__statement.reset() + self.__statement._reset() def close(self): self.__connection._check_thread() self.__connection._check_closed() if self.__statement: - self.__statement.reset() + self.__statement._reset() self.__statement = None self.__closed = True @@ -798,8 +796,15 @@ self.__connection._check_thread() self.__connection._check_closed() + def __check_cursor_wrap(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + self.__check_cursor() + return func(self, *args, **kwargs) + return wrapper + + @__check_cursor_wrap def execute(self, sql, params=None): - self.__check_cursor() self.__locked = True try: self.__description = None @@ -808,43 +813,43 @@ sql, self.row_factory) if self.__connection._isolation_level is not None: - if self.__statement.kind == _DDL: + if self.__statement._kind == Statement._DDL: if self.__connection._in_transaction: self.__connection.commit() - elif self.__statement.kind == _DML: + elif self.__statement._kind == Statement._DML: if not self.__connection._in_transaction: self.__connection._begin() - self.__statement.set_params(params) + self.__statement._set_params(params) # Actually execute the SQL statement - ret = sqlite.sqlite3_step(self.__statement.statement) + ret = sqlite.sqlite3_step(self.__statement._statement) if ret not in (SQLITE_DONE, SQLITE_ROW): - self.__statement.reset() + self.__statement._reset() self.__connection._in_transaction = \ not sqlite.sqlite3_get_autocommit(self.__connection._db) raise self.__connection._get_exception(ret) - if self.__statement.kind == _DML: - self.__statement.reset() + if self.__statement._kind == Statement._DML: + self.__statement._reset() - if self.__statement.kind == _DQL and ret == SQLITE_ROW: + if self.__statement._kind == Statement._DQL and ret == SQLITE_ROW: self.__statement._build_row_cast_map() self.__statement._readahead(self) else: - self.__statement.item = None - self.__statement.exhausted = True + self.__statement._item = None + self.__statement._exhausted = True self.__rowcount = -1 - if self.__statement.kind == _DML: + if self.__statement._kind == Statement._DML: self.__rowcount = sqlite.sqlite3_changes(self.__connection._db) finally: self.__locked = False return self + @__check_cursor_wrap def executemany(self, sql, many_params): - self.__check_cursor() self.__locked = True try: self.__description = None @@ -852,7 +857,7 @@ self.__statement = self.__connection._statement_cache.get( sql, self.row_factory) - if self.__statement.kind == _DML: + if self.__statement._kind == Statement._DML: if self.__connection._isolation_level is not None: if not self.__connection._in_transaction: self.__connection._begin() @@ -861,15 +866,15 @@ self.__rowcount = 0 for params in many_params: - self.__statement.set_params(params) - ret = sqlite.sqlite3_step(self.__statement.statement) + self.__statement._set_params(params) + ret = sqlite.sqlite3_step(self.__statement._statement) if ret != SQLITE_DONE: - self.__statement.reset() + self.__statement._reset() self.__connection._in_transaction = \ not sqlite.sqlite3_get_autocommit(self.__connection._db) raise self.__connection._get_exception(ret) self.__rowcount += sqlite.sqlite3_changes(self.__connection._db) - self.__statement.reset() + self.__statement._reset() finally: self.__locked = False @@ -926,7 +931,7 @@ return None try: - return self.__statement.next(self) + return self.__statement._next(self) except StopIteration: return None @@ -980,56 +985,66 @@ class Statement(object): - statement = None + _DML, _DQL, _DDL = range(3) + + _statement = None def __init__(self, connection, sql): + self.__con = connection + if not isinstance(sql, str): raise ValueError("sql must be a string") - self.con = connection - self.sql = sql # DEBUG ONLY first_word = self._statement_kind = sql.lstrip().split(" ")[0].upper() if first_word in ("INSERT", "UPDATE", "DELETE", "REPLACE"): - self.kind = _DML + self._kind = Statement._DML elif first_word in ("SELECT", "PRAGMA"): - self.kind = _DQL + self._kind = Statement._DQL else: - self.kind = _DDL - self.exhausted = False - self.in_use = False - # - # set by set_row_factory - self.row_factory = None + self._kind = Statement._DDL - self.statement = c_void_p() + self._in_use = False + self._exhausted = False + self._row_factory = None + + self._statement = c_void_p() next_char = c_char_p() sql_char = sql - ret = sqlite.sqlite3_prepare_v2(self.con._db, sql_char, -1, byref(self.statement), byref(next_char)) - if ret == SQLITE_OK and self.statement.value is None: + ret = sqlite.sqlite3_prepare_v2(self.__con._db, sql_char, -1, byref(self._statement), byref(next_char)) + if ret == SQLITE_OK and self._statement.value is None: # an empty statement, we work around that, as it's the least trouble - ret = sqlite.sqlite3_prepare_v2(self.con._db, "select 42", -1, byref(self.statement), byref(next_char)) - self.kind = _DQL + ret = sqlite.sqlite3_prepare_v2(self.__con._db, "select 42", -1, byref(self._statement), byref(next_char)) + self._kind = Statement._DQL if ret != SQLITE_OK: - raise self.con._get_exception(ret) - self.con._remember_statement(self) + raise self.__con._get_exception(ret) + self.__con._remember_statement(self) next_char = next_char.value.decode('utf-8') if _check_remaining_sql(next_char): raise Warning("One and only one statement required: %r" % (next_char,)) - # sql_char should remain alive until here - self._build_row_cast_map() + def __del__(self): + if self._statement: + sqlite.sqlite3_finalize(self._statement) - def set_row_factory(self, row_factory): - self.row_factory = row_factory + def _finalize(self): + sqlite.sqlite3_finalize(self._statement) + self._statement = None + self._in_use = False + + def _reset(self): + ret = sqlite.sqlite3_reset(self._statement) + self._in_use = False + self._exhausted = False + return ret def _build_row_cast_map(self): - self.row_cast_map = [] - for i in range(sqlite.sqlite3_column_count(self.statement)): + self.__row_cast_map = [] + for i in range(sqlite.sqlite3_column_count(self._statement)): converter = None - if self.con._detect_types & PARSE_COLNAMES: - colname = sqlite.sqlite3_column_name(self.statement, i) + if self.__con._detect_types & PARSE_COLNAMES: + colname = sqlite.sqlite3_column_name(self._statement, i) if colname is not None: colname = colname.decode('utf-8') type_start = -1 @@ -1041,8 +1056,8 @@ key = colname[type_start:pos] converter = converters[key.upper()] - if converter is None and self.con._detect_types & PARSE_DECLTYPES: - decltype = sqlite.sqlite3_column_decltype(self.statement, i) + if converter is None and self.__con._detect_types & PARSE_DECLTYPES: + decltype = sqlite.sqlite3_column_decltype(self._statement, i) if decltype is not None: decltype = decltype.split()[0] # if multiple words, use first, eg. "INTEGER NOT NULL" => "INTEGER" decltype = decltype.decode('utf-8') @@ -1050,9 +1065,9 @@ decltype = decltype[:decltype.index('(')] converter = converters.get(decltype.upper(), None) - self.row_cast_map.append(converter) + self.__row_cast_map.append(converter) - def set_param(self, idx, param): + def __set_param(self, idx, param): cvt = converters.get(type(param)) if cvt is not None: cvt = param = cvt(param) @@ -1060,32 +1075,32 @@ param = adapt(param) if param is None: - sqlite.sqlite3_bind_null(self.statement, idx) + sqlite.sqlite3_bind_null(self._statement, idx) elif type(param) in (bool, int): if -2147483648 <= param <= 2147483647: - sqlite.sqlite3_bind_int(self.statement, idx, param) + sqlite.sqlite3_bind_int(self._statement, idx, param) else: - sqlite.sqlite3_bind_int64(self.statement, idx, param) + sqlite.sqlite3_bind_int64(self._statement, idx, param) elif type(param) is float: - sqlite.sqlite3_bind_double(self.statement, idx, param) + sqlite.sqlite3_bind_double(self._statement, idx, param) elif isinstance(param, str): param = param.encode('utf-8') - sqlite.sqlite3_bind_text(self.statement, idx, param, len(param), SQLITE_TRANSIENT) + sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), SQLITE_TRANSIENT) elif type(param) in (bytes, memoryview): param = bytes(param) - sqlite.sqlite3_bind_blob(self.statement, idx, param, len(param), SQLITE_TRANSIENT) + sqlite.sqlite3_bind_blob(self._statement, idx, param, len(param), SQLITE_TRANSIENT) else: raise InterfaceError("parameter type %s is not supported" % type(param)) - def set_params(self, params): - ret = sqlite.sqlite3_reset(self.statement) + def _set_params(self, params): + ret = sqlite.sqlite3_reset(self._statement) if ret != SQLITE_OK: - raise self.con._get_exception(ret) - self.mark_dirty() + raise self.__con._get_exception(ret) + self._in_use = True if params is None: - if sqlite.sqlite3_bind_parameter_count(self.statement) != 0: + if sqlite.sqlite3_bind_parameter_count(self._statement) != 0: raise ProgrammingError("wrong number of arguments") return @@ -1096,14 +1111,14 @@ params_type = list if params_type == list: - if len(params) != sqlite.sqlite3_bind_parameter_count(self.statement): + if len(params) != sqlite.sqlite3_bind_parameter_count(self._statement): raise ProgrammingError("wrong number of arguments") for i in range(len(params)): - self.set_param(i+1, params[i]) + self.__set_param(i+1, params[i]) else: - for idx in range(1, sqlite.sqlite3_bind_parameter_count(self.statement) + 1): - param_name = sqlite.sqlite3_bind_parameter_name(self.statement, idx) + for idx in range(1, sqlite.sqlite3_bind_parameter_count(self._statement) + 1): + param_name = sqlite.sqlite3_bind_parameter_name(self._statement, idx) if param_name is None: raise ProgrammingError("need named parameters") param_name = param_name[1:].decode('utf-8') @@ -1111,92 +1126,73 @@ param = params[param_name] except KeyError: raise ProgrammingError("missing parameter %r" % param_name) - self.set_param(idx, param) + self.__set_param(idx, param) - def next(self, cursor): - self.con._check_closed() - self.con._check_thread() - if self.exhausted: + def _next(self, cursor): + self.__con._check_closed() + self.__con._check_thread() + if self._exhausted: raise StopIteration - item = self.item + item = self._item - ret = sqlite.sqlite3_step(self.statement) + ret = sqlite.sqlite3_step(self._statement) if ret == SQLITE_DONE: - self.exhausted = True - self.item = None + self._exhausted = True + self._item = None elif ret != SQLITE_ROW: - exc = self.con._get_exception(ret) - sqlite.sqlite3_reset(self.statement) + exc = self.__con._get_exception(ret) + sqlite.sqlite3_reset(self._statement) raise exc self._readahead(cursor) return item def _readahead(self, cursor): - self.column_count = sqlite.sqlite3_column_count(self.statement) + self.column_count = sqlite.sqlite3_column_count(self._statement) row = [] for i in range(self.column_count): - typ = sqlite.sqlite3_column_type(self.statement, i) + typ = sqlite.sqlite3_column_type(self._statement, i) - converter = self.row_cast_map[i] + converter = self.__row_cast_map[i] if converter is None: if typ == SQLITE_INTEGER: - val = sqlite.sqlite3_column_int64(self.statement, i) + val = sqlite.sqlite3_column_int64(self._statement, i) if -sys.maxsize-1 <= val <= sys.maxsize: val = int(val) elif typ == SQLITE_FLOAT: - val = sqlite.sqlite3_column_double(self.statement, i) + val = sqlite.sqlite3_column_double(self._statement, i) elif typ == SQLITE_BLOB: - blob_len = sqlite.sqlite3_column_bytes(self.statement, i) - blob = sqlite.sqlite3_column_blob(self.statement, i) + blob_len = sqlite.sqlite3_column_bytes(self._statement, i) + blob = sqlite.sqlite3_column_blob(self._statement, i) val = bytes(string_at(blob, blob_len)) elif typ == SQLITE_NULL: val = None elif typ == SQLITE_TEXT: - text_len = sqlite.sqlite3_column_bytes(self.statement, i) - text = sqlite.sqlite3_column_text(self.statement, i) + text_len = sqlite.sqlite3_column_bytes(self._statement, i) + text = sqlite.sqlite3_column_text(self._statement, i) val = string_at(text, text_len) - val = self.con.text_factory(val) + val = self.__con.text_factory(val) else: - blob = sqlite.sqlite3_column_blob(self.statement, i) + blob = sqlite.sqlite3_column_blob(self._statement, i) if not blob: val = None else: - blob_len = sqlite.sqlite3_column_bytes(self.statement, i) + blob_len = sqlite.sqlite3_column_bytes(self._statement, i) val = string_at(blob, blob_len) val = converter(val) row.append(val) row = tuple(row) - if self.row_factory is not None: - row = self.row_factory(cursor, row) - self.item = row - - def reset(self): - self.row_cast_map = None - ret = sqlite.sqlite3_reset(self.statement) - self.in_use = False - self.exhausted = False - return ret - - def finalize(self): - sqlite.sqlite3_finalize(self.statement) - self.statement = None - self.in_use = False - - def mark_dirty(self): - self.in_use = True - - def __del__(self): - if self.statement: - sqlite.sqlite3_finalize(self.statement) + if self._row_factory is not None: + row = self._row_factory(cursor, row) + self._item = row def _get_description(self): - if self.kind == _DML: + if self._kind == Statement._DML: return None desc = [] - for i in range(sqlite.sqlite3_column_count(self.statement)): - col_name = sqlite.sqlite3_column_name(self.statement, i) + for i in range(sqlite.sqlite3_column_count(self._statement)): + col_name = sqlite.sqlite3_column_name(self._statement, i) name = col_name.decode('utf-8').split("[")[0].strip() desc.append((name, None, None, None, None, None, None)) return desc diff --git a/pypy/module/test_lib_pypy/test_sqlite3.py b/pypy/module/test_lib_pypy/test_sqlite3.py --- a/pypy/module/test_lib_pypy/test_sqlite3.py +++ b/pypy/module/test_lib_pypy/test_sqlite3.py @@ -58,6 +58,9 @@ cur.close() con.close() pytest.raises(_sqlite3.ProgrammingError, "cur.close()") + # raises ProgrammingError because should check closed before check args + pytest.raises(_sqlite3.ProgrammingError, "cur.execute(1,2,3,4,5)") + pytest.raises(_sqlite3.ProgrammingError, "cur.executemany(1,2,3,4,5)") @pytest.mark.skipif("not hasattr(sys, 'pypy_translation_info')") def test_cursor_del(): _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit