Hey Mike, Here's a new patch, which uses your mapper extension idea rather than the session_context mapper() parameter. I also removed orm.session.current_session and friends, and updated the unitofwork documentation accordingly.
Note that I have implemented the new mapper extension method you had proposed as get_session() rather than get_session_context(). This allowed me to encapsulate the entire SessionContext interface within the mapper extension. Mapper also grew it's own get_session() method, which is used by Query. I'm not sure how many (if any) tests this may have broken. I've always had problems trying to run the tests...is there anything special I need to do? Give it a whirl and let me know what you think. ~ Daniel
Index: doc/build/content/unitofwork.txt =================================================================== --- doc/build/content/unitofwork.txt (revision 1453) +++ doc/build/content/unitofwork.txt (working copy) @@ -56,29 +56,21 @@ {python} session = object_session(obj) -When default session contexts are enabled, the current contextual session can be acquired by the `current_session()` function. This function takes an optional instance argument, which allows session contexts that are specific to particular class hierarchies to return the correct session. When using the "threadlocal" session context, enabled via the *mod* `sqlalchemy.mods.threadlocal`, no instance argument is required: +It is possible to install a default "threadlocal" session context by importing a *mod* called `sqlalchemy.mods.threadlocal`. This mod creates a familiar SA 0.1 keyword `objectstore` in the `sqlalchemy` namespace. The `objectstore` may be used directly like a session; all session actions performed on `sqlalchemy.objectstore` will be *proxied* to the thread-local Session: {python} - # enable the thread-local default session context (only need to call this once per application) + # install 'threadlocal' mod (only need to call this once per application) import sqlalchemy.mods.threadlocal - - # return the Session that is bound to the current thread - session = current_session() -When using the `threadlocal` mod, a familiar SA 0.1 keyword `objectstore` is imported into the `sqlalchemy` namespace. Using `objectstore`, methods can be called which will automatically be *proxied* to the Session that corresponds to `current_session()`: - - {python} - # load the 'threadlocal' mod *first* - import sqlalchemy.mods.threadlocal - # then 'objectstore' is available within the 'sqlalchemy' namespace - from sqlalchemy import * - - # and then this... - current_session().flush() - - # is the same as this: + from sqlalchemy import objectstore + + # flush the current thread-local session using the objectstore directly objectstore.flush() + + # which is the same as this (assuming we are still on the same thread): + session = objectstore.get_session() + session.flush() We will now cover some of the key concepts used by Sessions and its underlying Unit of Work. Index: lib/sqlalchemy/ext/sessioncontext.py =================================================================== --- lib/sqlalchemy/ext/sessioncontext.py (revision 1453) +++ lib/sqlalchemy/ext/sessioncontext.py (working copy) @@ -1,4 +1,5 @@ from sqlalchemy.util import ScopedRegistry +from sqlalchemy.orm.mapper import MapperExtension class SessionContext(object): """A simple wrapper for ScopedRegistry that provides a "current" property @@ -31,88 +32,22 @@ current = property(get_current, set_current, del_current, """Property used to get/set/del the session in the current scope""") - def create_metaclass(session_context): - """return a metaclass to be used by objects that wish to be bound to a - thread-local session upon instantiatoin. - - Note non-standard use of session_context rather than self as the name - of the first arguement of this method. - - Usage: - context = SessionContext(...) - class MyClass(object): - __metaclass__ = context.metaclass - ... - """ + def _get_mapper_extension(self): try: - return session_context._metaclass + return self._extension except AttributeError: - class metaclass(type): - def __init__(cls, name, bases, dct): - old_init = getattr(cls, "__init__") - def __init__(self, *args, **kwargs): - session_context.current.save(self) - old_init(self, *args, **kwargs) - setattr(cls, "__init__", __init__) - super(metaclass, cls).__init__(name, bases, dct) - session_context._metaclass = metaclass - return metaclass - metaclass = property(create_metaclass) + self._extension = ext = SessionContextExt(self) + return ext + mapper_extension = property(_get_mapper_extension, + doc="""get a mapper extension that implements get_session using this context""") - def create_baseclass(session_context): - """return a baseclass to be used by objects that wish to be bound to a - thread-local session upon instantiatoin. - Note non-standard use of session_context rather than self as the name - of the first arguement of this method. +class SessionContextExt(MapperExtension): + """a mapper extionsion that provides sessions to a mapper using SessionContext""" - Usage: - context = SessionContext(...) - class MyClass(context.baseclass): - ... - """ - try: - return session_context._baseclass - except AttributeError: - class baseclass(object): - def __init__(self, *args, **kwargs): - session_context.current.save(self) - super(baseclass, self).__init__(*args, **kwargs) - session_context._baseclass = baseclass - return baseclass - baseclass = property(create_baseclass) - - -def test(): - - def run_test(class_, context): - obj = class_() - assert context.current == get_session(obj) - - # keep a reference so the old session doesn't get gc'd - old_session = context.current - - context.current = create_session() - assert context.current != get_session(obj) - assert old_session == get_session(obj) - - del context.current - assert context.current != get_session(obj) - assert old_session == get_session(obj) - - obj2 = class_() - assert context.current == get_session(obj2) - - # test metaclass - context = SessionContext(create_session) - class MyClass(object): __metaclass__ = context.metaclass - run_test(MyClass, context) - - # test baseclass - context = SessionContext(create_session) - class MyClass(context.baseclass): pass - run_test(MyClass, context) - -if __name__ == "__main__": - test() - print "All tests passed!" + def __init__(self, context): + MapperExtension.__init__(self) + self.context = context + + def get_session(self): + return self.context.current Index: lib/sqlalchemy/mods/threadlocal.py =================================================================== --- lib/sqlalchemy/mods/threadlocal.py (revision 1453) +++ lib/sqlalchemy/mods/threadlocal.py (working copy) @@ -1,5 +1,7 @@ from sqlalchemy import util, engine, mapper -from sqlalchemy.orm import session, current_session +from sqlalchemy.ext.sessioncontext import SessionContext +from sqlalchemy.orm.mapper import global_extensions +from sqlalchemy.orm.session import Session import sqlalchemy import sys, types @@ -10,19 +12,19 @@ from the pool. this greatly helps functions that call multiple statements to be able to easily use just one connection without explicit "close" statements on result handles. -on the Session side, the current_session() method will be modified to return a thread-local Session when no arguments -are sent. It will also install module-level methods within the objectstore module, such as flush(), delete(), etc. -which call this method on the thread-local session returned by current_session(). +on the Session side, module-level methods will be installed within the objectstore module, such as flush(), delete(), etc. +which call this method on the thread-local session. - +Note: this mod creates a global, thread-local session context named sqlalchemy.objectstore. All mappers created +while this mod is installed will reference this global context when creating new mapped object instances. """ -class Objectstore(object): +class Objectstore(SessionContext): def __getattr__(self, key): - return getattr(current_session(), key) + return getattr(self.current, key) def get_session(self): - return current_session() - + return self.current + def monkeypatch_query_method(class_, name): def do(self, *args, **kwargs): query = class_.mapper.query() @@ -31,7 +33,7 @@ def monkeypatch_objectstore_method(class_, name): def do(self, *args, **kwargs): - session = current_session() + session = sqlalchemy.objectstore.current getattr(session, name)(self, *args, **kwargs) setattr(class_, name, do) @@ -48,16 +50,18 @@ monkeypatch_query_method(class_, name) for name in ['flush', 'delete', 'expire', 'refresh', 'expunge', 'merge', 'update', 'save_or_update']: monkeypatch_objectstore_method(class_, name) + +def _mapper_extension(): + return SessionContext._get_mapper_extension(sqlalchemy.objectstore) def install_plugin(): - reg = util.ScopedRegistry(session.Session) - session.register_default_session(lambda *args, **kwargs: reg()) + sqlalchemy.objectstore = objectstore = Objectstore(Session) + global_extensions.append(_mapper_extension) engine.default_strategy = 'threadlocal' - sqlalchemy.objectstore = Objectstore() sqlalchemy.assign_mapper = assign_mapper def uninstall_plugin(): - session.register_default_session(lambda *args, **kwargs:None) engine.default_strategy = 'plain' + global_extensions.remove(_mapper_extension) install_plugin() Index: lib/sqlalchemy/orm/__init__.py =================================================================== --- lib/sqlalchemy/orm/__init__.py (revision 1453) +++ lib/sqlalchemy/orm/__init__.py (working copy) @@ -14,12 +14,11 @@ from query import Query from util import polymorphic_union import properties -from session import current_session from session import Session as create_session __all__ = ['relation', 'backref', 'eagerload', 'lazyload', 'noload', 'deferred', 'defer', 'undefer', 'mapper', 'clear_mappers', 'sql', 'extension', 'class_mapper', 'object_mapper', 'MapperExtension', 'Query', - 'cascade_mappers', 'polymorphic_union', 'current_session', 'create_session', + 'cascade_mappers', 'polymorphic_union', 'create_session', ] def relation(*args, **kwargs): Index: lib/sqlalchemy/orm/mapper.py =================================================================== --- lib/sqlalchemy/orm/mapper.py (revision 1453) +++ lib/sqlalchemy/orm/mapper.py (working copy) @@ -63,7 +63,7 @@ ext = ext_obj.chain(ext) self.extension = ext - + self.class_ = class_ self.entity_name = entity_name self.class_key = ClassKey(class_, entity_name) @@ -325,6 +325,7 @@ """sets up our classes' overridden __init__ method, this mappers hash key as its '_mapper' property, and our columns as its 'c' property. if the class already had a mapper, the old __init__ method is kept the same.""" + ext = self.extension oldinit = self.class_.__init__ def init(self, *args, **kwargs): self._entity_name = kwargs.pop('_sa_entity_name', None) @@ -336,7 +337,9 @@ if kwargs.has_key('_sa_session'): session = kwargs.pop('_sa_session') else: - session = sessionlib.current_session(self) + session = ext.get_session() + if session is EXT_PASS: + session = None if session is not None: session._register_new(self) if oldinit is not None: @@ -349,6 +352,17 @@ mapper_registry[self.class_key] = self if self.entity_name is None: self.class_.c = self.c + + def get_session(self): + """returns the contextual session provided by the mapper extension chain + + raises InvalidRequestError if a session cannot be retrieved from the + extension chain + """ + s = self.extension.get_session() + if s is EXT_PASS: + raise exceptions.InvalidRequestError("No contextual Session is established. Use a MapperExtension that implements get_session or use 'import sqlalchemy.mods.threadlocal' to establish a default thread-local contextual session.") + return s def has_eager(self): """returns True if one of the properties attached to this Mapper is eager loading""" @@ -900,6 +914,14 @@ def chain(self, ext): self.next = ext return self + def get_session(self): + """called to retrieve a contextual Session instance with which to + register a new object. Note: this is not called if a session is + provided with the __init__ params (i.e. _sa_session)""" + if self.next is None: + return EXT_PASS + else: + return self.next.get_session() def select_by(self, query, *args, **kwargs): """overrides the select_by method of the Query object""" if self.next is None: Index: lib/sqlalchemy/orm/query.py =================================================================== --- lib/sqlalchemy/orm/query.py (revision 1453) +++ lib/sqlalchemy/orm/query.py (working copy) @@ -27,7 +27,7 @@ self._get_clause = self.mapper._get_clause def _get_session(self): if self._session is None: - return sessionlib.required_current_session() + return self.mapper.get_session() else: return self._session table = property(lambda s:s.mapper.select_table) Index: lib/sqlalchemy/orm/session.py =================================================================== --- lib/sqlalchemy/orm/session.py (revision 1453) +++ lib/sqlalchemy/orm/session.py (working copy) @@ -427,42 +427,19 @@ # acts as a Registry with which to locate Sessions. this is to enable # object instances to be associated with Sessions without having to attach the # actual Session object directly to the object instance. -_sessions = weakref.WeakValueDictionary() +_sessions = weakref.WeakValueDictionary() -def current_session(obj=None): - if hasattr(obj, '__session__'): - return obj.__session__() - else: - return _default_session(obj=obj) - -# deprecated -get_session=current_session - -def required_current_session(obj=None): - s = current_session(obj) - if s is None: - if obj is None: - raise exceptions.InvalidRequestError("No global-level Session context is established. Use 'import sqlalchemy.mods.threadlocal' to establish a default thread-local context.") - else: - raise exceptions.InvalidRequestError("No Session context is established for class '%s', and no global-level Session context is established. Use 'import sqlalchemy.mods.threadlocal' to establish a default thread-local context." % (obj.__class__)) - return s - -def _default_session(obj=None): - return None -def register_default_session(callable_): - global _default_session - _default_session = callable_ - def object_session(obj): hashkey = getattr(obj, '_sa_session_id', None) if hashkey is not None: - # ok, return that - try: - return _sessions[hashkey] - except KeyError: - return None - else: - return None + return _sessions.get(hashkey) + return None unitofwork.object_session = object_session + +def get_session(obj=None): + """deprecated""" + if obj is not None: + return object_session(obj) + raise exceptions.InvalidRequestError("get_session() is deprecated, and does not return the thread-local session anymore. Use the SessionContext.mapper_extension or import sqlalchemy.mod.threadlocal to establish a default thread-local context.") Index: test/sessioncontext.py =================================================================== --- test/sessioncontext.py (revision 0) +++ test/sessioncontext.py (revision 0) @@ -0,0 +1,93 @@ +''' +def test(): + def run_test(class_, context): + obj = class_() + assert context.current == object_session(obj) + + # keep a reference so the old session doesn't get gc'd + old_session = context.current + + context.current = Session() + assert context.current != object_session(obj) + assert old_session == object_session(obj) + + del context.current + assert context.current != object_session(obj) + assert old_session == object_session(obj) + + obj2 = class_() + assert context.current == object_session(obj2) + + # test metaclass + context = SessionContext(Session) + class MyClass(object): __contextsession__ = context.get_classmethod() + run_test(MyClass, context) + + # test baseclass + context = SessionContext(Session) + class MyClass(context.baseclass): pass + run_test(MyClass, context) + +if __name__ == "__main__": + test() + print "All tests passed!" +''' + +from testbase import PersistTest, AssertMixin +import unittest, sys, os +from sqlalchemy.ext.sessioncontext import SessionContext +from sqlalchemy.orm.session import object_session, Session +from sqlalchemy import * +import testbase + +db = testbase.db + +users = Table('users', db, + Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), + Column('user_name', String(40)), + mysql_engine='innodb' +) + +User = None + +class HistoryTest(AssertMixin): + def setUpAll(self): + db.echo = False + users.create() + db.echo = testbase.echo + def tearDownAll(self): + db.echo = False + users.drop() + db.echo = testbase.echo + def setUp(self): + clear_mappers() + + def do_test(self, class_, context): + """test session assignment on object creation""" + obj = class_() + assert context.current == object_session(obj) + + # keep a reference so the old session doesn't get gc'd + old_session = context.current + + context.current = Session() + assert context.current != object_session(obj) + assert old_session == object_session(obj) + + new_session = context.current + del context.current + assert context.current != new_session + assert old_session == object_session(obj) + + obj2 = class_() + assert context.current == object_session(obj2) + + def test_mapper_extension(self): + context = SessionContext(Session) + class User(object): pass + User.mapper = mapper(User, users, extension=context.mapper_extension) + self.do_test(User, context) + + +if __name__ == "__main__": + testbase.main()