Michael Bayer wrote:
hey Dan -

just FYI, the current 0.2 Session implementation regarding this is looking like the following (hibernate's "save" is "add", hibernate's "update" is "import_". both of which at the moment seem to be clearer names to me, but maybe you see it differently...possibly not consistent with "delete", yah ok....):


Yeah, I wasn't really suggesting that they be changed (although I don't like 
"import_").

The problem with add() and import_() is that they both sound like they're doing 
the same thing. I think update() is preferable to import_(). Then we'd have 
add(), update(), and delete().

One thing that Hibernate enforces is that a given instance is not associated 
with more than one Session at any given point in time. Although that can 
sometimes be a pain, I think it solves a bunch of problems with having a single 
object inserted/updated by multiple sessions. How does SA handle this?

I have attached a patch with a few of my ideas. Among other things, I moved a 
few methods around to make it easier to read related parts of the code--I hope 
you don't mind. Also, Hibernate has a nice convenience method saveOrUpdate(); I 
took the liberty to add an equivalent method in my patch.

I also have a bunch of other stuff I'd like to discuss, but I don't have any 
more time. Later...

~ Daniel
Index: objectstore.py
===================================================================
--- objectstore.py      (revision 1307)
+++ objectstore.py      (working copy)
@@ -5,6 +5,7 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 from sqlalchemy import util
+from sqlalchemy.mapping import mapper, object_mapper, class_mapper
 from sqlalchemy.exceptions import *
 import unitofwork, query
 import weakref
@@ -200,45 +201,141 @@
         """deprecated"""
         raise InvalidRequestError("Session.commit() is deprecated.  use 
install_mod('legacy_session') to enable the old behavior")    
 
-    def flush(self, *obj):
+    def flush(self, *objects):
         """flushes all the object modifications present in this session to the 
database.  if object
         arguments are given, then only those objects (and immediate 
dependencies) are flushed."""
-        self.uow.flush(self, *obj)
+        self.uow.flush(self, *objects)
             
-    def refresh(self, *obj):
+    def refresh(self, *objects):
         """reloads the attributes for the given objects from the database, 
clears
         any changes made."""
-        for o in obj:
-            self.uow.refresh(o)
+        if not objects:
+            raise InvalidRequestError("Session.refresh() requires at least one 
positional argument.")
+        for obj in objects:
+            self.uow.refresh(obj)
 
-    def expire(self, *obj):
+    def expire(self, *objects):
         """invalidates the data in the given objects and sets them to refresh 
themselves
         the next time they are requested."""
-        for o in obj:
-            self.uow.expire(o)
+        if not objects:
+            raise InvalidRequestError("Session.expire() requires at least one 
positional argument.")
+        for obj in objects:
+            self.uow.expire(obj)
 
-    def expunge(self, *obj):
+    def expunge(self, *objects):
         """removes the given objects from this Session.  this will free all 
internal references to the objects."""
-        for o in obj:
-            self.uow.expunge(o)
+        if not objects:
+            raise InvalidRequestError("Session.expunge() requires at least one 
positional argument.")
+        for obj in objects:
+            self.uow.expunge(obj)
+    
+    def clear(self):
+        """removes all object instances from this Session.  this is equivalent 
to calling expunge() for all
+        objects in this Session."""
+        self.uow = unitofwork.UnitOfWork()
             
-    def add(self, *obj, **kwargs):
+    def add(self, *objects, **kwargs):
         """adds unsaved objects to this Session.  
         
         The 'entity_name' keyword argument can also be given which will be 
assigned
         to the instances if given.
         """
-        for o in obj:
-            if hasattr(o, '_instance_key'):
-                if not self.uow.has_key(o._instance_key):
-                    raise InvalidRequestError("Instance '%s' is not bound to 
this Session; use session.import(instance)" % repr(o))
+        if not objects:
+            raise InvalidRequestError("Session.add() requires at least one 
positional argument.")
+        for obj in objects:
+            if hasattr(obj, '_instance_key'):
+                if not self.uow.has_key(obj._instance_key):
+                    raise InvalidRequestError("Instance %s is not bound to 
this Session; use session.import(instance)" % repr(obj))
             else:
                 entity_name = kwargs.get('entity_name', None)
                 if entity_name is not None:
-                    m = class_mapper(o.__class__, entity_name=entity_name)
-                    m._assign_entity_name(o)
-                self._register_new(o)
+                    m = class_mapper(obj.__class__, entity_name=entity_name)
+                    m._assign_entity_name(obj)
+                self._register_new(obj)
             
+    def update(self, *objects, **kwargs):
+        """given one or more objects that represent persistent items, adds 
them to this session.
+
+        if an instance corresponding to the identity of the given instance 
already
+        exists within this session, then that instance is returned; the 
returned
+        instance should always be used following this method.
+        
+        if a single object is given, then a single object will be returned.  
however, if
+        multiple objects are given then a list of objects will be returned.
+        
+        if the given instance does not have an _instance_key and also does not 
have all 
+        of its primary key attributes populated, an exception is raised.  
similarly, if no
+        mapper can be located for the given instance, an exception is raised.
+
+        this method should be used for any object instance that is coming from 
serialized
+        storage, or was loaded by a Session other than this one.
+                
+        the keyword parameter entity_name is optional and is used to locate a 
Mapper for this
+        class which also specifies the given entity name.
+        """
+        if not objects:
+            raise InvalidRequestError("Session.update() requires at least one 
positional argument.")
+        rval = []
+        for obj in objects:
+            if obj is None:
+                return None
+            key = getattr(obj, '_instance_key', None)
+            mapper = object_mapper(obj, raiseerror=False)
+            if mapper is None:
+                mapper = class_mapper(obj, entity_name=entity_name)
+            if key is None:
+                ident = mapper.identity(obj)
+                for k in ident:
+                    if k is None:
+                        raise InvalidRequestError("Instance %s does not have a 
full set of identity values, and does not represent a saved entity in the 
database.  Use the add() method to add unsaved instances to this Session." % 
repr(obj))
+                key = mapper.identity_key(*ident)
+            u = self.uow
+            if u.identity_map.has_key(key):
+                rval.append(u.identity_map[key])
+            else:
+                obj._instance_key = key
+                u.identity_map[key] = obj
+                self._bind_to(obj)
+                rval.append(obj)
+        if len(rval) > 1:
+            return rval
+        return rval[0]
+
+    def addOrUpdate(self, *objects, **kwargs):
+        """add or update all given objects to this session
+        
+        see add() and update() for more information
+        """
+        if not objects:
+            raise InvalidRequestError("Session.addOrUpdate() requires at least 
one positional argument.")
+        rval = []
+        for obj in objects:
+            if hasattr(obj, '_instance_key'):
+                rval.append(self.update(obj, **kwargs))
+            else:
+                self.add(obj, **kwargs)
+                rval.append(obj)
+        if len(rval) > 1:
+            return rval
+        return rval[0]
+
+    def delete(self, *objects, **kwargs):
+        """registers the given objects to be deleted upon the next flush().  
If the given objects are not part of this
+        Session, they will be imported.  the objects are expected to either 
have an _instance_key
+        attribute or have all of their primary key attributes populated.
+        
+        the keyword argument 'entity_name' can also be provided which will be 
used by the import."""
+        if not objects:
+            raise InvalidRequestError("Session.delete() requires at least one 
positional argument.")
+        for obj in objects:
+            if not self._is_bound(obj):
+                obj = self.import_(obj, **kwargs)
+            self.uow.register_deleted(obj)
+        
+    def import_instance(self, *args, **kwargs):
+        """deprecated; a synynom for update()"""
+        return self.update(*args, **kwargs)
+
     def _register_new(self, obj):
         self._bind_to(obj)
         self.uow.register_new(obj)
@@ -272,79 +369,11 @@
     new = property(lambda s:s.uow.new)
     modified_lists = property(lambda s:s.uow.modified_lists)
     identity_map = property(lambda s:s.uow.identity_map)
-    
-    def clear(self):
-        """removes all object instances from this Session.  this is equivalent 
to calling expunge() for all
-        objects in this Session."""
-        self.uow = unitofwork.UnitOfWork()
 
-    def delete(self, *obj, **kwargs):
-        """registers the given objects to be deleted upon the next flush().  
If the given objects are not part of this
-        Session, they will be imported.  the objects are expected to either 
have an _instance_key
-        attribute or have all of their primary key attributes populated.
-        
-        the keyword argument 'entity_name' can also be provided which will be 
used by the import."""
-        for o in obj:
-            if not self._is_bound(o):
-                o = self.import_(o, **kwargs)
-            self.uow.register_deleted(o)
-        
-    def import_(self, instance, entity_name=None):
-        """given an instance that represents a saved item, adds it to this 
session.
-        the return value is either the given instance, or if an instance 
corresponding to the 
-        identity of the given instance already exists within this session, 
then that instance is returned;
-        the returned instance should always be used following this method.
-        
-        if the given instance does not have an _instance_key and also does not 
have all 
-        of its primary key attributes populated, an exception is raised.  
similarly, if no
-        mapper can be located for the given instance, an exception is raised.
+get_id_key = Session.get_id_key
 
-        this method should be used for any object instance that is coming from 
a serialized
-        storage, or was loaded by a Session other than this one.
-                
-        the keyword parameter entity_name is optional and is used to locate a 
Mapper for this
-        class which also specifies the given entity name.
-        """
-        if instance is None:
-            return None
-        key = getattr(instance, '_instance_key', None)
-        mapper = object_mapper(instance, raiseerror=False)
-        if mapper is None:
-            mapper = class_mapper(instance, entity_name=entity_name)
-        if key is None:
-            ident = mapper.identity(instance)
-            for k in ident:
-                if k is None:
-                    raise InvalidRequestError("Instance '%s' does not have a 
full set of identity values, and does not represent a saved entity in the 
database.  Use the add() method to add unsaved instances to this Session." % 
str(instance))
-            key = mapper.identity_key(*ident)
-        u = self.uow
-        if u.identity_map.has_key(key):
-            return u.identity_map[key]
-        else:
-            instance._instance_key = key
-            u.identity_map[key] = instance
-            self._bind_to(instance)
-            return instance
-            
-    def import_instance(self, *args, **kwargs):
-        """deprecated; a synynom for import()"""
-        return self.import_(*args, **kwargs)
+get_row_key = Session.get_row_key
 
-def get_id_key(ident, class_, entity_name=None):
-    return Session.get_id_key(ident, class_, entity_name)
-
-def get_row_key(row, class_, primary_key, entity_name=None):
-    return Session.get_row_key(row, class_, primary_key, entity_name)
-
-def mapper(*args, **params):
-    return sqlalchemy.mapping.mapper(*args, **params)
-
-def object_mapper(obj):
-    return sqlalchemy.mapping.object_mapper(obj)
-
-def class_mapper(class_, **kwargs):
-    return sqlalchemy.mapping.class_mapper(class_, **kwargs)
-
 # this is the AttributeManager instance used to provide attribute behavior on 
objects.
 # to all the "global variable police" out there:  its a stateless object.
 global_attributes = unitofwork.global_attributes
@@ -355,23 +384,21 @@
 # actual Session object directly to the object instance.
 _sessions = weakref.WeakValueDictionary() 
 
-def get_session(obj=None, raiseerror=True):
+def get_session(obj, raiseerror=True):
     """returns the Session corrseponding to the given object instance.  By 
default, if the object is not bound
     to any Session, then an error is raised (or None is returned if 
raiseerror=False).  This behavior can be changed
     using the "threadlocal" mod, which will add an additional step to return a 
Session that is bound to the current 
     thread."""
-    if obj is not None:
-        # does it have a hash key ?
-        hashkey = getattr(obj, '_sa_session_id', None)
-        if hashkey is not None:
-            # ok, return that
-            try:
-                return _sessions[hashkey]
-            except KeyError:
-                if raiseerror:
-                    raise InvalidRequestError("Session '%s' referenced by 
object '%s' no longer exists" % (hashkey, repr(obj)))
-                else:
-                    return None
+    hashkey = getattr(obj, '_sa_session_id', None)
+    if hashkey is not None:
+        try:
+            return _sessions[hashkey]
+        except KeyError:
+            if raiseerror:
+                raise InvalidRequestError("Session '%s' referenced by object 
'%s' no longer exists" % (hashkey, repr(obj)))
+    elif raiseerror:
+        raise InvalidRequestError("Object '%s' is not associated with a 
session" % repr(obj))
+    return None
                     
     return _default_session(obj=obj, raiseerror=raiseerror)
 

Reply via email to