dabo Commit
Revision 6821
Date: 2011-09-03 11:53:03 -0700 (Sat, 03 Sep 2011)
Author: Ed
Trac: http://trac.dabodev.com/changeset/6821

Changed:
U   trunk/dabo/biz/dBizobj.py
A   trunk/dabo/biz/test/test_mm.py
U   trunk/dabo/db/dCursorMixin.py

Log:
Improved the handling of many-to-many relations. Now each bizobj can have 
multiple MM relations instead of just one.

Added a test script for all of these methods. That script can also be used to 
get examples of typical usages.



Diff:
Modified: trunk/dabo/biz/dBizobj.py
===================================================================
--- trunk/dabo/biz/dBizobj.py   2011-09-02 22:02:30 UTC (rev 6820)
+++ trunk/dabo/biz/dBizobj.py   2011-09-03 18:53:03 UTC (rev 6821)
@@ -65,6 +65,7 @@
                self._cursorFactory = None
                self.__params = ()              # tuple of params to be merged 
with the sql in the cursor
                self._children = []             # Collection of child bizobjs
+               self._associations = {}         # Dict of many-to-many 
associations, keyed by DataSource
                self._baseClass = dBizobj
                self.__areThereAnyChanges = False       # Used by the 
isChanged() method.
                # Used by the LinkField property
@@ -176,11 +177,13 @@
                return cur
 
 
-       def createCursor(self, key=None):
+       def createCursor(self, key=None, addToCursorCollection=True):
                """
-               Create the cursor that this bizobj will be using for data, and 
store it
+               Create the cursor that this bizobj will be using for data, and 
optionally store it
                in the dictionary for cursors, with the passed value of 'key' 
as its dict key.
-               For independent bizobjs, that key will be None.
+               For independent bizobjs, that key will be None. If creating a 
cursor that will not
+               be used as a data source for this bizobj, as when creating 
many-to-many
+               cursors, pass False for the 'addToCursorCollection' parameter.
 
                Subclasses should override beforeCreateCursor() and/or 
afterCreateCursor()
                instead of overriding this method, if possible. Returning any 
non-empty value
@@ -208,17 +211,17 @@
                        key = self.__currentCursorKey
 
                cf = self._cursorFactory
-               self.__cursors[key] = cf.getCursor(cursorClass)
-               self.__cursors[key].setCursorFactory(cf.getCursor, cursorClass)
-
-               crs = self.__cursors[key]
+               crs = cf.getCursor(cursorClass)
+               crs.setCursorFactory(cf.getCursor, cursorClass)
+               if addToCursorCollection:
+                       self.__cursors[key] = crs
                if _dataStructure is not None:
                        crs._dataStructure = _dataStructure
                crs.BackendObject = cf.getBackendObject()
                crs.sqlManager = self.SqlManager
                crs._bizobj = self
                self._syncCursorProps(crs)
-               if self.RequeryOnLoad:
+               if addToCursorCollection and self.RequeryOnLoad:
                        if self.__cursorsToRequery is None:
                                # We've already passed the bizobj init process
                                crs.requery()
@@ -1675,6 +1678,27 @@
                        child.Parent = self
 
 
+       def addMMBizobj(self, mmBizobj, assocTable, assocPKColThis, 
assocPKColOther,
+                       mmPkCol=None):
+               """
+               Add the passed bizobj to this bizobj in a Many-to-Many 
relationship.
+
+               The reference will be stored, and the Parent reference of that 
bizobj
+               will be set to this. If mmPkCol is not specified, the KeyField 
for the mmBizobj
+               will be used for the relationship.
+               """
+               if mmBizobj.DataSource not in self._associations:
+                       if mmPkCol is None:
+                               mmPkCol = mmBizobj.KeyField
+                       crs = self.createCursor(key=None, 
addToCursorCollection=False)
+                       crs._isMM = True
+                       crs.createAssociation(mmBizobj.DataSource, mmPkCol, 
assocTable,
+                               assocPKColThis, assocPKColOther)
+                       self._associations[mmBizobj.DataSource] = {
+                                       "bizobj": mmBizobj,
+                                       "cursor": crs}
+
+
        def getAncestorByDataSource(self, ds):
                """
                Given a DataSource, finds the ancestor (parent, grandparent, 
etc.) of
@@ -2058,50 +2082,107 @@
                return self._CurrentCursor.oldVal(fieldName, row)
 
 
-       def mmAssociateValue(self, otherField, otherVal):
+       def _getAssociation(self, bizOrDS):
                """
+               Returns the relevant association disctionary, given either the 
DataSource
+               or the 'other' bizobj. Returns None if no association has been 
defined.
+               """
+               try:
+                       # Assume this is a DataSource string
+                       return self._associations[bizOrDS]
+               except KeyError:
+                       # Try the bizobj
+                       keys = [k for k in self._associations
+                                       if self._associations[k]["bizobj"] == 
bizOrDS]
+                       try:
+                               ds = keys[0]
+                               return self._associations[ds]
+                       except IndexError:
+                               raise dException.DataSourceNotFoundException(
+                                               _("No many-to-many association 
found for DataSource: '%s'." % bizOrDS))
+               return None
+
+               
+       def mmAssociateValue(self, bizOrDS, otherField, otherVal):
+               """
                Associates the value in the 'other' table of a M-M relationship 
with the
                current record in the bizobj. If that value doesn't exist in 
the other
                table, it is added.
                """
-               self._CurrentCursor.mmAssociateValue(otherField, otherVal)
+               assoc = self._getAssociation(bizOrDS)
+               assoc["cursor"].mmAssociateValue(otherField, otherVal)
 
 
-       def mmDisssociateValue(self, otherField, otherVal):
+       def mmAssociateValues(self, bizOrDS, otherField, listOfValues):
                """
+               Adds association records so that the current record in this 
bizobj is associated
+               with every item in listOfValues. Other existing relationships 
are unaffected.
+               """
+               assoc = self._getAssociation(bizOrDS)
+               assoc["cursor"].mmAssociateValues(otherField, listOfValues)
+
+
+       def mmDisssociateValue(self, bizOrDS, otherField, otherVal):
+               """
                Removes the association between the current record and the 
specified value
                in the 'other' table of a M-M relationship. If no such 
association exists,
                nothing happens.
                """
-               self._CurrentCursor.mmDisssociateValue(otherField, otherVal)
+               assoc = self._getAssociation(bizOrDS)
+               assoc["cursor"].mmDisssociateValue(otherField, otherVal)
 
 
-       def mmDisssociateAll(self):
+       def mmDisssociateValues(self, bizOrDS, otherField, listOfValues):
                """
+               Removes the association between the current record and every 
item in 'listOfValues'
+               in the 'other' table of a M-M relationship. If no such 
association exists,
+               nothing happens.
+               """
+               assoc = self._getAssociation(bizOrDS)
+               assoc["cursor"].mmDisssociateValues(otherField, listOfValues)
+
+
+       def mmDisssociateAll(self, bizOrDS):
+               """
                Removes all associations between the current record and the 
associated
                M-M table.
                """
-               self._CurrentCursor.mmDisssociateAll()
+               assoc = self._getAssociation(bizOrDS)
+               assoc["cursor"].mmDisssociateAll()
 
 
-       def mmSetFullAssociation(self, otherField, listOfValues):
+       def mmSetFullAssociation(self, bizOrDS, otherField, listOfValues):
                """
                Adds and/or removes association records so that the current 
record in this
                bizobj is associated with every item in listOfValues, and none 
other.
                """
-               self._CurrentCursor.mmSetFullAssociation(otherField, 
listOfValues)
+               assoc = self._getAssociation(bizOrDS)
+               assoc["cursor"].mmSetFullAssociation(otherField, listOfValues)
 
 
-       def mmAddToBoth(self, thisField, thisVal, otherField, otherVal):
+       def mmAddToBoth(self, bizOrDS, thisField, thisVal, otherField, 
otherVal):
                """
                Creates an association in a M-M relationship. If the 
relationship
                already exists, nothing changes. Otherwise, this will ensure 
that
                both values exist in their respective tables, and will create 
the 
                entry in the association table.
                """
-               return self._CurrentCursor.mmAddToBoth(thisField, thisVal, 
otherField, otherVal)
+               assoc = self._getAssociation(bizOrDS)
+               return assoc["cursor"].mmAddToBoth(thisField, thisVal, 
otherField, otherVal)
 
 
+       def mmGetAssociatedValues(self, bizOrDS, listOfFields):
+               """
+               Given a relationship, returns the values associated with the 
current
+               record. 'listOfFields' can be either a single field name, or a 
list
+               of fields in the associated table.
+               """
+               if not isinstance(listOfFields, (list, tuple)):
+                       listOfFields = [listOfFields]
+               assoc = self._getAssociation(bizOrDS)
+               return assoc["cursor"].mmGetAssociatedValues(listOfFields)
+
+
        ########## SQL Builder interface section ##############
        def addField(self, exp, alias=None):
                """Add a field to the field clause."""
@@ -2128,7 +2209,8 @@
        def createAssociation(self, mmOtherTable, mmOtherPKCol, assocTable, 
assocPKColThis,
                        assocPKColOther):
                """
-               Create a many-to-many association.
+               Create a many-to-many association. Generally it is better to 
use the 'addMMBizobj()'
+               method, but if you want to set this manually, use this instead 
of defining the JOINs.
 
                :param mmOtherTable: the name of the table for the other half 
of the MM relation
                :param mmOtherPKCol: the name of the PK column in the 
mmOtherTable
@@ -2271,7 +2353,7 @@
        ########## Post-hook interface section ##############
 
        afterNew = _makeHookMethod("afterNew", "a new record is added",
-                       additionalDoc=\
+                       additionalDoc=
 """Use this hook to change field values of newly added records. If
 you change field values here, the memento system will catch it and
 prompt you to save if needed later on. If you want to change field
@@ -2296,7 +2378,7 @@
        afterChildRequery = _makeHookMethod("afterChildRequery",
                        "the child bizobjs are requeried")
        afterChange = _makeHookMethod("afterChange", "a record is changed",
-                       additionalDoc=\
+                       additionalDoc=
 """This hook will be called after a successful save() or delete(). Contrast
 with the afterSave() hook which only gets called after a save(), and the
 afterDelete() which is only called after a delete().""")
@@ -2741,13 +2823,13 @@
 
 
        def _getSQL(self):
-               warnings.warn(_("""This property is deprecated, and will be 
removed in the next version
-of the framework. Use the 'UserSQL' property instead."""), DeprecationWarning, 
1)
+               warnings.warn(_("This property is deprecated, and will be 
removed in the next version "
+                               "of the framework. Use the 'UserSQL' property 
instead."), DeprecationWarning, 1)
                return self.UserSQL
 
        def _setSQL(self, val):
-               warnings.warn(_("""This property is deprecated, and will be 
removed in the next version
-of the framework. Use the 'UserSQL' property instead."""), DeprecationWarning, 
1)
+               warnings.warn(_("This property is deprecated, and will be 
removed in the next version "
+                               "of the framework. Use the 'UserSQL' property 
instead."), DeprecationWarning, 1)
                self.UserSQL = val
 
 

Added: trunk/dabo/biz/test/test_mm.py
===================================================================
--- trunk/dabo/biz/test/test_mm.py                              (rev 0)
+++ trunk/dabo/biz/test/test_mm.py      2011-09-03 18:53:03 UTC (rev 6821)
@@ -0,0 +1,310 @@
+# -*- coding: utf-8 -*-
+import unittest
+import dabo
+import dabo.dException as dException
+from dabo.lib import getRandomUUID
+
+class Test_Many_To_Many(unittest.TestCase):
+       def setUp(self):
+               self.conn = dabo.db.dConnection(DbType="SQLite", 
Database=":memory:")
+               pbiz = self.person_biz = dabo.biz.dBizobj(self.conn)
+               self.crs = self.person_biz.getTempCursor()
+               self.createSchema()
+               pbiz.KeyField = "pkid"
+               pbiz.DataSource = "person"
+               pbiz.requery()
+               comp = self.company_biz = dabo.biz.dBizobj(self.conn)
+               comp.KeyField = "pkid"
+               comp.DataSource = "company"
+               comp.requery()
+               fan_club = self.fan_club_biz = dabo.biz.dBizobj(self.conn)
+               fan_club.KeyField = "pkid"
+               fan_club.DataSource = "fan_club"
+               fan_club.requery()
+               # Set the MM relations
+               pbiz.addMMBizobj(self.company_biz, "employees", "person_id", 
"company_id")
+               pbiz.addMMBizobj(self.fan_club_biz, "membership", "person_id", 
"fan_club_id")
+
+
+       def tearDown(self):
+               self.person_biz = self.company_biz = None
+
+
+       def createSchema(self):
+               self.crs.execute("create table person (pkid INTEGER PRIMARY KEY 
AUTOINCREMENT, first_name TEXT, last_name TEXT);")
+               self.crs.execute("create table company (pkid INTEGER PRIMARY 
KEY AUTOINCREMENT, company TEXT);")
+               self.crs.execute("create table employees (pkid INTEGER PRIMARY 
KEY AUTOINCREMENT, person_id INT, company_id INT);")
+               self.crs.execute("insert into person (first_name, last_name) 
values ('Ed', 'Leafe')")
+               self.crs.execute("insert into person (first_name, last_name) 
values ('Paul', 'McNett')")
+               self.crs.execute("insert into company (company) values ('Acme 
Manufacturing')")
+
+               self.crs.execute("create table fan_club (pkid INTEGER PRIMARY 
KEY AUTOINCREMENT, performer TEXT);")
+               self.crs.execute("create table membership (pkid INTEGER PRIMARY 
KEY AUTOINCREMENT, person_id INT, fan_club_id INT);")
+               self.crs.execute("insert into fan_club (performer) values 
('Green Day')")
+               self.crs.execute("insert into fan_club (performer) values ('The 
Clash')")
+               self.crs.execute("insert into fan_club (performer) values 
('Ramones')")
+               self.crs.execute("insert into fan_club (performer) values ('Pat 
Boone')")
+
+
+       def reccount(self, tbl, filt=None):
+               """Please note that SQL injection is consciously ignored here. 
These
+               are in-memory tables!!
+               """
+               if filt:
+                       self.crs.execute("select count(*) as cnt from %s where 
%s" % (tbl, filt))
+               else:
+                       self.crs.execute("select count(*) as cnt from %s" % tbl)
+               return self.crs.Record.cnt
+
+
+       def test_bad_datasource(self):
+               """Ensure that the proper exception is raised when a DataSource 
is passed
+               that does not correspond to an existing relation.
+               """
+               pbiz = self.person_biz
+               cbiz = self.company_biz
+               pbiz.seek("Leafe", "last_name")
+               self.assertRaises(dException.DataSourceNotFoundException,
+                               pbiz.mmAssociateValue, "dummy", "company", 
"Acme Manufacturing")
+               
+
+       def test_associate(self):
+               """Verify that bizobj.mmAssociateValue() works correctly."""
+               pbiz = self.person_biz
+               cbiz = self.company_biz
+               self.assertEqual(self.reccount("company"), 1)
+
+               pbiz.seek("Leafe", "last_name")
+               pbiz.mmAssociateValue(cbiz, "company", "Acme Manufacturing")
+               # Company count should not have changed
+               self.assertEqual(self.reccount("company"), 1)
+               pbiz.mmAssociateValue(cbiz, "company", "Amalgamated Industries")
+               # Company count should have increased
+               self.assertEqual(self.reccount("company"), 2)
+
+
+       def test_associate_list(self):
+               """Verify that bizobj.mmAssociateValues() works correctly."""
+               pbiz = self.person_biz
+               fbiz = self.fan_club_biz
+               orig_club_count = self.reccount("fan_club")
+               orig_fan_count = self.reccount("membership")
+               self.assertEqual(orig_fan_count, 0)
+
+               pbiz.seek("Leafe", "last_name")
+               pbiz.mmAssociateValues(fbiz, "performer", ["Ramones", "Green 
Day"])
+               # Club count should not have changed
+               self.assertEqual(self.reccount("fan_club"), orig_club_count)
+               # Membership count should have increased
+               self.assertEqual(self.reccount("membership"), orig_fan_count + 
2)
+
+               # Add a list with both existing and new clubs
+               pbiz.mmAssociateValues(fbiz, "performer", ["Ramones", "Black 
Flag"])
+               # Club count should have increased by 1
+               self.assertEqual(self.reccount("fan_club"), orig_club_count + 1)
+               # Membership count should have increased by 1
+               self.assertEqual(self.reccount("membership"), orig_fan_count + 
3)
+
+
+       def test_dissociate(self):
+               """Verify that bizobj.mmDisssociateValue() works correctly."""
+               pbiz = self.person_biz
+               cbiz = self.company_biz
+               pbiz.seek("Leafe", "last_name")
+               leafe_pk = pbiz.getPK()
+               pbiz.mmAssociateValue(cbiz, "company", "Acme Manufacturing")
+               emp_count = self.reccount("employees", "person_id = %s" % 
leafe_pk)
+               self.assertEqual(emp_count, 1)
+               pbiz.mmAssociateValue(cbiz, "company", "Amalgamated Industries")
+               emp_count = self.reccount("employees", "person_id = %s" % 
leafe_pk)
+               self.assertEqual(emp_count, 2)
+               pbiz.mmDisssociateValue(cbiz, "company", "Acme Manufacturing")
+               emp_count = self.reccount("employees", "person_id = %s" % 
leafe_pk)
+               self.assertEqual(emp_count, 1)
+
+
+       def test_dissociate_list(self):
+               """Verify that bizobj.mmDisssociateValues() works correctly."""
+               pbiz = self.person_biz
+               cbiz = self.company_biz
+               pbiz.seek("Leafe", "last_name")
+               leafe_pk = pbiz.getPK()
+               pbiz.mmAssociateValues(cbiz, "company", ["Acme Manufacturing", 
"Amalgamated Industries",
+                               "Dabo Incorporated"])
+               emp_count = self.reccount("employees", "person_id = %s" % 
leafe_pk)
+               self.assertEqual(emp_count, 3)
+               pbiz.mmDisssociateValues(cbiz, "company", ["Acme 
Manufacturing", "Amalgamated Industries"])
+               emp_count = self.reccount("employees", "person_id = %s" % 
leafe_pk)
+               self.assertEqual(emp_count, 1)
+
+
+       def test_dissociateAll(self):
+               """Verify that bizobj.mmDisssociateAll() works correctly."""
+               pbiz = self.person_biz
+               cbiz = self.company_biz
+               pbiz.seek("Leafe", "last_name")
+               leafe_pk = pbiz.getPK()
+               # Add a bunch of related entries
+               pbiz.mmAssociateValue(cbiz, "company", "AAAA")
+               pbiz.mmAssociateValue(cbiz, "company", "BBBB")
+               pbiz.mmAssociateValue(cbiz, "company", "CCCC")
+               pbiz.mmAssociateValue(cbiz, "company", "DDDD")
+               pbiz.mmAssociateValue(cbiz, "company", "EEEE")
+               emp_count = self.reccount("employees", "person_id = %s" % 
leafe_pk)
+               self.assertEqual(emp_count, 5)
+               # Now disassociate all of them
+               pbiz.mmDisssociateAll(cbiz)
+               emp_count = self.reccount("employees", "person_id = %s" % 
leafe_pk)
+               self.assertEqual(emp_count, 0)
+
+
+       def test_full_associate(self):
+               """Verify that bizobj.mmSetFullAssociation() works correctly."""
+               pbiz = self.person_biz
+               cbiz = self.company_biz
+               pbiz.seek("Leafe", "last_name")
+               leafe_pk = pbiz.getPK()
+               pbiz.mmAssociateValue(cbiz, "company", "AAAA")
+               pbiz.mmAssociateValue(cbiz, "company", "BBBB")
+               pbiz.mmAssociateValue(cbiz, "company", "CCCC")
+               pbiz.mmAssociateValue(cbiz, "company", "DDDD")
+               pbiz.mmAssociateValue(cbiz, "company", "EEEE")
+               emp_count = self.reccount("employees", "person_id = %s" % 
leafe_pk)
+               self.assertEqual(emp_count, 5)
+               pbiz.mmSetFullAssociation(cbiz, "company", ["yy", "zz"])
+               emp_count = self.reccount("employees", "person_id = %s" % 
leafe_pk)
+               self.assertEqual(emp_count, 2)
+
+
+       def test_add_to_both(self):
+               """Verify that bizobj.mmAddToBoth() works correctly."""
+               pbiz = self.person_biz
+               cbiz = self.company_biz
+               person_count = self.reccount("person")
+               self.assertEqual(person_count, 2)
+               company_count = self.reccount("company")
+               self.assertEqual(company_count, 1)
+               emp_count = self.reccount("employees")
+               self.assertEqual(emp_count, 0)
+
+               # Add values that exist in both already
+               pbiz.mmAddToBoth(cbiz, "last_name", "Leafe", "company", "Acme 
Manufacturing")
+               # Person and company should be unchanged; there should be one 
more employee
+               person_count = self.reccount("person")
+               self.assertEqual(person_count, 2)
+               company_count = self.reccount("company")
+               self.assertEqual(company_count, 1)
+               emp_count = self.reccount("employees")
+               self.assertEqual(emp_count, 1)
+
+               # Add a new company with an existing person
+               pbiz.mmAddToBoth(cbiz, "last_name", "Leafe", "company", "Dabo 
Incorporated")
+               # Person should be unchanged; there should be one more company 
and employee
+               person_count = self.reccount("person")
+               self.assertEqual(person_count, 2)
+               company_count = self.reccount("company")
+               self.assertEqual(company_count, 2)
+               emp_count = self.reccount("employees")
+
+               # Add a new person with an existing company
+               pbiz.mmAddToBoth(cbiz, "last_name", "Schwartz", "company", 
"Dabo Incorporated")
+               # Person should increase and employee should increase
+               person_count = self.reccount("person")
+               self.assertEqual(person_count, 3)
+               company_count = self.reccount("company")
+               self.assertEqual(company_count, 2)
+               emp_count = self.reccount("employees")
+               self.assertEqual(emp_count, 3)
+
+               # Add the same relation again. Nothing should change.
+               pbiz.mmAddToBoth(cbiz, "last_name", "Schwartz", "company", 
"Dabo Incorporated")
+               person_count = self.reccount("person")
+               self.assertEqual(person_count, 3)
+               company_count = self.reccount("company")
+               self.assertEqual(company_count, 2)
+               emp_count = self.reccount("employees")
+               self.assertEqual(emp_count, 3)
+
+               # Add new values to both; all tables should increase.
+               pbiz.mmAddToBoth(cbiz, "last_name", "Jones", "company", 
"SmithCo")
+               person_count = self.reccount("person")
+               self.assertEqual(person_count, 4)
+               company_count = self.reccount("company")
+               self.assertEqual(company_count, 3)
+               emp_count = self.reccount("employees")
+               self.assertEqual(emp_count, 4)
+
+
+       def test_reverse_mm_relationship(self):
+               """Add the person bizobj to the company bizobj as a MM 
target."""
+               pbiz = self.person_biz
+               cbiz = self.company_biz
+               cbiz.addMMBizobj(pbiz, "employees", "company_id", "person_id")
+               emp_count = self.reccount("employees")
+               self.assertEqual(emp_count, 0)
+               cbiz.seek("Acme Manufacturing", "company")
+               cbiz.mmAssociateValue(pbiz, "last_name", "McNett")
+               # Employee count should have increased
+               emp_count = self.reccount("employees")
+               self.assertEqual(emp_count, 1)
+               self.assertEqual(self.reccount("company"), 1)
+               pbiz.mmAssociateValue(cbiz, "company", "Amalgamated Industries")
+               # Company count should have increased
+               self.assertEqual(self.reccount("company"), 2)
+
+
+       def test_multiple_mm_relationships(self):
+               """Make sure that more than one MM relationship works as 
expected."""   
+               pbiz = self.person_biz
+               cbiz = self.company_biz
+               fbiz = self.fan_club_biz
+               emp_count = self.reccount("employees")
+               orig_fan_count = self.reccount("membership")
+               orig_club_count = self.reccount("fan_club")
+               pbiz.seek("Leafe", "last_name")
+               leafe_pk = pbiz.getPK()
+               pbiz.mmAssociateValues(fbiz, "performer", ["Ramones", "Green 
Day", "The Clash"])
+               new_fan_count = self.reccount("membership")
+               self.assertEqual(new_fan_count, orig_fan_count + 3)
+               self.assertEqual(orig_club_count, self.reccount("fan_club"))
+               # Now add an employment
+               self.assertEqual(self.reccount("employees"), 0)
+               pbiz.mmAssociateValue(cbiz, "company", "Acme Manufacturing")
+               # Employee count should now be 1
+               self.assertEqual(self.reccount("employees"), 1)
+               # Fan club count should not have changed
+               self.assertEqual(new_fan_count, self.reccount("membership"))
+
+               # Add two new clubs that didn't exist
+               pbiz.mmAssociateValues(fbiz, "performer", ["Wire", "Burning 
Spear"])
+               new_fan_count = self.reccount("membership")
+               new_club_count = self.reccount("fan_club")
+               self.assertEqual(new_fan_count, orig_fan_count + 5)
+               self.assertEqual(new_club_count, orig_club_count + 2)
+               # Employee count should not have changed
+               self.assertEqual(self.reccount("employees"), 1)
+
+
+       def test_get_associated_values(self):
+               """Ensure that the mmGetAssociatedValues() method works 
correctly."""
+               pbiz = self.person_biz
+               fbiz = self.fan_club_biz
+               fbiz.addMMBizobj(pbiz, "membership", "fan_club_id", "person_id")
+               fbiz.seek("Ramones", "performer")
+               fbiz.mmAssociateValues(pbiz, "last_name", ["McNett", "Leafe"])
+               recs = fbiz.mmGetAssociatedValues(pbiz, "first_name")
+               self.assertEqual(len(recs), 2)
+               for rec in recs:
+                       self.assertEqual(rec.keys(), ["first_name"])
+                       self.assert_(rec["first_name"] in ("Paul", "Ed"))
+
+               # Check for no associated records
+               fbiz.seek("Pat Boone", "performer")
+               recs = fbiz.mmGetAssociatedValues(pbiz, "first_name")
+               self.assertEqual(len(recs), 0)
+
+
+
+if __name__ == "__main__":
+       suite = unittest.TestLoader().loadTestsFromTestCase(Test_Many_To_Many)
+       unittest.TextTestRunner(verbosity=2).run(suite)

Modified: trunk/dabo/db/dCursorMixin.py
===================================================================
--- trunk/dabo/db/dCursorMixin.py       2011-09-02 22:02:30 UTC (rev 6820)
+++ trunk/dabo/db/dCursorMixin.py       2011-09-03 18:53:03 UTC (rev 6821)
@@ -30,7 +30,7 @@
                if sql and isinstance(sql, basestring) and len(sql) > 0:
                        self.UserSQL = sql
                # Attributes used for M-M relationships
-               # Temporary! until the refactoring
+               self._isMM = False
                self._mmOtherTable = None
                self._mmOtherPKCol = None
                self._assocTable = None
@@ -482,7 +482,6 @@
                ac.IsPrefCursor = self.IsPrefCursor
                ac.KeyField = self.KeyField
                ac.Table = self.Table
-               # Temporary! until the refactoring
                ac._mmOtherTable = self._mmOtherTable
                ac._mmOtherPKCol = self._mmOtherPKCol
                ac._assocTable = self._assocTable
@@ -876,10 +875,13 @@
                If that record is a new unsaved record, return the temp PK 
value. If this is a
                compound PK, return a tuple containing each field's values.
                """
+               if self._isMM:
+                       # This is a cursor for handling many-many relations. 
Get the PK from the bizobj
+                       return self._bizobj.getPK()
+               ret = None
                if self.RowCount <= 0:
                        raise dException.NoRecordsException(
                                        _("No records in dataset '%s'.") % 
self.Table)
-               ret = None
                if row is None:
                        row = self.RowNumber
                rec = self._records[row]
@@ -1155,20 +1157,40 @@
                return self.mmAddToBoth(self.KeyField, self.getPK(), 
otherField, otherVal)
 
 
+       def mmAssociateValues(self, otherField, listOfValues):
+               keyField = self.KeyField
+               pk = self.getPK()
+               for val in listOfValues:
+                       self.mmAddToBoth(keyField, pk, otherField, val)
+
+
        def mmDisssociateValue(self, otherField, otherVal):
                """
                Removes the association between the current record and the 
specified value
                in the 'other' table of a M-M relationship. If no such 
association exists,
                nothing happens.
                """
+               self.mmDisssociateValues(otherField, [otherVal])
+
+
+       def mmDisssociateValues(self, otherField, listOfValues):
+               """
+               Removes the association between the current record and every 
item in 'listOfValues'
+               in the 'other' table of a M-M relationship. If no such 
association exists,
+               nothing happens.
+               """
                thisTable = self.Table
                otherTable = self._mmOtherTable
-               thisPK = self.lookupPKWithAdd(thisField, thisVal, thisTable)
-               otherPK = self.lookupPKWithAdd(otherField, otherVal, otherTable)
-               aux = self.AuxCursor
-               sql = "delete from %s where %s = ? and %s = ?" % 
(self._assocTable,
-                               self._assocPKColThis, self._assocPKColOther)
-               aux.execute(sql, (thisPK, otherPK))
+               thisPK = self.getPK()
+               for otherVal in listOfValues:
+                       otherPK = self.lookupPKWithAdd(otherField, otherVal, 
otherTable)
+                       aux = self.AuxCursor
+                       sql = "delete from %s where %s = ? and %s = ?" % 
(self._assocTable,
+                                       self._assocPKColThis, 
self._assocPKColOther)
+                       try:
+                               aux.execute(sql, (thisPK, otherPK))
+                       except dException.NoRecordsException:
+                               pass
 
 
        def mmDisssociateAll(self):
@@ -1178,7 +1200,10 @@
                """
                aux = self.AuxCursor
                sql = "delete from %s where %s = ?" % (self._assocTable, 
self._assocPKColThis)
-               aux.execute(sql, (self.getPK(),))
+               try:
+                       aux.execute(sql, (self.getPK(),))
+               except dException.NoRecordsException:
+                       pass
 
 
        def mmSetFullAssociation(self, otherField, listOfValues):
@@ -1187,10 +1212,7 @@
                is associated with every item in listOfValues, and none other.
                """
                self.mmDisssociateAll()
-               keyField = self.KeyField
-               pk = self.getPK()
-               for val in listOfValues:
-                       self.mmAddToBoth(keyField, pk, otherField, val)
+               self.mmAssociateValues(otherField, listOfValues)
 
 
        def mmAddToBoth(self, thisField, thisVal, otherField, otherVal):
@@ -1214,6 +1236,28 @@
                        aux.execute(sql, (thisPK, otherPK))
 
 
+       def mmGetAssociatedValues(self, listOfFields):
+               """
+               Returns a dataset containing the values for the specified fields
+               in the records associated with the current record.
+               """
+               aux = self.AuxCursor
+               # Add the related table alias
+               aliased_names = ["%s.%s" % (self._mmOtherTable, fld)
+                               for fld in listOfFields]
+               fldNames = ", ".join(aliased_names)
+               otherPKcol = self._mmOtherPKCol
+               aux.setFromClause(self._assocTable)
+               join = "join %s on %s.%s = %s.%s" % (self._mmOtherTable, 
self._assocTable,
+                               self._assocPKColOther, self._mmOtherTable, 
self._mmOtherPKCol)
+               aux.setJoinClause(join)
+               aux.setFieldClause(fldNames)
+               aux.setWhereClause("%s.%s = ?" % (self._assocTable, 
self._assocPKColThis))
+               params = (self.getPK(),)
+               aux.requery(params)
+               return aux.getDataSet()
+
+
        def getRecordStatus(self, row=None, pk=None):
                """
                Returns a dictionary containing an element for each changed
@@ -2418,7 +2462,6 @@
        def createAssociation(self, mmOtherTable, mmOtherPKCol, assocTable, 
assocPKColThis, assocPKColOther):
                """Create a many-to-many association."""
                # Save locally
-               # Temporary! until the refactoring
                self._mmOtherTable = mmOtherTable
                self._mmOtherPKCol = mmOtherPKCol
                self._assocTable = assocTable



_______________________________________________
Post Messages to: [email protected]
Subscription Maintenance: http://leafe.com/mailman/listinfo/dabo-dev
Searchable Archives: http://leafe.com/archives/search/dabo-dev
This message: 
http://leafe.com/archives/byMID/[email protected]

Reply via email to