Author: cito
Date: Wed Jan 13 19:30:50 2016
New Revision: 739

Log:
Return ordered dict for attributes is possible

Sometimes it's important to know the order of the columns in a table.
By returning an OrderedDict instead of a dict in get_attnames, we can
deliver that information en passant, while staying backward compatible.

Modified:
   trunk/docs/contents/pg/db_wrapper.rst
   trunk/pg.py
   trunk/tests/test_classic_dbwrapper.py

Modified: trunk/docs/contents/pg/db_wrapper.rst
==============================================================================
--- trunk/docs/contents/pg/db_wrapper.rst       Wed Jan 13 17:57:29 2016        
(r738)
+++ trunk/docs/contents/pg/db_wrapper.rst       Wed Jan 13 19:30:50 2016        
(r739)
@@ -125,6 +125,17 @@
 
 Given the name of a table, digs out the set of attribute names.
 
+Returns a dictionary of attribute names (the names are the keys,
+the values are the names of the attributes' types).
+
+If your Python version supports this, the dictionary will be an
+OrderedDictionary with the column names in the right order.
+
+By default, only a limited number of simple types will be returned.
+You can get the regular types after enabling this by calling the
+:meth:`DB.use_regtypes` method.
+
+
 has_table_privilege -- check table privilege
 --------------------------------------------
 

Modified: trunk/pg.py
==============================================================================
--- trunk/pg.py Wed Jan 13 17:57:29 2016        (r738)
+++ trunk/pg.py Wed Jan 13 19:30:50 2016        (r739)
@@ -40,6 +40,11 @@
 from functools import partial
 
 try:
+    from collections import OrderedDict
+except ImportError:  # Python 2.6 or 3.0
+    OrderedDict = dict
+
+try:
     basestring
 except NameError:  # Python >= 3.0
     basestring = (str, bytes)
@@ -553,8 +558,12 @@
         Returns a dictionary of attribute names (the names are the keys,
         the values are the names of the attributes' types).
 
-        If the optional newattnames exists, it must be a dictionary and
-        will become the new attribute names dictionary.
+        If your Python version supports this, the dictionary will be an
+        OrderedDictionary with the column names in the right order.
+
+        If flush is set then the internal cache for attribute names will
+        be flushed. This may be necessary after the database schema or
+        the search path has been changed.
 
         By default, only a limited number of simple types will be returned.
         You can get the regular types after calling use_regtypes(True).
@@ -572,16 +581,15 @@
                 " JOIN pg_type t ON t.oid = a.atttypid"
                 " WHERE a.attrelid = %s::regclass"
                 " AND (a.attnum > 0 OR a.attname = 'oid')"
-                " AND NOT a.attisdropped") % (
+                " AND NOT a.attisdropped ORDER BY a.attnum") % (
                     '::regtype' if self._regtypes else '',
                     self._prepare_qualified_param(table, 1))
             names = self.db.query(q, (table,)).getresult()
             if not names:
                 raise KeyError('Table %s does not exist' % table)
-            if self._regtypes:
-                names = dict(names)
-            else:
-                names = dict((name, _simpletype(typ)) for name, typ in names)
+            if not self._regtypes:
+                names = ((name, _simpletype(typ)) for name, typ in names)
+            names = OrderedDict(names)
             attnames[table] = names  # cache it
         return names
 

Modified: trunk/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/tests/test_classic_dbwrapper.py       Wed Jan 13 17:57:29 2016        
(r738)
+++ trunk/tests/test_classic_dbwrapper.py       Wed Jan 13 19:30:50 2016        
(r739)
@@ -48,6 +48,11 @@
 except NameError:  # Python >= 3.0
     unicode = str
 
+try:
+    from collections import OrderedDict
+except ImportError:  # Python 2.6 or 3.0
+    OrderedDict = dict
+
 windows = os.name == 'nt'
 
 # There is a known a bug in libpq under Windows which can cause
@@ -406,6 +411,90 @@
             b'\\x746861742773206be47365')
         self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
 
+    def testGetAttnames(self):
+        get_attnames = self.db.get_attnames
+        query = self.db.query
+        query("drop table if exists test_table")
+        query("create table test_table("
+            " n int, alpha smallint, beta bool,"
+            " gamma char(5), tau text, v varchar(3))")
+        r = get_attnames("test_table")
+        self.assertIsInstance(r, dict)
+        self.assertEquals(r, dict(
+            n='int', alpha='int', beta='bool',
+            gamma='text', tau='text', v='text'))
+        query("drop table test_table")
+
+    def testGetAttnamesWithQuotes(self):
+        get_attnames = self.db.get_attnames
+        query = self.db.query
+        table = 'test table for get_attnames()'
+        query('drop table if exists "%s"' % table)
+        query('create table "%s"('
+            '"Prime!" smallint,'
+            '"much space" integer, "Questions?" text)' % table)
+        r = get_attnames(table)
+        self.assertIsInstance(r, dict)
+        self.assertEquals(r, {
+            'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
+        query('drop table "%s"' % table)
+
+    def testGetAttnamesWithRegtypes(self):
+        get_attnames = self.db.get_attnames
+        query = self.db.query
+        query("drop table if exists test_table")
+        query("create table test_table("
+            " n int, alpha smallint, beta bool,"
+            " gamma char(5), tau text, v varchar(3))")
+        self.db.use_regtypes(True)
+        try:
+            r = get_attnames("test_table")
+            self.assertIsInstance(r, dict)
+        finally:
+            self.db.use_regtypes(False)
+        self.assertEquals(r, dict(
+            n='integer', alpha='smallint', beta='boolean',
+            gamma='character', tau='text', v='character varying'))
+        query("drop table test_table")
+
+    def testGetAttnamesIsCached(self):
+        get_attnames = self.db.get_attnames
+        query = self.db.query
+        query("drop table if exists test_table")
+        query("create table test_table(col int)")
+        r = get_attnames("test_table")
+        self.assertIsInstance(r, dict)
+        self.assertEquals(r, dict(col='int'))
+        query("drop table test_table")
+        query("create table test_table(col text)")
+        r = get_attnames("test_table")
+        self.assertEquals(r, dict(col='int'))
+        r = get_attnames("test_table", flush=True)
+        self.assertEquals(r, dict(col='text'))
+        query("drop table test_table")
+        r = get_attnames("test_table")
+        self.assertEquals(r, dict(col='text'))
+        self.assertRaises(pg.ProgrammingError,
+            get_attnames, "test_table", flush=True)
+
+    def testGetAttnamesIsOrdered(self):
+        get_attnames = self.db.get_attnames
+        query = self.db.query
+        query("drop table if exists test_table")
+        query("create table test_table("
+            " n int, alpha smallint, v varchar(3),"
+            " gamma char(5), tau text, beta bool)")
+        r = get_attnames("test_table")
+        self.assertIsInstance(r, OrderedDict)
+        self.assertEquals(r, OrderedDict([
+            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
+            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
+        query("drop table test_table")
+        if OrderedDict is dict:
+            self.skipTest('OrderedDict is not supported')
+        r = ' '.join(list(r.keys()))
+        self.assertEquals(r, 'n alpha v gamma tau beta')
+
     def testQuery(self):
         query = self.db.query
         query("drop table if exists test_table")
@@ -696,10 +785,10 @@
                     "%d, %d, '%s')" % (table, n + 1, m + 1, t))
         self.assertRaises(pg.ProgrammingError, get, table, 2)
         self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
-        self.assertEqual(get(table, dict(n=1, m=2),
-                             ('n', 'm'))['t'], 'b')
-        self.assertEqual(get(table, dict(n=3, m=2),
-                             frozenset(['n', 'm']))['t'], 'f')
+        r = get(table, dict(n=1, m=2), ('n', 'm'))
+        self.assertEqual(r['t'], 'b')
+        r = get(table, dict(n=3, m=2), frozenset(['n', 'm']))
+        self.assertEqual(r['t'], 'f')
         query('drop table "%s"' % table)
 
     def testGetWithQuotedNames(self):
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to