Author: cito
Date: Tue Jan 12 10:32:32 2016
New Revision: 725

Log:
Improve implementation and test for pkey()

Modified:
   branches/4.x/pg.py
   trunk/pg.py
   trunk/tests/test_classic_dbwrapper.py

Modified: branches/4.x/pg.py
==============================================================================
--- branches/4.x/pg.py  Tue Jan 12 08:46:00 2016        (r724)
+++ branches/4.x/pg.py  Tue Jan 12 10:32:32 2016        (r725)
@@ -558,7 +558,7 @@
 
         If newpkey is set and is not a dictionary then set that
         value as the primary key of the class.  If it is a dictionary
-        then replace the _pkeys dictionary with a copy of it.
+        then replace the internal cache of primary keys with a copy of it.
 
         """
         # First see if the caller is supplying a dictionary

Modified: trunk/pg.py
==============================================================================
--- trunk/pg.py Tue Jan 12 08:46:00 2016        (r724)
+++ trunk/pg.py Tue Jan 12 10:32:32 2016        (r725)
@@ -37,6 +37,7 @@
 
 from decimal import Decimal
 from collections import namedtuple
+from itertools import groupby
 
 try:
     basestring
@@ -565,18 +566,19 @@
 
         If newpkey is set and is not a dictionary then set that
         value as the primary key of the class.  If it is a dictionary
-        then replace the _pkeys dictionary with a copy of it.
+        then replace the internal cache for primary keys with a copy of it.
 
         """
+        add_schema = self._add_schema
+
         # First see if the caller is supplying a dictionary
         if isinstance(newpkey, dict):
             # make sure that all classes have a namespace
-            self._pkeys = dict([
-                (cl if '.' in cl else 'public.' + cl, pkey)
-                for cl, pkey in newpkey.items()])
+            self._pkeys = dict((add_schema(cl), pkey)
+                for cl, pkey in newpkey.items())
             return self._pkeys
 
-        qcl = self._add_schema(cl)  # build fully qualified class name
+        qcl = add_schema(cl)  # build fully qualified class name
         # Check if the caller is supplying a new primary key for the class
         if newpkey:
             self._pkeys[qcl] = newpkey
@@ -585,7 +587,6 @@
         # Get all the primary keys at once
         if qcl not in self._pkeys:
             # if not found, check again in case it was added after we started
-            self._pkeys = {}
             q = ("SELECT s.nspname, r.relname, a.attname"
                 " FROM pg_class r"
                 " JOIN pg_namespace s ON s.oid = r.relnamespace"
@@ -594,14 +595,16 @@
                 " JOIN pg_attribute a ON a.attrelid = r.oid"
                 " AND NOT a.attisdropped"
                 " JOIN pg_index i ON i.indrelid = r.oid"
-                " AND i.indisprimary AND a.attnum = ANY (i.indkey)")
-            for r in self.db.query(q).getresult():
-                cl, pkey = _join_parts(r[:2]), r[2]
-                self._pkeys.setdefault(cl, []).append(pkey)
-            # (only) for composite primary keys, the values will be frozensets
-            for cl, pkey in self._pkeys.items():
-                self._pkeys[cl] = frozenset(pkey) if len(pkey) > 1 else pkey[0]
-            self._do_debug(self._pkeys)
+                " AND i.indisprimary AND a.attnum = ANY (i.indkey)"
+                " ORDER BY 1,2")
+            rows = self.db.query(q).getresult()
+            pkeys = {}
+            for cl, group in groupby(rows, lambda row: row[:2]):
+                cl = _join_parts(cl)
+                pkey = [row[2] for row in group]
+                pkeys[cl] = frozenset(pkey) if len(pkey) > 1 else pkey[0]
+            self._do_debug(pkeys)
+            self._pkeys = pkeys
 
         # will raise an exception if primary key doesn't exist
         return self._pkeys[qcl]

Modified: trunk/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/tests/test_classic_dbwrapper.py       Tue Jan 12 08:46:00 2016        
(r724)
+++ trunk/tests/test_classic_dbwrapper.py       Tue Jan 12 10:32:32 2016        
(r725)
@@ -558,31 +558,51 @@
 
     def testPkey(self):
         query = self.db.query
-        for n in range(4):
-            query("drop table if exists pkeytest%d" % n)
-        query("create table pkeytest0 ("
-            "a smallint)")
-        query("create table pkeytest1 ("
-            "b smallint primary key)")
-        query("create table pkeytest2 ("
-            "c smallint, d smallint primary key)")
-        query("create table pkeytest3 ("
-            "e smallint, f smallint, g smallint, "
-            "h smallint, i smallint, "
-            "primary key (f,h))")
         pkey = self.db.pkey
-        self.assertRaises(KeyError, pkey, 'pkeytest0')
-        self.assertEqual(pkey('pkeytest1'), 'b')
-        self.assertEqual(pkey('pkeytest2'), 'd')
-        self.assertEqual(pkey('pkeytest3'), frozenset('fh'))
-        self.assertEqual(pkey('pkeytest0', 'none'), 'none')
-        self.assertEqual(pkey('pkeytest0'), 'none')
-        pkey(None, {'t': 'a', 'n.t': 'b'})
-        self.assertEqual(pkey('t'), 'a')
-        self.assertEqual(pkey('n.t'), 'b')
-        self.assertRaises(KeyError, pkey, 'pkeytest0')
-        for n in range(4):
-            query("drop table pkeytest%d" % n)
+        for t in ('pkeytest', 'primary key test'):
+            for n in range(7):
+                query('drop table if exists "%s%d"' % (t, n))
+            query('create table "%s0" ('
+                "a smallint)" % t)
+            query('create table "%s1" ('
+                "b smallint primary key)" % t)
+            query('create table "%s2" ('
+                "c smallint, d smallint primary key)" % t)
+            query('create table "%s3" ('
+                "e smallint, f smallint, g smallint, "
+                "h smallint, i smallint, "
+                "primary key (f, h))" % t)
+            query('create table "%s4" ('
+                "more_than_one_letter varchar primary key)" % t)
+            query('create table "%s5" ('
+                '"with space" date primary key)' % t)
+            query('create table "%s6" ('
+                'a_very_long_column_name varchar, '
+                '"with space" date, '
+                '"42" int, '
+                "primary key (a_very_long_column_name, "
+                '"with space", "42"))' % t)
+            self.assertRaises(KeyError, pkey, '%s0' % t)
+            self.assertEqual(pkey('%s1' % t), 'b')
+            self.assertEqual(pkey('%s2' % t), 'd')
+            r = pkey('%s3' % t)
+            self.assertIsInstance(r, frozenset)
+            self.assertEqual(r, frozenset('fh'))
+            self.assertEqual(pkey('%s4' % t), 'more_than_one_letter')
+            self.assertEqual(pkey('%s5' % t), 'with space')
+            r = pkey('%s6' % t)
+            self.assertIsInstance(r, frozenset)
+            self.assertEqual(r, frozenset([
+                'a_very_long_column_name', 'with space', '42']))
+            self.assertEqual(pkey('%s0' % t, 'none'), 'none')
+            self.assertEqual(pkey('%s0' % t), 'none')
+            pkey(None, {'%s0' % t: 'a', 'public."%s1"' % t: 'b'})
+            self.assertEqual(pkey('%s0' % t), 'a')
+            self.assertEqual(pkey('%s1' % t), 'b')
+            pkey(None, {})
+            self.assertRaises(KeyError, pkey, '%s0' % t)
+            for n in range(7):
+                query('drop table "%s%d"' % (t, n))
 
     def testGetDatabases(self):
         databases = self.db.get_databases()
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to