--- //depot/vendor/sqlalchemy/sqlalchemy/lib/sqlalchemy/orm/dependency.py	2007/03/13 11:30:43
+++ //depot/user/benno/dbrep/sqlalchemy/lib/sqlalchemy/orm/dependency.py	2007/03/21 14:19:25
@@ -290,6 +290,9 @@
         
     def process_dependencies(self, task, deplist, uowcommit, delete = False):
         #print self.mapper.table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction)
+        # XXX: Use of bind_func at the session layer will break here, I think.
+        # The problem is that we need to work out which engine to use for the
+        # association table based on the objects that are in deplist.
         connection = uowcommit.transaction.connection(self.mapper)
         secondary_delete = []
         secondary_insert = []
--- //depot/vendor/sqlalchemy/sqlalchemy/lib/sqlalchemy/orm/mapper.py	2007/03/13 11:30:43
+++ //depot/user/benno/dbrep/sqlalchemy/lib/sqlalchemy/orm/mapper.py	2007/03/21 14:07:56
@@ -852,17 +852,19 @@
             for obj in objects:
                 self.save_obj([obj], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
             return
-            
-        connection = uowtransaction.transaction.connection(self)
+        
+        connections = {}
+        for obj in objects:
+            connections[obj] = uowtransaction.transaction.connection(self, obj)
 
         if not postupdate:
             for obj in objects:
                 if not has_identity(obj):
                     for mapper in object_mapper(obj).iterate_to_root():
-                        mapper.extension.before_insert(mapper, connection, obj)
+                        mapper.extension.before_insert(mapper, connections[obj], obj)
                 else:
                     for mapper in object_mapper(obj).iterate_to_root():
-                        mapper.extension.before_update(mapper, connection, obj)
+                        mapper.extension.before_update(mapper, connections[obj], obj)
 
         for obj in objects:
             # detect if we have a "pending" instance (i.e. has no instance_key attached to it),
@@ -996,8 +998,8 @@
                 update.sort(comparator)
                 for rec in update:
                     (obj, params, mapper) = rec
-                    c = connection.execute(statement, params)
-                    mapper._postfetch(connection, table, obj, c, c.last_updated_params())
+                    c = connections[obj].execute(statement, params)
+                    mapper._postfetch(connections[obj], table, obj, c, c.last_updated_params())
 
                     updated_objects.add(obj)
                     rows += c.rowcount
@@ -1012,7 +1014,7 @@
                 insert.sort(comparator)
                 for rec in insert:
                     (obj, params, mapper) = rec
-                    c = connection.execute(statement, params)
+                    c = connections[obj].execute(statement, params)
                     primary_key = c.last_inserted_ids()
                     if primary_key is not None:
                         i = 0
@@ -1020,7 +1022,7 @@
                             if mapper.get_attr_by_column(obj, col) is None and len(primary_key) > i:
                                 mapper.set_attr_by_column(obj, col, primary_key[i])
                             i+=1
-                    mapper._postfetch(connection, table, obj, c, c.last_inserted_params())
+                    mapper._postfetch(connections[obj], table, obj, c, c.last_inserted_params())
                     
                     # synchronize newly inserted ids from one table to the next
                     # TODO: this fires off more than needed, try to organize syncrules
@@ -1037,10 +1039,10 @@
         if not postupdate:
             for obj in inserted_objects:
                 for mapper in object_mapper(obj).iterate_to_root():
-                    mapper.extension.after_insert(mapper, connection, obj)
+                    mapper.extension.after_insert(mapper, connections[obj], obj)
             for obj in updated_objects:
                 for mapper in object_mapper(obj).iterate_to_root():
-                    mapper.extension.after_update(mapper, connection, obj)
+                    mapper.extension.after_update(mapper, connections[obj], obj)
 
     def _postfetch(self, connection, table, obj, resultproxy, params):
         """after an INSERT or UPDATE, asks the returned result if PassiveDefaults fired off on the database side
@@ -1072,9 +1074,11 @@
         if self.__should_log_debug:
             self.__log_debug("delete_obj() start")
 
-        connection = uowtransaction.transaction.connection(self)
+        connections = {}
+        for obj in objects:
+            connections[obj] = uowtransaction.transaction.connection(self, obj)
 
-        [self.extension.before_delete(self, connection, obj) for obj in objects]
+        [self.extension.before_delete(self, connections[obj], obj) for obj in objects]
         deleted_objects = util.Set()
         for table in self.tables.sort(reverse=True):
             if not self._has_pks(table):
@@ -1105,11 +1109,11 @@
                 if self.version_id_col is not None:
                     clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col.key, type=self.version_id_col.type))
                 statement = table.delete(clause)
-                c = connection.execute(statement, delete)
+                c = connections[obj].execute(statement, delete)
                 if c.supports_sane_rowcount() and c.rowcount != len(delete):
                     raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete)))
                     
-        [self.extension.after_delete(self, connection, obj) for obj in deleted_objects]
+        [self.extension.after_delete(self, connections[obj], obj) for obj in deleted_objects]
 
     def _has_pks(self, table):
         try:
--- //depot/vendor/sqlalchemy/sqlalchemy/lib/sqlalchemy/orm/session.py	2007/03/13 11:30:43
+++ //depot/user/benno/dbrep/sqlalchemy/lib/sqlalchemy/orm/session.py	2007/03/21 14:15:55
@@ -23,12 +23,12 @@
         self.connections = {}
         self.parent = parent
         self.autoflush = autoflush
-    def connection(self, mapper_or_class, entity_name=None):
+    def connection(self, mapper_or_class, obj=None, entity_name=None):
         if isinstance(mapper_or_class, type):
             mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name)
         if self.parent is not None:
-            return self.parent.connection(mapper_or_class)
-        engine = self.session.get_bind(mapper_or_class)
+            return self.parent.connection(mapper_or_class, obj)
+        engine = self.session.get_bind(mapper_or_class, obj)
         return self.get_or_add(engine)
     def _begin(self):
         return SessionTransaction(self.session, self)
@@ -92,6 +92,7 @@
         
         self.bind_to = bind_to
         self.binds = {}
+        self.bind_funcs = {}
         self.echo_uow = echo_uow
         self.weak_identity_map = weak_identity_map
         self.transaction = None
@@ -116,11 +117,11 @@
         else:
             self.transaction = SessionTransaction(self, **kwargs)
             return self.transaction
-    def connect(self, mapper=None, **kwargs):
+    def connect(self, mapper=None, obj=None, **kwargs):
         """returns a unique connection corresponding to the given mapper.  this connection
         will not be part of any pre-existing transactional context."""
-        return self.get_bind(mapper).connect(**kwargs)
-    def connection(self, mapper, **kwargs):
+        return self.get_bind(mapper, obj).connect(**kwargs)
+    def connection(self, mapper, obj=None, **kwargs):
         """returns a Connection corresponding to the given mapper.  used by the execute()
         method which performs select operations for Mapper and Query.
         if this Session is transactional, 
@@ -130,19 +131,19 @@
         
         the given **kwargs will be sent to the engine's contextual_connect() method, if no transaction is in progress."""
         if self.transaction is not None:
-            return self.transaction.connection(mapper)
+            return self.transaction.connection(mapper, obj)
         else:
-            return self.get_bind(mapper).contextual_connect(**kwargs)
-    def execute(self, mapper, clause, params, **kwargs):
+            return self.get_bind(mapper, obj).contextual_connect(**kwargs)
+    def execute(self, mapper, clause, params, obj=None, **kwargs):
         """using the given mapper to identify the appropriate Engine or Connection to be used for statement execution, 
         executes the given ClauseElement using the provided parameter dictionary.  Returns a ResultProxy corresponding
         to the execution's results.  If this method allocates a new Connection for the operation, then the ResultProxy's close() 
         method will release the resources of the underlying Connection, otherwise its a no-op.
         """
-        return self.connection(mapper, close_with_result=True).execute(clause, params, **kwargs)
-    def scalar(self, mapper, clause, params, **kwargs):
+        return self.connection(mapper, obj, close_with_result=True).execute(clause, params, **kwargs)
+    def scalar(self, mapper, clause, params, obj=None, **kwargs):
         """works like execute() but returns a scalar result."""
-        return self.connection(mapper, close_with_result=True).scalar(clause, params, **kwargs)
+        return self.connection(mapper, obj, close_with_result=True).scalar(clause, params, **kwargs)
         
     def close(self):
         """closes this Session.  
@@ -173,7 +174,9 @@
         
         All subsequent operations involving this Table will use the given bindto."""
         self.binds[table] = bindto
-    def get_bind(self, mapper):
+    def bind_func(self, mapper_or_table, func):
+        self.bind_funcs[mapper_or_table] = func
+    def get_bind(self, mapper, obj=None):
         """return the Engine or Connection which is used to execute statements on behalf of the given Mapper.
         
         Calling connect() on the return result will always result in a Connection object.  This method 
@@ -195,6 +198,8 @@
         """
         if mapper is None:
             return self.bind_to
+        elif obj is not None and self.bind_funcs.has_key(mapper):
+            return self.bind_funcs[mapper](mapper, obj)
         elif self.binds.has_key(mapper):
             return self.binds[mapper]
         elif self.binds.has_key(mapper.mapped_table):
--- //depot/vendor/sqlalchemy/sqlalchemy/test/orm/session.py	2007/03/13 11:30:43
+++ //depot/user/benno/dbrep/sqlalchemy/test/orm/session.py	2007/03/21 16:04:36
@@ -162,6 +162,25 @@
         assert s.query(Address).selectone().address_id == a.address_id
         assert s.query(User).selectfirst() is None
 
+    def test_bind_func(self):
+        c = testbase.db.connect()
+        class User(object):pass
+        user_mapper = mapper(User, users)
+        u = User()
+        u.user_name = 'fred'
+        def bind_func(mapper, obj):
+            assert mapper is user_mapper
+            assert obj is u
+            return c
+        sess = create_session()
+        sess.bind_func(user_mapper, bind_func)
+        sess.bind_func(users, bind_func)
+        transaction = sess.create_transaction()
+        sess.save(u)
+        transaction.commit()
+        v = users.select(users.c.user_id == u.user_id).execute().fetchone()
+        assert v[0] == u.user_id
+        assert v[1] == u.user_name
         
 class OrphanDeletionTest(AssertMixin):
 
