Author: Brian Kearns <bdkea...@gmail.com> Branch: py3k Changeset: r62192:f7086f05b2e3 Date: 2013-03-07 17:43 -0500 http://bitbucket.org/pypy/pypy/changeset/f7086f05b2e3/
Log: merge default diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py --- a/lib_pypy/_sqlite3.py +++ b/lib_pypy/_sqlite3.py @@ -34,6 +34,16 @@ import weakref from threading import _get_ident as _thread_get_ident +if sys.version_info[0] >= 3: + StandardError = Exception + long = int + xrange = range + basestring = unicode = str + buffer = memoryview + BLOB_TYPE = bytes +else: + BLOB_TYPE = buffer + names = "sqlite3.dll libsqlite3.so.0 libsqlite3.so libsqlite3.dylib".split() for name in names: try: @@ -243,12 +253,12 @@ ########################################## # SQLite version information -sqlite_version = sqlite.sqlite3_libversion().decode('ascii') +sqlite_version = str(sqlite.sqlite3_libversion().decode('ascii')) -class Error(Exception): +class Error(StandardError): pass -class Warning(Exception): +class Warning(StandardError): pass class InterfaceError(Error): @@ -280,7 +290,17 @@ return factory(database, **kwargs) def unicode_text_factory(x): - return str(x, 'utf-8') + return unicode(x, 'utf-8') + +if sys.version_info[0] < 3: + def OptimizedUnicode(s): + try: + val = unicode(s, "ascii").encode("ascii") + except UnicodeDecodeError: + val = unicode(s, "utf-8") + return val +else: + OptimizedUnicode = unicode_text_factory class _StatementCache(object): @@ -313,7 +333,8 @@ self.__initialized = True self._db = c_void_p() - database = database.encode('utf-8') + if isinstance(database, unicode): + database = database.encode('utf-8') if sqlite.sqlite3_open(database, byref(self._db)) != SQLITE_OK: raise OperationalError("Could not open database") if timeout is not None: @@ -439,7 +460,7 @@ @_check_thread_wrap @_check_closed_wrap def __call__(self, sql): - if not isinstance(sql, str): + if not isinstance(sql, basestring): raise Warning("SQL is of wrong type. Must be string or unicode.") return self._statement_cache.get(sql, self.row_factory) @@ -556,7 +577,8 @@ c_closure = _FUNC(closure) self.__func_cache[callback] = c_closure, closure - name = name.encode('utf-8') + if isinstance(name, unicode): + name = name.encode('utf-8') ret = sqlite.sqlite3_create_function(self._db, name, num_args, SQLITE_UTF8, None, c_closure, @@ -624,7 +646,8 @@ self.__aggregates[cls] = (c_step_callback, c_final_callback, step_callback, final_callback) - name = name.encode('utf-8') + if isinstance(name, unicode): + name = name.encode('utf-8') ret = sqlite.sqlite3_create_function(self._db, name, num_args, SQLITE_UTF8, None, cast(None, _FUNC), @@ -656,7 +679,8 @@ c_collation_callback = _COLLATION(collation_callback) self.__collations[name] = c_collation_callback - name = name.encode('utf-8') + if isinstance(name, unicode): + name = name.encode('utf-8') ret = sqlite.sqlite3_create_collation(self._db, name, SQLITE_UTF8, None, @@ -710,9 +734,10 @@ if ret != SQLITE_OK: raise self._get_exception(ret) - def __get_in_transaction(self): - return self._in_transaction - in_transaction = property(__get_in_transaction) + if sys.version_info[0] >= 3: + def __get_in_transaction(self): + return self._in_transaction + in_transaction = property(__get_in_transaction) def __get_total_changes(self): self._check_closed() @@ -726,7 +751,7 @@ if val is None: self.commit() else: - self.__begin_statement = b"BEGIN " + val.encode('utf-8') + self.__begin_statement = str("BEGIN " + val).encode('utf-8') self._isolation_level = val isolation_level = property(__get_isolation_level, __set_isolation_level) @@ -800,7 +825,7 @@ try: self.__description = None self._reset = False - if not isinstance(sql, str): + if not isinstance(sql, basestring): raise ValueError("operation parameter must be str or unicode") self.__statement = self.__connection._statement_cache.get( sql, self.row_factory) @@ -847,7 +872,7 @@ try: self.__description = None self._reset = False - if not isinstance(sql, str): + if not isinstance(sql, basestring): raise ValueError("operation parameter must be str or unicode") self.__statement = self.__connection._statement_cache.get( sql, self.row_factory) @@ -880,9 +905,9 @@ self._reset = False self.__check_cursor() statement = c_void_p() - if isinstance(sql, str): + if isinstance(sql, unicode): sql = sql.encode('utf-8') - else: + elif not isinstance(sql, str): raise ValueError("script argument must be unicode or string.") c_sql = c_char_p(sql) @@ -982,7 +1007,7 @@ def __init__(self, connection, sql): self.__con = connection - if not isinstance(sql, str): + if not isinstance(sql, basestring): raise ValueError("sql must be a string") first_word = self._statement_kind = sql.lstrip().split(" ")[0].upper() if first_word in ("INSERT", "UPDATE", "DELETE", "REPLACE"): @@ -998,7 +1023,8 @@ self._statement = c_void_p() next_char = c_char_p() - sql = sql.encode('utf-8') + if isinstance(sql, unicode): + sql = sql.encode('utf-8') ret = sqlite.sqlite3_prepare_v2(self.__con._db, sql, -1, byref(self._statement), byref(next_char)) if ret == SQLITE_OK and self._statement.value is None: @@ -1032,7 +1058,7 @@ def _build_row_cast_map(self): self.__row_cast_map = [] - for i in range(sqlite.sqlite3_column_count(self._statement)): + for i in xrange(sqlite.sqlite3_column_count(self._statement)): converter = None if self.__con._detect_types & PARSE_COLNAMES: @@ -1059,6 +1085,19 @@ self.__row_cast_map.append(converter) + if sys.version_info[0] < 3: + def __check_decodable(self, param): + if self.__con.text_factory in (unicode, OptimizedUnicode, + unicode_text_factory): + for c in param: + if ord(c) & 0x80 != 0: + raise self.__con.ProgrammingError( + "You must not use 8-bit bytestrings unless " + "you use a text_factory that can interpret " + "8-bit bytestrings (like text_factory = str). " + "It is highly recommended that you instead " + "just switch your application to Unicode strings.") + def __set_param(self, idx, param): cvt = converters.get(type(param)) if cvt is not None: @@ -1068,17 +1107,20 @@ if param is None: rc = sqlite.sqlite3_bind_null(self._statement, idx) - elif type(param) in (bool, int): + elif isinstance(param, (bool, int, long)): if -2147483648 <= param <= 2147483647: rc = sqlite.sqlite3_bind_int(self._statement, idx, param) else: rc = sqlite.sqlite3_bind_int64(self._statement, idx, param) - elif type(param) is float: + elif isinstance(param, float): rc = sqlite.sqlite3_bind_double(self._statement, idx, param) - elif isinstance(param, str): + elif isinstance(param, unicode): param = param.encode("utf-8") rc = sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), SQLITE_TRANSIENT) - elif type(param) in (bytes, memoryview): + elif isinstance(param, str): + self.__check_decodable(param) + rc = sqlite.sqlite3_bind_text(self._statement, idx, param, len(param), SQLITE_TRANSIENT) + elif isinstance(param, (buffer, bytes)): param = bytes(param) rc = sqlite.sqlite3_bind_blob(self._statement, idx, param, len(param), SQLITE_TRANSIENT) else: @@ -1147,28 +1189,26 @@ def _readahead(self, cursor): self.column_count = sqlite.sqlite3_column_count(self._statement) row = [] - for i in range(self.column_count): + for i in xrange(self.column_count): typ = sqlite.sqlite3_column_type(self._statement, i) converter = self.__row_cast_map[i] if converter is None: - if typ == SQLITE_INTEGER: + if typ == SQLITE_NULL: + val = None + elif typ == SQLITE_INTEGER: val = sqlite.sqlite3_column_int64(self._statement, i) - if -sys.maxsize-1 <= val <= sys.maxsize: - val = int(val) elif typ == SQLITE_FLOAT: val = sqlite.sqlite3_column_double(self._statement, i) - elif typ == SQLITE_BLOB: - blob = sqlite.sqlite3_column_blob(self._statement, i) - blob_len = sqlite.sqlite3_column_bytes(self._statement, i) - val = bytes(string_at(blob, blob_len)) - elif typ == SQLITE_NULL: - val = None elif typ == SQLITE_TEXT: text = sqlite.sqlite3_column_text(self._statement, i) text_len = sqlite.sqlite3_column_bytes(self._statement, i) val = string_at(text, text_len) val = self.__con.text_factory(val) + elif typ == SQLITE_BLOB: + blob = sqlite.sqlite3_column_blob(self._statement, i) + blob_len = sqlite.sqlite3_column_bytes(self._statement, i) + val = BLOB_TYPE(string_at(blob, blob_len)) else: blob = sqlite.sqlite3_column_blob(self._statement, i) if not blob: @@ -1188,7 +1228,7 @@ if self._kind == Statement._DML: return None desc = [] - for i in range(sqlite.sqlite3_column_count(self._statement)): + for i in xrange(sqlite.sqlite3_column_count(self._statement)): name = sqlite.sqlite3_column_name(self._statement, i) if name is not None: name = name.decode('utf-8').split("[")[0].strip() @@ -1277,21 +1317,19 @@ _params = [] for i in range(nargs): typ = sqlite.sqlite3_value_type(params[i]) - if typ == SQLITE_INTEGER: + if typ == SQLITE_NULL: + val = None + elif typ == SQLITE_INTEGER: val = sqlite.sqlite3_value_int64(params[i]) - if -sys.maxsize-1 <= val <= sys.maxsize: - val = int(val) elif typ == SQLITE_FLOAT: val = sqlite.sqlite3_value_double(params[i]) + elif typ == SQLITE_TEXT: + val = sqlite.sqlite3_value_text(params[i]) + val = val.decode('utf-8') elif typ == SQLITE_BLOB: blob = sqlite.sqlite3_value_blob(params[i]) blob_len = sqlite.sqlite3_value_bytes(params[i]) - val = bytes(string_at(blob, blob_len)) - elif typ == SQLITE_NULL: - val = None - elif typ == SQLITE_TEXT: - val = sqlite.sqlite3_value_text(params[i]) - val = val.decode('utf-8') + val = BLOB_TYPE(string_at(blob, blob_len)) else: raise NotImplementedError _params.append(val) @@ -1301,14 +1339,16 @@ def _convert_result(con, val): if val is None: sqlite.sqlite3_result_null(con) - elif isinstance(val, (bool, int)): + elif isinstance(val, (bool, int, long)): sqlite.sqlite3_result_int64(con, int(val)) - elif isinstance(val, str): + elif isinstance(val, float): + sqlite.sqlite3_result_double(con, val) + elif isinstance(val, unicode): val = val.encode('utf-8') sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT) - elif isinstance(val, float): - sqlite.sqlite3_result_double(con, val) - elif isinstance(val, (bytes, memoryview)): + elif isinstance(val, str): + sqlite.sqlite3_result_text(con, val, len(val), SQLITE_TRANSIENT) + elif isinstance(val, (buffer, bytes)): sqlite.sqlite3_result_blob(con, bytes(val), len(val), SQLITE_TRANSIENT) else: raise NotImplementedError @@ -1380,8 +1420,8 @@ microseconds = int(timepart_full[1]) else: microseconds = 0 - return datetime.datetime(year, month, day, - hours, minutes, seconds, microseconds) + return datetime.datetime(year, month, day, hours, minutes, seconds, + microseconds) register_adapter(datetime.date, adapt_date) register_adapter(datetime.datetime, adapt_datetime) @@ -1418,6 +1458,3 @@ return val register_adapters_and_converters() - - -OptimizedUnicode = unicode_text_factory _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit