Hey Mike,
Here's a new patch, which uses your mapper extension idea rather than the
session_context mapper() parameter. I also removed orm.session.current_session
and friends, and updated the unitofwork documentation accordingly.
Note that I have implemented the new mapper extension method you had proposed
as get_session() rather than get_session_context(). This allowed me to
encapsulate the entire SessionContext interface within the mapper extension.
Mapper also grew it's own get_session() method, which is used by Query.
I'm not sure how many (if any) tests this may have broken. I've always had
problems trying to run the tests...is there anything special I need to do? Give
it a whirl and let me know what you think.
~ Daniel
Index: doc/build/content/unitofwork.txt
===================================================================
--- doc/build/content/unitofwork.txt (revision 1453)
+++ doc/build/content/unitofwork.txt (working copy)
@@ -56,29 +56,21 @@
{python}
session = object_session(obj)
-When default session contexts are enabled, the current contextual session can
be acquired by the `current_session()` function. This function takes an
optional instance argument, which allows session contexts that are specific to
particular class hierarchies to return the correct session. When using the
"threadlocal" session context, enabled via the *mod*
`sqlalchemy.mods.threadlocal`, no instance argument is required:
+It is possible to install a default "threadlocal" session context by importing
a *mod* called `sqlalchemy.mods.threadlocal`. This mod creates a familiar SA
0.1 keyword `objectstore` in the `sqlalchemy` namespace. The `objectstore` may
be used directly like a session; all session actions performed on
`sqlalchemy.objectstore` will be *proxied* to the thread-local Session:
{python}
- # enable the thread-local default session context (only need to call this
once per application)
+ # install 'threadlocal' mod (only need to call this once per application)
import sqlalchemy.mods.threadlocal
-
- # return the Session that is bound to the current thread
- session = current_session()
-When using the `threadlocal` mod, a familiar SA 0.1 keyword `objectstore` is
imported into the `sqlalchemy` namespace. Using `objectstore`, methods can be
called which will automatically be *proxied* to the Session that corresponds to
`current_session()`:
-
- {python}
- # load the 'threadlocal' mod *first*
- import sqlalchemy.mods.threadlocal
-
# then 'objectstore' is available within the 'sqlalchemy' namespace
- from sqlalchemy import *
-
- # and then this...
- current_session().flush()
-
- # is the same as this:
+ from sqlalchemy import objectstore
+
+ # flush the current thread-local session using the objectstore directly
objectstore.flush()
+
+ # which is the same as this (assuming we are still on the same thread):
+ session = objectstore.get_session()
+ session.flush()
We will now cover some of the key concepts used by Sessions and its underlying
Unit of Work.
Index: lib/sqlalchemy/ext/sessioncontext.py
===================================================================
--- lib/sqlalchemy/ext/sessioncontext.py (revision 1453)
+++ lib/sqlalchemy/ext/sessioncontext.py (working copy)
@@ -1,4 +1,5 @@
from sqlalchemy.util import ScopedRegistry
+from sqlalchemy.orm.mapper import MapperExtension
class SessionContext(object):
"""A simple wrapper for ScopedRegistry that provides a "current" property
@@ -31,88 +32,22 @@
current = property(get_current, set_current, del_current,
"""Property used to get/set/del the session in the current scope""")
- def create_metaclass(session_context):
- """return a metaclass to be used by objects that wish to be bound to a
- thread-local session upon instantiatoin.
-
- Note non-standard use of session_context rather than self as the name
- of the first arguement of this method.
-
- Usage:
- context = SessionContext(...)
- class MyClass(object):
- __metaclass__ = context.metaclass
- ...
- """
+ def _get_mapper_extension(self):
try:
- return session_context._metaclass
+ return self._extension
except AttributeError:
- class metaclass(type):
- def __init__(cls, name, bases, dct):
- old_init = getattr(cls, "__init__")
- def __init__(self, *args, **kwargs):
- session_context.current.save(self)
- old_init(self, *args, **kwargs)
- setattr(cls, "__init__", __init__)
- super(metaclass, cls).__init__(name, bases, dct)
- session_context._metaclass = metaclass
- return metaclass
- metaclass = property(create_metaclass)
+ self._extension = ext = SessionContextExt(self)
+ return ext
+ mapper_extension = property(_get_mapper_extension,
+ doc="""get a mapper extension that implements get_session using this
context""")
- def create_baseclass(session_context):
- """return a baseclass to be used by objects that wish to be bound to a
- thread-local session upon instantiatoin.
- Note non-standard use of session_context rather than self as the name
- of the first arguement of this method.
+class SessionContextExt(MapperExtension):
+ """a mapper extionsion that provides sessions to a mapper using
SessionContext"""
- Usage:
- context = SessionContext(...)
- class MyClass(context.baseclass):
- ...
- """
- try:
- return session_context._baseclass
- except AttributeError:
- class baseclass(object):
- def __init__(self, *args, **kwargs):
- session_context.current.save(self)
- super(baseclass, self).__init__(*args, **kwargs)
- session_context._baseclass = baseclass
- return baseclass
- baseclass = property(create_baseclass)
-
-
-def test():
-
- def run_test(class_, context):
- obj = class_()
- assert context.current == get_session(obj)
-
- # keep a reference so the old session doesn't get gc'd
- old_session = context.current
-
- context.current = create_session()
- assert context.current != get_session(obj)
- assert old_session == get_session(obj)
-
- del context.current
- assert context.current != get_session(obj)
- assert old_session == get_session(obj)
-
- obj2 = class_()
- assert context.current == get_session(obj2)
-
- # test metaclass
- context = SessionContext(create_session)
- class MyClass(object): __metaclass__ = context.metaclass
- run_test(MyClass, context)
-
- # test baseclass
- context = SessionContext(create_session)
- class MyClass(context.baseclass): pass
- run_test(MyClass, context)
-
-if __name__ == "__main__":
- test()
- print "All tests passed!"
+ def __init__(self, context):
+ MapperExtension.__init__(self)
+ self.context = context
+
+ def get_session(self):
+ return self.context.current
Index: lib/sqlalchemy/mods/threadlocal.py
===================================================================
--- lib/sqlalchemy/mods/threadlocal.py (revision 1453)
+++ lib/sqlalchemy/mods/threadlocal.py (working copy)
@@ -1,5 +1,7 @@
from sqlalchemy import util, engine, mapper
-from sqlalchemy.orm import session, current_session
+from sqlalchemy.ext.sessioncontext import SessionContext
+from sqlalchemy.orm.mapper import global_extensions
+from sqlalchemy.orm.session import Session
import sqlalchemy
import sys, types
@@ -10,19 +12,19 @@
from the pool. this greatly helps functions that call multiple statements to
be able to easily use just one connection
without explicit "close" statements on result handles.
-on the Session side, the current_session() method will be modified to return a
thread-local Session when no arguments
-are sent. It will also install module-level methods within the objectstore
module, such as flush(), delete(), etc.
-which call this method on the thread-local session returned by
current_session().
+on the Session side, module-level methods will be installed within the
objectstore module, such as flush(), delete(), etc.
+which call this method on the thread-local session.
-
+Note: this mod creates a global, thread-local session context named
sqlalchemy.objectstore. All mappers created
+while this mod is installed will reference this global context when creating
new mapped object instances.
"""
-class Objectstore(object):
+class Objectstore(SessionContext):
def __getattr__(self, key):
- return getattr(current_session(), key)
+ return getattr(self.current, key)
def get_session(self):
- return current_session()
-
+ return self.current
+
def monkeypatch_query_method(class_, name):
def do(self, *args, **kwargs):
query = class_.mapper.query()
@@ -31,7 +33,7 @@
def monkeypatch_objectstore_method(class_, name):
def do(self, *args, **kwargs):
- session = current_session()
+ session = sqlalchemy.objectstore.current
getattr(session, name)(self, *args, **kwargs)
setattr(class_, name, do)
@@ -48,16 +50,18 @@
monkeypatch_query_method(class_, name)
for name in ['flush', 'delete', 'expire', 'refresh', 'expunge', 'merge',
'update', 'save_or_update']:
monkeypatch_objectstore_method(class_, name)
+
+def _mapper_extension():
+ return SessionContext._get_mapper_extension(sqlalchemy.objectstore)
def install_plugin():
- reg = util.ScopedRegistry(session.Session)
- session.register_default_session(lambda *args, **kwargs: reg())
+ sqlalchemy.objectstore = objectstore = Objectstore(Session)
+ global_extensions.append(_mapper_extension)
engine.default_strategy = 'threadlocal'
- sqlalchemy.objectstore = Objectstore()
sqlalchemy.assign_mapper = assign_mapper
def uninstall_plugin():
- session.register_default_session(lambda *args, **kwargs:None)
engine.default_strategy = 'plain'
+ global_extensions.remove(_mapper_extension)
install_plugin()
Index: lib/sqlalchemy/orm/__init__.py
===================================================================
--- lib/sqlalchemy/orm/__init__.py (revision 1453)
+++ lib/sqlalchemy/orm/__init__.py (working copy)
@@ -14,12 +14,11 @@
from query import Query
from util import polymorphic_union
import properties
-from session import current_session
from session import Session as create_session
__all__ = ['relation', 'backref', 'eagerload', 'lazyload', 'noload',
'deferred', 'defer', 'undefer',
'mapper', 'clear_mappers', 'sql', 'extension', 'class_mapper',
'object_mapper', 'MapperExtension', 'Query',
- 'cascade_mappers', 'polymorphic_union', 'current_session',
'create_session',
+ 'cascade_mappers', 'polymorphic_union', 'create_session',
]
def relation(*args, **kwargs):
Index: lib/sqlalchemy/orm/mapper.py
===================================================================
--- lib/sqlalchemy/orm/mapper.py (revision 1453)
+++ lib/sqlalchemy/orm/mapper.py (working copy)
@@ -63,7 +63,7 @@
ext = ext_obj.chain(ext)
self.extension = ext
-
+
self.class_ = class_
self.entity_name = entity_name
self.class_key = ClassKey(class_, entity_name)
@@ -325,6 +325,7 @@
"""sets up our classes' overridden __init__ method, this mappers hash
key as its
'_mapper' property, and our columns as its 'c' property. if the class
already had a
mapper, the old __init__ method is kept the same."""
+ ext = self.extension
oldinit = self.class_.__init__
def init(self, *args, **kwargs):
self._entity_name = kwargs.pop('_sa_entity_name', None)
@@ -336,7 +337,9 @@
if kwargs.has_key('_sa_session'):
session = kwargs.pop('_sa_session')
else:
- session = sessionlib.current_session(self)
+ session = ext.get_session()
+ if session is EXT_PASS:
+ session = None
if session is not None:
session._register_new(self)
if oldinit is not None:
@@ -349,6 +352,17 @@
mapper_registry[self.class_key] = self
if self.entity_name is None:
self.class_.c = self.c
+
+ def get_session(self):
+ """returns the contextual session provided by the mapper extension
chain
+
+ raises InvalidRequestError if a session cannot be retrieved from the
+ extension chain
+ """
+ s = self.extension.get_session()
+ if s is EXT_PASS:
+ raise exceptions.InvalidRequestError("No contextual Session is
established. Use a MapperExtension that implements get_session or use 'import
sqlalchemy.mods.threadlocal' to establish a default thread-local contextual
session.")
+ return s
def has_eager(self):
"""returns True if one of the properties attached to this Mapper is
eager loading"""
@@ -900,6 +914,14 @@
def chain(self, ext):
self.next = ext
return self
+ def get_session(self):
+ """called to retrieve a contextual Session instance with which to
+ register a new object. Note: this is not called if a session is
+ provided with the __init__ params (i.e. _sa_session)"""
+ if self.next is None:
+ return EXT_PASS
+ else:
+ return self.next.get_session()
def select_by(self, query, *args, **kwargs):
"""overrides the select_by method of the Query object"""
if self.next is None:
Index: lib/sqlalchemy/orm/query.py
===================================================================
--- lib/sqlalchemy/orm/query.py (revision 1453)
+++ lib/sqlalchemy/orm/query.py (working copy)
@@ -27,7 +27,7 @@
self._get_clause = self.mapper._get_clause
def _get_session(self):
if self._session is None:
- return sessionlib.required_current_session()
+ return self.mapper.get_session()
else:
return self._session
table = property(lambda s:s.mapper.select_table)
Index: lib/sqlalchemy/orm/session.py
===================================================================
--- lib/sqlalchemy/orm/session.py (revision 1453)
+++ lib/sqlalchemy/orm/session.py (working copy)
@@ -427,42 +427,19 @@
# acts as a Registry with which to locate Sessions. this is to enable
# object instances to be associated with Sessions without having to attach the
# actual Session object directly to the object instance.
-_sessions = weakref.WeakValueDictionary()
+_sessions = weakref.WeakValueDictionary()
-def current_session(obj=None):
- if hasattr(obj, '__session__'):
- return obj.__session__()
- else:
- return _default_session(obj=obj)
-
-# deprecated
-get_session=current_session
-
-def required_current_session(obj=None):
- s = current_session(obj)
- if s is None:
- if obj is None:
- raise exceptions.InvalidRequestError("No global-level Session
context is established. Use 'import sqlalchemy.mods.threadlocal' to establish
a default thread-local context.")
- else:
- raise exceptions.InvalidRequestError("No Session context is
established for class '%s', and no global-level Session context is established.
Use 'import sqlalchemy.mods.threadlocal' to establish a default thread-local
context." % (obj.__class__))
- return s
-
-def _default_session(obj=None):
- return None
-def register_default_session(callable_):
- global _default_session
- _default_session = callable_
-
def object_session(obj):
hashkey = getattr(obj, '_sa_session_id', None)
if hashkey is not None:
- # ok, return that
- try:
- return _sessions[hashkey]
- except KeyError:
- return None
- else:
- return None
+ return _sessions.get(hashkey)
+ return None
unitofwork.object_session = object_session
+
+def get_session(obj=None):
+ """deprecated"""
+ if obj is not None:
+ return object_session(obj)
+ raise exceptions.InvalidRequestError("get_session() is deprecated, and
does not return the thread-local session anymore. Use the
SessionContext.mapper_extension or import sqlalchemy.mod.threadlocal to
establish a default thread-local context.")
Index: test/sessioncontext.py
===================================================================
--- test/sessioncontext.py (revision 0)
+++ test/sessioncontext.py (revision 0)
@@ -0,0 +1,93 @@
+'''
+def test():
+ def run_test(class_, context):
+ obj = class_()
+ assert context.current == object_session(obj)
+
+ # keep a reference so the old session doesn't get gc'd
+ old_session = context.current
+
+ context.current = Session()
+ assert context.current != object_session(obj)
+ assert old_session == object_session(obj)
+
+ del context.current
+ assert context.current != object_session(obj)
+ assert old_session == object_session(obj)
+
+ obj2 = class_()
+ assert context.current == object_session(obj2)
+
+ # test metaclass
+ context = SessionContext(Session)
+ class MyClass(object): __contextsession__ = context.get_classmethod()
+ run_test(MyClass, context)
+
+ # test baseclass
+ context = SessionContext(Session)
+ class MyClass(context.baseclass): pass
+ run_test(MyClass, context)
+
+if __name__ == "__main__":
+ test()
+ print "All tests passed!"
+'''
+
+from testbase import PersistTest, AssertMixin
+import unittest, sys, os
+from sqlalchemy.ext.sessioncontext import SessionContext
+from sqlalchemy.orm.session import object_session, Session
+from sqlalchemy import *
+import testbase
+
+db = testbase.db
+
+users = Table('users', db,
+ Column('user_id', Integer, Sequence('user_id_seq', optional=True),
primary_key = True),
+ Column('user_name', String(40)),
+ mysql_engine='innodb'
+)
+
+User = None
+
+class HistoryTest(AssertMixin):
+ def setUpAll(self):
+ db.echo = False
+ users.create()
+ db.echo = testbase.echo
+ def tearDownAll(self):
+ db.echo = False
+ users.drop()
+ db.echo = testbase.echo
+ def setUp(self):
+ clear_mappers()
+
+ def do_test(self, class_, context):
+ """test session assignment on object creation"""
+ obj = class_()
+ assert context.current == object_session(obj)
+
+ # keep a reference so the old session doesn't get gc'd
+ old_session = context.current
+
+ context.current = Session()
+ assert context.current != object_session(obj)
+ assert old_session == object_session(obj)
+
+ new_session = context.current
+ del context.current
+ assert context.current != new_session
+ assert old_session == object_session(obj)
+
+ obj2 = class_()
+ assert context.current == object_session(obj2)
+
+ def test_mapper_extension(self):
+ context = SessionContext(Session)
+ class User(object): pass
+ User.mapper = mapper(User, users, extension=context.mapper_extension)
+ self.do_test(User, context)
+
+
+if __name__ == "__main__":
+ testbase.main()