Hey Mike,

Here's a new patch, which uses your mapper extension idea rather than the 
session_context mapper() parameter. I also removed orm.session.current_session 
and friends, and updated the unitofwork documentation accordingly.

Note that I have implemented the new mapper extension method you had proposed 
as get_session() rather than get_session_context(). This allowed me to 
encapsulate the entire SessionContext interface within the mapper extension. 
Mapper also grew it's own get_session() method, which is used by Query.

I'm not sure how many (if any) tests this may have broken. I've always had 
problems trying to run the tests...is there anything special I need to do? Give 
it a whirl and let me know what you think.

~ Daniel
Index: doc/build/content/unitofwork.txt
===================================================================
--- doc/build/content/unitofwork.txt    (revision 1453)
+++ doc/build/content/unitofwork.txt    (working copy)
@@ -56,29 +56,21 @@
     {python}
     session = object_session(obj)
 
-When default session contexts are enabled, the current contextual session can 
be acquired by the `current_session()` function.  This function takes an 
optional instance argument, which allows session contexts that are specific to 
particular class hierarchies to return the correct session.  When using the 
"threadlocal" session context, enabled via the *mod* 
`sqlalchemy.mods.threadlocal`, no instance argument is required:
+It is possible to install a default "threadlocal" session context by importing 
a *mod* called `sqlalchemy.mods.threadlocal`.  This mod creates a familiar SA 
0.1 keyword `objectstore` in the `sqlalchemy` namespace.  The `objectstore` may 
be used directly like a session; all session actions performed on 
`sqlalchemy.objectstore` will be *proxied* to the thread-local Session:
 
     {python}
-    # enable the thread-local default session context (only need to call this 
once per application)
+    # install 'threadlocal' mod (only need to call this once per application)
     import sqlalchemy.mods.threadlocal
-    
-    # return the Session that is bound to the current thread
-    session = current_session()
 
-When using the `threadlocal` mod, a familiar SA 0.1 keyword `objectstore` is 
imported into the `sqlalchemy` namespace.  Using `objectstore`, methods can be 
called which will automatically be *proxied* to the Session that corresponds to 
`current_session()`:
-
-    {python}
-    # load the 'threadlocal' mod *first*
-    import sqlalchemy.mods.threadlocal
-    
     # then 'objectstore' is available within the 'sqlalchemy' namespace
-    from sqlalchemy import *
-    
-    # and then this...
-    current_session().flush()
-    
-    # is the same as this:
+    from sqlalchemy import objectstore
+
+    # flush the current thread-local session using the objectstore directly
     objectstore.flush()
+    
+    # which is the same as this (assuming we are still on the same thread):
+    session = objectstore.get_session()
+    session.flush()
 
 We will now cover some of the key concepts used by Sessions and its underlying 
Unit of Work.
 
Index: lib/sqlalchemy/ext/sessioncontext.py
===================================================================
--- lib/sqlalchemy/ext/sessioncontext.py        (revision 1453)
+++ lib/sqlalchemy/ext/sessioncontext.py        (working copy)
@@ -1,4 +1,5 @@
 from sqlalchemy.util import ScopedRegistry
+from sqlalchemy.orm.mapper import MapperExtension
 
 class SessionContext(object):
     """A simple wrapper for ScopedRegistry that provides a "current" property
@@ -31,88 +32,22 @@
     current = property(get_current, set_current, del_current,
         """Property used to get/set/del the session in the current scope""")
 
-    def create_metaclass(session_context):
-        """return a metaclass to be used by objects that wish to be bound to a
-        thread-local session upon instantiatoin.
-
-        Note non-standard use of session_context rather than self as the name
-        of the first arguement of this method.
-
-        Usage:
-            context = SessionContext(...)
-            class MyClass(object):
-                __metaclass__ = context.metaclass
-                ...
-        """
+    def _get_mapper_extension(self):
         try:
-            return session_context._metaclass
+            return self._extension
         except AttributeError:
-            class metaclass(type):
-                def __init__(cls, name, bases, dct):
-                    old_init = getattr(cls, "__init__")
-                    def __init__(self, *args, **kwargs):
-                        session_context.current.save(self)
-                        old_init(self, *args, **kwargs)
-                    setattr(cls, "__init__", __init__)
-                    super(metaclass, cls).__init__(name, bases, dct)
-            session_context._metaclass = metaclass
-            return metaclass
-    metaclass = property(create_metaclass)
+            self._extension = ext = SessionContextExt(self)
+            return ext
+    mapper_extension = property(_get_mapper_extension,
+        doc="""get a mapper extension that implements get_session using this 
context""")
 
-    def create_baseclass(session_context):
-        """return a baseclass to be used by objects that wish to be bound to a
-        thread-local session upon instantiatoin.
 
-        Note non-standard use of session_context rather than self as the name
-        of the first arguement of this method.
+class SessionContextExt(MapperExtension):
+    """a mapper extionsion that provides sessions to a mapper using 
SessionContext"""
 
-        Usage:
-            context = SessionContext(...)
-            class MyClass(context.baseclass):
-                ...
-        """
-        try:
-            return session_context._baseclass
-        except AttributeError:
-            class baseclass(object):
-                def __init__(self, *args, **kwargs):
-                    session_context.current.save(self)
-                    super(baseclass, self).__init__(*args, **kwargs)
-            session_context._baseclass = baseclass
-            return baseclass
-    baseclass = property(create_baseclass)
-
-
-def test():
-
-    def run_test(class_, context):
-        obj = class_()
-        assert context.current == get_session(obj)
-
-        # keep a reference so the old session doesn't get gc'd
-        old_session = context.current
-
-        context.current = create_session()
-        assert context.current != get_session(obj)
-        assert old_session == get_session(obj)
-
-        del context.current
-        assert context.current != get_session(obj)
-        assert old_session == get_session(obj)
-
-        obj2 = class_()
-        assert context.current == get_session(obj2)
-
-    # test metaclass
-    context = SessionContext(create_session)
-    class MyClass(object): __metaclass__ = context.metaclass
-    run_test(MyClass, context)
-
-    # test baseclass
-    context = SessionContext(create_session)
-    class MyClass(context.baseclass): pass
-    run_test(MyClass, context)
-
-if __name__ == "__main__":
-    test()
-    print "All tests passed!"
+    def __init__(self, context):
+        MapperExtension.__init__(self)
+        self.context = context
+    
+    def get_session(self):
+        return self.context.current
Index: lib/sqlalchemy/mods/threadlocal.py
===================================================================
--- lib/sqlalchemy/mods/threadlocal.py  (revision 1453)
+++ lib/sqlalchemy/mods/threadlocal.py  (working copy)
@@ -1,5 +1,7 @@
 from sqlalchemy import util, engine, mapper
-from sqlalchemy.orm import session, current_session
+from sqlalchemy.ext.sessioncontext import SessionContext
+from sqlalchemy.orm.mapper import global_extensions
+from sqlalchemy.orm.session import Session
 import sqlalchemy
 import sys, types
 
@@ -10,19 +12,19 @@
 from the pool.  this greatly helps functions that call multiple statements to 
be able to easily use just one connection
 without explicit "close" statements on result handles.
 
-on the Session side, the current_session() method will be modified to return a 
thread-local Session when no arguments
-are sent.  It will also install module-level methods within the objectstore 
module, such as flush(), delete(), etc.
-which call this method on the thread-local session returned by 
current_session().
+on the Session side, module-level methods will be installed within the 
objectstore module, such as flush(), delete(), etc.
+which call this method on the thread-local session.
 
-
+Note: this mod creates a global, thread-local session context named 
sqlalchemy.objectstore. All mappers created
+while this mod is installed will reference this global context when creating 
new mapped object instances.
 """
 
-class Objectstore(object):
+class Objectstore(SessionContext):
     def __getattr__(self, key):
-        return getattr(current_session(), key)
+        return getattr(self.current, key)
     def get_session(self):
-        return current_session()
-        
+        return self.current
+
 def monkeypatch_query_method(class_, name):
     def do(self, *args, **kwargs):
         query = class_.mapper.query()
@@ -31,7 +33,7 @@
 
 def monkeypatch_objectstore_method(class_, name):
     def do(self, *args, **kwargs):
-        session = current_session()
+        session = sqlalchemy.objectstore.current
         getattr(session, name)(self, *args, **kwargs)
     setattr(class_, name, do)
     
@@ -48,16 +50,18 @@
         monkeypatch_query_method(class_, name)
     for name in ['flush', 'delete', 'expire', 'refresh', 'expunge', 'merge', 
'update', 'save_or_update']:
         monkeypatch_objectstore_method(class_, name)
+
+def _mapper_extension():
+    return SessionContext._get_mapper_extension(sqlalchemy.objectstore)
     
 def install_plugin():
-    reg = util.ScopedRegistry(session.Session)
-    session.register_default_session(lambda *args, **kwargs: reg())
+    sqlalchemy.objectstore = objectstore = Objectstore(Session)
+    global_extensions.append(_mapper_extension)
     engine.default_strategy = 'threadlocal'
-    sqlalchemy.objectstore = Objectstore()
     sqlalchemy.assign_mapper = assign_mapper
 
 def uninstall_plugin():
-    session.register_default_session(lambda *args, **kwargs:None)
     engine.default_strategy = 'plain'
+    global_extensions.remove(_mapper_extension)
     
 install_plugin()
Index: lib/sqlalchemy/orm/__init__.py
===================================================================
--- lib/sqlalchemy/orm/__init__.py      (revision 1453)
+++ lib/sqlalchemy/orm/__init__.py      (working copy)
@@ -14,12 +14,11 @@
 from query import Query
 from util import polymorphic_union
 import properties
-from session import current_session
 from session import Session as create_session
 
 __all__ = ['relation', 'backref', 'eagerload', 'lazyload', 'noload', 
'deferred', 'defer', 'undefer',
         'mapper', 'clear_mappers', 'sql', 'extension', 'class_mapper', 
'object_mapper', 'MapperExtension', 'Query', 
-        'cascade_mappers', 'polymorphic_union', 'current_session', 
'create_session',  
+        'cascade_mappers', 'polymorphic_union', 'create_session',  
         ]
 
 def relation(*args, **kwargs):
Index: lib/sqlalchemy/orm/mapper.py
===================================================================
--- lib/sqlalchemy/orm/mapper.py        (revision 1453)
+++ lib/sqlalchemy/orm/mapper.py        (working copy)
@@ -63,7 +63,7 @@
                 ext = ext_obj.chain(ext)
             
         self.extension = ext
-
+        
         self.class_ = class_
         self.entity_name = entity_name
         self.class_key = ClassKey(class_, entity_name)
@@ -325,6 +325,7 @@
         """sets up our classes' overridden __init__ method, this mappers hash 
key as its
         '_mapper' property, and our columns as its 'c' property.  if the class 
already had a
         mapper, the old __init__ method is kept the same."""
+        ext = self.extension
         oldinit = self.class_.__init__
         def init(self, *args, **kwargs):
             self._entity_name = kwargs.pop('_sa_entity_name', None)
@@ -336,7 +337,9 @@
             if kwargs.has_key('_sa_session'):
                 session = kwargs.pop('_sa_session')
             else:
-                session = sessionlib.current_session(self)
+                session = ext.get_session()
+                if session is EXT_PASS:
+                    session = None
             if session is not None:
                 session._register_new(self)
             if oldinit is not None:
@@ -349,6 +352,17 @@
         mapper_registry[self.class_key] = self
         if self.entity_name is None:
             self.class_.c = self.c
+
+    def get_session(self):
+        """returns the contextual session provided by the mapper extension 
chain
+        
+        raises InvalidRequestError if a session cannot be retrieved from the
+        extension chain
+        """
+        s = self.extension.get_session()
+        if s is EXT_PASS:
+            raise exceptions.InvalidRequestError("No contextual Session is 
established.  Use a MapperExtension that implements get_session or use 'import 
sqlalchemy.mods.threadlocal' to establish a default thread-local contextual 
session.")
+        return s
     
     def has_eager(self):
         """returns True if one of the properties attached to this Mapper is 
eager loading"""
@@ -900,6 +914,14 @@
     def chain(self, ext):
         self.next = ext
         return self    
+    def get_session(self):
+        """called to retrieve a contextual Session instance with which to
+        register a new object. Note: this is not called if a session is 
+        provided with the __init__ params (i.e. _sa_session)"""
+        if self.next is None:
+            return EXT_PASS
+        else:
+            return self.next.get_session()
     def select_by(self, query, *args, **kwargs):
         """overrides the select_by method of the Query object"""
         if self.next is None:
Index: lib/sqlalchemy/orm/query.py
===================================================================
--- lib/sqlalchemy/orm/query.py (revision 1453)
+++ lib/sqlalchemy/orm/query.py (working copy)
@@ -27,7 +27,7 @@
         self._get_clause = self.mapper._get_clause
     def _get_session(self):
         if self._session is None:
-            return sessionlib.required_current_session()
+            return self.mapper.get_session()
         else:
             return self._session
     table = property(lambda s:s.mapper.select_table)
Index: lib/sqlalchemy/orm/session.py
===================================================================
--- lib/sqlalchemy/orm/session.py       (revision 1453)
+++ lib/sqlalchemy/orm/session.py       (working copy)
@@ -427,42 +427,19 @@
 # acts as a Registry with which to locate Sessions.  this is to enable
 # object instances to be associated with Sessions without having to attach the
 # actual Session object directly to the object instance.
-_sessions = weakref.WeakValueDictionary() 
+_sessions = weakref.WeakValueDictionary()
 
-def current_session(obj=None):
-    if hasattr(obj, '__session__'):
-        return obj.__session__()
-    else:
-        return _default_session(obj=obj)
-        
-# deprecated
-get_session=current_session
-
-def required_current_session(obj=None):
-    s = current_session(obj)
-    if s is None:
-        if obj is None:
-            raise exceptions.InvalidRequestError("No global-level Session 
context is established.  Use 'import sqlalchemy.mods.threadlocal' to establish 
a default thread-local context.")
-        else:
-            raise exceptions.InvalidRequestError("No Session context is 
established for class '%s', and no global-level Session context is established. 
 Use 'import sqlalchemy.mods.threadlocal' to establish a default thread-local 
context." % (obj.__class__))
-    return s
-    
-def _default_session(obj=None):
-    return None
-def register_default_session(callable_):
-    global _default_session
-    _default_session = callable_            
-    
 def object_session(obj):
     hashkey = getattr(obj, '_sa_session_id', None)
     if hashkey is not None:
-        # ok, return that
-        try:
-            return _sessions[hashkey]
-        except KeyError:
-            return None
-    else:
-        return None
+        return _sessions.get(hashkey)
+    return None
 
 unitofwork.object_session = object_session
 
+
+def get_session(obj=None):
+    """deprecated"""
+    if obj is not None:
+        return object_session(obj)
+    raise exceptions.InvalidRequestError("get_session() is deprecated, and 
does not return the thread-local session anymore. Use the 
SessionContext.mapper_extension or import sqlalchemy.mod.threadlocal to 
establish a default thread-local context.")
Index: test/sessioncontext.py
===================================================================
--- test/sessioncontext.py      (revision 0)
+++ test/sessioncontext.py      (revision 0)
@@ -0,0 +1,93 @@
+'''
+def test():
+    def run_test(class_, context):
+        obj = class_()
+        assert context.current == object_session(obj)
+
+        # keep a reference so the old session doesn't get gc'd
+        old_session = context.current
+
+        context.current = Session()
+        assert context.current != object_session(obj)
+        assert old_session == object_session(obj)
+
+        del context.current
+        assert context.current != object_session(obj)
+        assert old_session == object_session(obj)
+        
+        obj2 = class_()
+        assert context.current == object_session(obj2)
+
+    # test metaclass
+    context = SessionContext(Session)
+    class MyClass(object): __contextsession__ = context.get_classmethod()
+    run_test(MyClass, context)
+
+    # test baseclass
+    context = SessionContext(Session)
+    class MyClass(context.baseclass): pass
+    run_test(MyClass, context)
+
+if __name__ == "__main__":
+    test()
+    print "All tests passed!"
+'''
+
+from testbase import PersistTest, AssertMixin
+import unittest, sys, os
+from sqlalchemy.ext.sessioncontext import SessionContext
+from sqlalchemy.orm.session import object_session, Session
+from sqlalchemy import *
+import testbase
+
+db = testbase.db
+
+users = Table('users', db,
+    Column('user_id', Integer, Sequence('user_id_seq', optional=True), 
primary_key = True),
+    Column('user_name', String(40)),
+    mysql_engine='innodb'
+)
+
+User = None
+
+class HistoryTest(AssertMixin):
+    def setUpAll(self):
+        db.echo = False
+        users.create()
+        db.echo = testbase.echo
+    def tearDownAll(self):
+        db.echo = False
+        users.drop()
+        db.echo = testbase.echo
+    def setUp(self):
+        clear_mappers()
+        
+    def do_test(self, class_, context):
+        """test session assignment on object creation"""
+        obj = class_()
+        assert context.current == object_session(obj)
+
+        # keep a reference so the old session doesn't get gc'd
+        old_session = context.current
+
+        context.current = Session()
+        assert context.current != object_session(obj)
+        assert old_session == object_session(obj)
+
+        new_session = context.current
+        del context.current
+        assert context.current != new_session
+        assert old_session == object_session(obj)
+        
+        obj2 = class_()
+        assert context.current == object_session(obj2)
+    
+    def test_mapper_extension(self):
+        context = SessionContext(Session)
+        class User(object): pass
+        User.mapper = mapper(User, users, extension=context.mapper_extension)
+        self.do_test(User, context)
+
+
+if __name__ == "__main__":
+    testbase.main()        

Reply via email to