Author: Amaury Forgeot d'Arc <amaur...@gmail.com> Branch: stdlib-2.7.3 Changeset: r55670:6749c2482195 Date: 2012-06-14 23:13 +0200 http://bitbucket.org/pypy/pypy/changeset/6749c2482195/
Log: CPython Issue #10811: sqlite: Fix recursive usage of cursors. diff --git a/lib_pypy/_sqlite3.py b/lib_pypy/_sqlite3.py --- a/lib_pypy/_sqlite3.py +++ b/lib_pypy/_sqlite3.py @@ -722,6 +722,19 @@ DML, DQL, DDL = range(3) +class CursorLock(object): + def __init__(self, cursor): + self.cursor = cursor + + def __enter__(self): + if self.cursor.locked: + raise ProgrammingError("Recursive use of cursors not allowed.") + self.cursor.locked = True + + def __exit__(self, *args): + self.cursor.locked = False + + class Cursor(object): def __init__(self, con): if not isinstance(con, Connection): @@ -736,6 +749,7 @@ self.rowcount = -1 self.statement = None self.reset = False + self.locked = False def _check_closed(self): if not getattr(self, 'connection', None): @@ -743,64 +757,72 @@ self.connection._check_thread() self.connection._check_closed() + def _check_and_lock(self): + self._check_closed() + return CursorLock(self) + def execute(self, sql, params=None): - self._description = None - self.reset = False if type(sql) is unicode: sql = sql.encode("utf-8") - self._check_closed() - self.statement = self.connection.statement_cache.get(sql, self, self.row_factory) - if self.connection._isolation_level is not None: - if self.statement.kind == DDL: - self.connection.commit() - elif self.statement.kind == DML: - self.connection._begin() + with self._check_and_lock(): + self._description = None + self.reset = False + self.statement = self.connection.statement_cache.get( + sql, self, self.row_factory) - self.statement.set_params(params) + if self.connection._isolation_level is not None: + if self.statement.kind == DDL: + self.connection.commit() + elif self.statement.kind == DML: + self.connection._begin() - # Actually execute the SQL statement - ret = sqlite.sqlite3_step(self.statement.statement) - if ret not in (SQLITE_DONE, SQLITE_ROW): - self.statement.reset() - raise self.connection._get_exception(ret) + self.statement.set_params(params) - if self.statement.kind == DQL and ret == SQLITE_ROW: - self.statement._build_row_cast_map() - self.statement._readahead(self) - else: - self.statement.item = None - self.statement.exhausted = True + # Actually execute the SQL statement + ret = sqlite.sqlite3_step(self.statement.statement) + if ret not in (SQLITE_DONE, SQLITE_ROW): + self.statement.reset() + raise self.connection._get_exception(ret) - if self.statement.kind == DML: - self.statement.reset() + if self.statement.kind == DQL and ret == SQLITE_ROW: + self.statement._build_row_cast_map() + self.statement._readahead(self) + else: + self.statement.item = None + self.statement.exhausted = True - self.rowcount = -1 - if self.statement.kind == DML: - self.rowcount = sqlite.sqlite3_changes(self.connection.db) + if self.statement.kind == DML: + self.statement.reset() + + self.rowcount = -1 + if self.statement.kind == DML: + self.rowcount = sqlite.sqlite3_changes(self.connection.db) return self def executemany(self, sql, many_params): - self._description = None - self.reset = False if type(sql) is unicode: sql = sql.encode("utf-8") - self._check_closed() - self.statement = self.connection.statement_cache.get(sql, self, self.row_factory) - if self.statement.kind == DML: - self.connection._begin() - else: - raise ProgrammingError("executemany is only for DML statements") + with self._check_and_lock(): + self._description = None + self.reset = False + self.statement = self.connection.statement_cache.get( + sql, self, self.row_factory) - self.rowcount = 0 - for params in many_params: - self.statement.set_params(params) - ret = sqlite.sqlite3_step(self.statement.statement) - if ret != SQLITE_DONE: - raise self.connection._get_exception(ret) - self.rowcount += sqlite.sqlite3_changes(self.connection.db) + if self.statement.kind == DML: + self.connection._begin() + else: + raise ProgrammingError("executemany is only for DML statements") + + self.rowcount = 0 + for params in many_params: + self.statement.set_params(params) + ret = sqlite.sqlite3_step(self.statement.statement) + if ret != SQLITE_DONE: + raise self.connection._get_exception(ret) + self.rowcount += sqlite.sqlite3_changes(self.connection.db) return self _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit