Author: cito
Date: Tue Jan 19 10:19:38 2016
New Revision: 768

Log:
Refactoring of DB wrapper test

The creation of temporary tables happened way too often,
so this was outsourced into a separate method.

Modified:
   trunk/tests/test_classic_dbwrapper.py

Modified: trunk/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/tests/test_classic_dbwrapper.py       Tue Jan 19 05:51:40 2016        
(r767)
+++ trunk/tests/test_classic_dbwrapper.py       Tue Jan 19 10:19:38 2016        
(r768)
@@ -10,7 +10,6 @@
 These tests need a database to test against.
 
 """
-
 try:
     import unittest2 as unittest  # for Python < 2.7
 except ImportError:
@@ -315,6 +314,35 @@
         self.doCleanups()
         self.db.close()
 
+    def createTable(self, table, definition,
+            temporary=True, oids=None, values=None):
+        query = self.db.query
+        if not '"' in table or '.' in table:
+            table = '"%s"' % table
+        if not temporary:
+            q = 'drop table if exists %s cascade' % table
+            query(q)
+            self.addCleanup(query, q)
+        temporary = 'temporary table' if temporary else 'table'
+        as_query = definition.startswith(('as ', 'AS '))
+        if not as_query and not definition.startswith('('):
+            definition = '(%s)' % definition
+        with_oids = 'with oids' if oids else 'without oids'
+        q = ['create', temporary, table]
+        if as_query:
+            q.extend([with_oids, definition])
+        else:
+            q.extend([definition, with_oids])
+        q = ' '.join(q)
+        query(q)
+        if values:
+            for params in values:
+                if not isinstance(params, (list, tuple)):
+                    params = [params]
+                values = ', '.join('$%d' % (n + 1) for n in range(len(params)))
+                q = "insert into %s values (%s)" % (table, values)
+                query(q, params)
+
     def testClassName(self):
         self.assertEqual(self.db.__class__.__name__, 'DB')
 
@@ -628,13 +656,21 @@
         r = self.db.get_parameter('datestyle')
         self.assertEqual(r, default_datestyle)
 
+    def testCreateTable(self):
+        table = 'test hello world'
+        values = [(2, "World!"), (1, "Hello")]
+        self.createTable(table, "n smallint, t varchar",
+            temporary=True, oids=True, values=values)
+        r = self.db.query('select t from "%s" order by n' % table).getresult()
+        r = ', '.join(row[0] for row in r)
+        self.assertEqual(r, "Hello, World!")
+        r = self.db.query('select oid from "%s" limit 1' % table).getresult()
+        self.assertIsInstance(r[0][0], int)
+
     def testQuery(self):
         query = self.db.query
-        query("drop table if exists test_table")
-        self.addCleanup(query, "drop table test_table")
-        q = "create table test_table (n integer) with oids"
-        r = query(q)
-        self.assertIsNone(r)
+        table = 'test_table'
+        self.createTable(table, "n integer", oids=True)
         q = "insert into test_table values (1)"
         r = query(q)
         self.assertIsInstance(r, int)
@@ -670,10 +706,7 @@
 
     def testQueryWithParams(self):
         query = self.db.query
-        query("drop table if exists test_table")
-        self.addCleanup(query, "drop table test_table")
-        q = "create table test_table (n1 integer, n2 integer) with oids"
-        query(q)
+        self.createTable('test_table', 'n1 integer, n2 integer', oids=True)
         q = "insert into test_table values ($1, $2)"
         r = query(q, (1, 2))
         self.assertIsInstance(r, int)
@@ -710,33 +743,23 @@
         pkey = self.db.pkey
         self.assertRaises(KeyError, pkey, 'test')
         for t in ('pkeytest', 'primary key test'):
-            for n in range(8):
-                query('drop table if exists "%s%d"' % (t, n))
-                self.addCleanup(query, 'drop table "%s%d"' % (t, n))
-            query('create table "%s0" ('
-                "a smallint)" % t)
-            query('create table "%s1" ('
-                "b smallint primary key)" % t)
-            query('create table "%s2" ('
-                "c smallint, d smallint primary key)" % t)
-            query('create table "%s3" ('
-                "e smallint, f smallint, g smallint,"
-                " h smallint, i smallint,"
-                " primary key (f, h))" % t)
-            query('create table "%s4" ('
-                "e smallint, f smallint, g smallint,"
-                " h smallint, i smallint,"
-                " primary key (h, f))" % t)
-            query('create table "%s5" ('
-                "more_than_one_letter varchar primary key)" % t)
-            query('create table "%s6" ('
-                '"with space" date primary key)' % t)
-            query('create table "%s7" ('
-                'a_very_long_column_name varchar,'
-                ' "with space" date,'
-                ' "42" int,'
-                " primary key (a_very_long_column_name,"
-                ' "with space", "42"))' % t)
+            self.createTable('%s0' % t, 'a smallint')
+            self.createTable('%s1' % t, 'b smallint primary key')
+            self.createTable('%s2' % t,
+                'c smallint, d smallint primary key')
+            self.createTable('%s3' % t,
+                'e smallint, f smallint, g smallint, h smallint, i smallint,'
+                ' primary key (f, h)')
+            self.createTable('%s4' % t,
+                'e smallint, f smallint, g smallint, h smallint, i smallint,'
+                ' primary key (h, f)')
+            self.createTable('%s5' % t,
+                'more_than_one_letter varchar primary key')
+            self.createTable('%s6' % t,
+                '"with space" date primary key')
+            self.createTable('%s7' % t,
+                'a_very_long_column_name varchar, "with space" date, "42" int,'
+                ' primary key (a_very_long_column_name, "with space", "42")')
             self.assertRaises(KeyError, pkey, '%s0' % t)
             self.assertEqual(pkey('%s1' % t), 'b')
             self.assertEqual(pkey('%s1' % t, True), ('b',))
@@ -779,9 +802,15 @@
 
     def testGetTables(self):
         get_tables = self.db.get_tables
-        result1 = get_tables()
-        self.assertIsInstance(result1, list)
-        for t in result1:
+        tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe',
+            'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321',
+            'averyveryveryveryveryveryveryreallyreallylongtablename',
+            'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z')
+        for t in tables:
+            self.db.query('drop table if exists "%s" cascade' % t)
+        before_tables = get_tables()
+        self.assertIsInstance(before_tables, list)
+        for t in before_tables:
             t = t.split('.', 1)
             self.assertGreaterEqual(len(t), 2)
             if len(t) > 2:
@@ -789,30 +818,16 @@
             t = t[0]
             self.assertNotEqual(t, 'information_schema')
             self.assertFalse(t.startswith('pg_'))
-        tables = ('"A very Special Name"',
-            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
-            'A_MiXeD_NaMe', '"another special name"',
-            'averyveryveryveryveryveryverylongtablename',
-            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
         for t in tables:
-            self.db.query('drop table if exists %s' % t)
-            self.db.query("create table %s"
-                " as select 0" % t)
-        result3 = get_tables()
-        result2 = []
-        for t in result3:
-            if t not in result1:
-                result2.append(t)
-        result3 = []
-        for t in tables:
-            if not t.startswith('"'):
-                t = t.lower()
-            result3.append('public.' + t)
-        self.assertEqual(result2, result3)
-        for t in result2:
-            self.db.query('drop table %s' % t)
-        result2 = get_tables()
-        self.assertEqual(result2, result1)
+            self.createTable(t, 'as select 0', temporary=False)
+        current_tables = get_tables()
+        new_tables = [t for t in current_tables if t not in before_tables]
+        expected_new_tables = ['public.%s' % (
+            '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables]
+        self.assertEqual(new_tables, expected_new_tables)
+        self.doCleanups()
+        after_tables = get_tables()
+        self.assertEqual(after_tables, before_tables)
 
     def testGetRelations(self):
         get_relations = self.db.get_relations
@@ -844,12 +859,9 @@
             i2='int', i4='int', i8='int', d='num',
             f4='float', f8='float', m='money',
             v4='text', c4='text', t='text'))
-        query = self.db.query
-        query("drop table if exists test_table")
-        self.addCleanup(query, "drop table test_table")
-        query("create table test_table("
-            " n int, alpha smallint, beta bool,"
-            " gamma char(5), tau text, v varchar(3))")
+        self.createTable('test_table',
+            'n int, alpha smallint, beta bool,'
+            ' gamma char(5), tau text, v varchar(3)')
         r = get_attnames('test_table')
         self.assertIsInstance(r, dict)
         self.assertEqual(r, dict(
@@ -858,27 +870,21 @@
 
     def testGetAttnamesWithQuotes(self):
         get_attnames = self.db.get_attnames
-        query = self.db.query
         table = 'test table for get_attnames()'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s"('
-            '"Prime!" smallint,'
-            ' "much space" integer, "Questions?" text)' % table)
+        self.createTable(table,
+            '"Prime!" smallint, "much space" integer, "Questions?" text')
         r = get_attnames(table)
         self.assertIsInstance(r, dict)
         self.assertEqual(r, {
             'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
         table = 'yet another test table for get_attnames()'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        self.db.query('create table "%s" ('
+        self.createTable(table,
             'a smallint, b integer, c bigint,'
             ' e numeric, f float, f2 double precision, m money,'
             ' x smallint, y smallint, z smallint,'
             ' Normal_NaMe smallint, "Special Name" smallint,'
             ' t text, u char(2), v varchar(2),'
-            ' primary key (y, u)) with oids' % table)
+            ' primary key (y, u)', oids=True)
         r = get_attnames(table)
         self.assertIsInstance(r, dict)
         self.assertEqual(r, {'a': 'int', 'c': 'int', 'b': 'int',
@@ -889,12 +895,9 @@
 
     def testGetAttnamesWithRegtypes(self):
         get_attnames = self.db.get_attnames
-        query = self.db.query
-        query("drop table if exists test_table")
-        self.addCleanup(query, "drop table test_table")
-        query("create table test_table("
-            " n int, alpha smallint, beta bool,"
-            " gamma char(5), tau text, v varchar(3))")
+        self.createTable('test_table',
+            ' n int, alpha smallint, beta bool,'
+            ' gamma char(5), tau text, v varchar(3)')
         use_regtypes = self.db.use_regtypes
         regtypes = use_regtypes()
         self.assertFalse(regtypes)
@@ -911,9 +914,7 @@
     def testGetAttnamesIsCached(self):
         get_attnames = self.db.get_attnames
         query = self.db.query
-        query("drop table if exists test_table")
-        self.addCleanup(query, "drop table test_table")
-        query("create table test_table(col int)")
+        self.createTable('test_table', 'col int')
         r = get_attnames("test_table")
         self.assertIsInstance(r, dict)
         self.assertEqual(r, dict(col='int'))
@@ -937,11 +938,9 @@
     def testGetAttnamesIsOrdered(self):
         get_attnames = self.db.get_attnames
         query = self.db.query
-        query("drop table if exists test_table")
-        self.addCleanup(query, "drop table test_table")
-        query("create table test_table("
-            " n int, alpha smallint, v varchar(3),"
-            " gamma char(5), tau text, beta bool)")
+        self.createTable('test_table',
+            ' n int, alpha smallint, v varchar(3),'
+            ' gamma char(5), tau text, beta bool')
         r = get_attnames("test_table")
         self.assertIsInstance(r, OrderedDict)
         self.assertEqual(r, OrderedDict([
@@ -972,13 +971,8 @@
         table = 'get_test_table'
         self.assertRaises(TypeError, get)
         self.assertRaises(TypeError, get, table)
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            "n integer, t text) without oids" % table)
-        for n, t in enumerate('xyz'):
-            query('insert into "%s" values('"%d, '%s')"
-                % (table, n + 1, t))
+        self.createTable(table, 'n integer, t text',
+            values=enumerate('xyz', start=1))
         self.assertRaises(pg.ProgrammingError, get, table, 2)
         r = get(table, 2, 'n')
         self.assertIsInstance(r, dict)
@@ -1030,13 +1024,8 @@
         get = self.db.get
         query = self.db.query
         table = 'get_with_oid_test_table'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            "n integer, t text) with oids" % table)
-        for n, t in enumerate('xyz'):
-            query('insert into "%s" values('"%d, '%s')"
-                % (table, n + 1, t))
+        self.createTable(table, 'n integer, t text', oids=True,
+            values=enumerate('xyz', start=1))
         self.assertRaises(pg.ProgrammingError, get, table, 2)
         self.assertRaises(KeyError, get, table, {}, 'oid')
         r = get(table, 2, 'n')
@@ -1099,13 +1088,8 @@
         get = self.db.get
         query = self.db.query
         table = 'get_test_table_1'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            "n integer, t text, primary key (n))" % table)
-        for n, t in enumerate('abc'):
-            query('insert into "%s" values('
-                "%d, '%s')" % (table, n + 1, t))
+        self.createTable(table, 'n integer primary key, t text',
+            values=enumerate('abc', start=1))
         self.assertEqual(get(table, 2)['t'], 'b')
         self.assertEqual(get(table, 1, 'n')['t'], 'a')
         self.assertEqual(get(table, 2, ('n',))['t'], 'b')
@@ -1115,15 +1099,10 @@
         self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
         self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
         table = 'get_test_table_2'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            "n integer, m integer, t text, primary key (n, m))" % table)
-        for n in range(3):
-            for m in range(2):
-                t = chr(ord('a') + 2 * n + m)
-                query('insert into "%s" values('
-                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
+        self.createTable(table,
+            'n integer, m integer, t text, primary key (n, m)',
+            values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
+                for n in range(3) for m in range(2)])
         self.assertRaises(KeyError, get, table, 2)
         self.assertEqual(get(table, (1, 1))['t'], 'a')
         self.assertEqual(get(table, (1, 2))['t'], 'b')
@@ -1141,13 +1120,9 @@
         get = self.db.get
         query = self.db.query
         table = 'test table for get()'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            '"Prime!" smallint primary key,'
-            ' "much space" integer, "Questions?" text)' % table)
-        query('insert into "%s"'
-              " values(17, 1001, 'No!')" % table)
+        self.createTable(table, '"Prime!" smallint primary key,'
+            ' "much space" integer, "Questions?" text',
+            values=[(17, 1001, 'No!')])
         r = get(table, 17)
         self.assertIsInstance(r, dict)
         self.assertEqual(r['Prime!'], 17)
@@ -1165,16 +1140,10 @@
     def testGetLittleBobbyTables(self):
         get = self.db.get
         query = self.db.query
-        query("drop table if exists test_students")
-        self.addCleanup(query, "drop table test_students")
-        query("create table test_students (firstname varchar primary key,"
-            " nickname varchar, grade char(2))")
-        query("insert into test_students values ("
-              "'D''Arcy', 'Darcey', 'A+')")
-        query("insert into test_students values ("
-              "'Sheldon', 'Moonpie', 'A+')")
-        query("insert into test_students values ("
-              "'Robert', 'Little Bobby Tables', 'D-')")
+        self.createTable('test_students',
+            'firstname varchar primary key, nickname varchar, grade char(2)',
+            values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'),
+                    ('Robert', 'Little Bobby Tables', 'D-')])
         r = get('test_students', 'Sheldon')
         self.assertEqual(r, dict(
             firstname="Sheldon", nickname='Moonpie', grade='A+'))
@@ -1207,13 +1176,11 @@
         bool_on = pg.get_bool()
         decimal = pg.get_decimal()
         table = 'insert_test_table'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            "i2 smallint, i4 integer, i8 bigint,"
-            " d numeric, f4 real, f8 double precision, m money,"
-            " v4 varchar(4), c4 char(4), t text,"
-            " b boolean, ts timestamp) with oids" % table)
+        self.createTable(table,
+            'i2 smallint, i4 integer, i8 bigint,'
+            ' d numeric, f4 real, f8 double precision, m money,'
+            ' v4 varchar(4), c4 char(4), t text,'
+            ' b boolean, ts timestamp', oids=True)
         oid_table = 'oid(%s)' % table
         tests = [dict(i2=None, i4=None, i8=None),
             (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
@@ -1308,9 +1275,7 @@
     def testInsertWithOid(self):
         insert = self.db.insert
         query = self.db.query
-        query("drop table if exists test_table")
-        self.addCleanup(query, "drop table test_table")
-        query("create table test_table (n int) with oids")
+        self.createTable('test_table', 'n int', oids=True)
         r = insert('test_table', n=1)
         self.assertIsInstance(r, dict)
         self.assertEqual(r['n'], 1)
@@ -1379,11 +1344,8 @@
         insert = self.db.insert
         query = self.db.query
         table = 'test table for insert()'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            '"Prime!" smallint primary key,'
-            ' "much space" integer, "Questions?" text)' % table)
+        self.createTable(table, '"Prime!" smallint primary key,'
+            ' "much space" integer, "Questions?" text')
         r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
         r = insert(table, r)
         self.assertIsInstance(r, dict)
@@ -1403,13 +1365,8 @@
         self.assertRaises(pg.ProgrammingError, update,
             'test', i2=2, i4=4, i8=8)
         table = 'update_test_table'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            "n integer, t text) with oids" % table)
-        for n, t in enumerate('xyz'):
-            query('insert into "%s" values('
-                "%d, '%s')" % (table, n + 1, t))
+        self.createTable(table, 'n integer, t text', oids=True,
+            values=enumerate('xyz', start=1))
         self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
         r = self.db.get(table, 2, 'n')
         r['t'] = 'u'
@@ -1423,10 +1380,7 @@
         update = self.db.update
         get = self.db.get
         query = self.db.query
-        query("drop table if exists test_table")
-        self.addCleanup(query, "drop table test_table")
-        query("create table test_table (n int) with oids")
-        query("insert into test_table values (1)")
+        self.createTable('test_table', 'n int', oids=True, values=[1])
         s = get('test_table', 1, 'n')
         self.assertIsInstance(s, dict)
         self.assertEqual(s['n'], 1)
@@ -1493,17 +1447,28 @@
         r = query(q).getresult()
         self.assertEqual(r, [(1, 3), (5, 9)])
 
+    def testUpdateWithoutOid(self):
+        update = self.db.update
+        query = self.db.query
+        self.assertRaises(pg.ProgrammingError, update,
+            'test', i2=2, i4=4, i8=8)
+        table = 'update_test_table'
+        self.createTable(table, 'n integer primary key, t text', oids=False,
+            values=enumerate('xyz', start=1))
+        r = self.db.get(table, 2)
+        r['t'] = 'u'
+        s = update(table, r)
+        self.assertEqual(s, r)
+        q = 'select t from "%s" where n=2' % table
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 'u')
+
     def testUpdateWithCompositeKey(self):
         update = self.db.update
         query = self.db.query
         table = 'update_test_table_1'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table if exists "%s"' % table)
-        query('create table "%s" ('
-            "n integer, t text, primary key (n))" % table)
-        for n, t in enumerate('abc'):
-            query('insert into "%s" values('
-                "%d, '%s')" % (table, n + 1, t))
+        self.createTable(table, 'n integer primary key, t text',
+            values=enumerate('abc', start=1))
         self.assertRaises(KeyError, update, table, dict(t='b'))
         s = dict(n=2, t='d')
         r = update(table, s)
@@ -1525,14 +1490,10 @@
         self.assertEqual(len(r), 0)
         query('drop table "%s"' % table)
         table = 'update_test_table_2'
-        query('drop table if exists "%s"' % table)
-        query('create table "%s" ('
-            "n integer, m integer, t text, primary key (n, m))" % table)
-        for n in range(3):
-            for m in range(2):
-                t = chr(ord('a') + 2 * n + m)
-                query('insert into "%s" values('
-                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
+        self.createTable(table,
+            'n integer, m integer, t text, primary key (n, m)',
+            values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
+                for n in range(3) for m in range(2)])
         self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
         self.assertEqual(update(table,
             dict(n=2, m=2, t='x'))['t'], 'x')
@@ -1544,13 +1505,9 @@
         update = self.db.update
         query = self.db.query
         table = 'test table for update()'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            '"Prime!" smallint primary key,'
-            ' "much space" integer, "Questions?" text)' % table)
-        query('insert into "%s"'
-              " values(13, 3003, 'Why!')" % table)
+        self.createTable(table, '"Prime!" smallint primary key,'
+            ' "much space" integer, "Questions?" text',
+            values=[(13, 3003, 'Why!')])
         r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
         r = update(table, r)
         self.assertIsInstance(r, dict)
@@ -1570,10 +1527,7 @@
         self.assertRaises(pg.ProgrammingError, upsert,
             'test', i2=2, i4=4, i8=8)
         table = 'upsert_test_table'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            "n integer primary key, t text) with oids" % table)
+        self.createTable(table, 'n integer primary key, t text', oids=True)
         s = dict(n=1, t='x')
         try:
             r = upsert(table, s)
@@ -1643,10 +1597,7 @@
         upsert = self.db.upsert
         get = self.db.get
         query = self.db.query
-        query("drop table if exists test_table")
-        self.addCleanup(query, "drop table test_table")
-        query("create table test_table (n int) with oids")
-        query("insert into test_table values (1)")
+        self.createTable('test_table', 'n int', oids=True, values=[1])
         self.assertRaises(pg.ProgrammingError,
             upsert, 'test_table', dict(n=2))
         r = get('test_table', 1, 'n')
@@ -1725,10 +1676,8 @@
         upsert = self.db.upsert
         query = self.db.query
         table = 'upsert_test_table_2'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            "n integer, m integer, t text, primary key (n, m))" % table)
+        self.createTable(table,
+            'n integer, m integer, t text, primary key (n, m)')
         s = dict(n=1, m=2, t='x')
         try:
             r = upsert(table, s)
@@ -1794,11 +1743,8 @@
         upsert = self.db.upsert
         query = self.db.query
         table = 'test table for upsert()'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            '"Prime!" smallint primary key,'
-            ' "much space" integer, "Questions?" text)' % table)
+        self.createTable(table, '"Prime!" smallint primary key,'
+            ' "much space" integer, "Questions?" text')
         s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
         try:
             r = upsert(table, s)
@@ -1824,37 +1770,31 @@
 
     def testClear(self):
         clear = self.db.clear
-        query = self.db.query
         f = False if pg.get_bool() else 'f'
         r = clear('test')
         result = dict(
             i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
         self.assertEqual(r, result)
         table = 'clear_test_table'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            "n integer, b boolean, d date, t text) with oids" % table)
+        self.createTable(table,
+            'n integer, f float, b boolean, d date, t text', oids=True)
         r = clear(table)
-        result = dict(n=0, b=f, d='', t='')
+        result = dict(n=0, f=0, b=f, d='', t='')
         self.assertEqual(r, result)
-        r['a'] = r['n'] = 1
+        r['a'] = r['f'] = r['n'] = 1
         r['d'] = r['t'] = 'x'
         r['b'] = 't'
         r['oid'] = long(1)
         r = clear(table, r)
-        result = dict(a=1, n=0, b=f, d='', t='', oid=long(1))
+        result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1))
         self.assertEqual(r, result)
 
     def testClearWithQuotedNames(self):
         clear = self.db.clear
         query = self.db.query
         table = 'test table for clear()'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            '"Prime!" smallint primary key,'
-            ' "much space" integer, "Questions?" text)' % table)
+        self.createTable(table, '"Prime!" smallint primary key,'
+            ' "much space" integer, "Questions?" text')
         r = clear(table)
         self.assertIsInstance(r, dict)
         self.assertEqual(r['Prime!'], 0)
@@ -1867,13 +1807,8 @@
         self.assertRaises(pg.ProgrammingError, delete,
             'test', dict(i2=2, i4=4, i8=8))
         table = 'delete_test_table'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            "n integer, t text) with oids" % table)
-        for n, t in enumerate('xyz'):
-            query('insert into "%s" values('
-                "%d, '%s')" % (table, n + 1, t))
+        self.createTable(table, 'n integer, t text', oids=True,
+            values=enumerate('xyz', start=1))
         self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
         r = self.db.get(table, 1, 'n')
         s = delete(table, r)
@@ -1903,11 +1838,7 @@
         delete = self.db.delete
         get = self.db.get
         query = self.db.query
-        query("drop table if exists test_table")
-        self.addCleanup(query, "drop table test_table")
-        query("create table test_table (n int) with oids")
-        for i in range(6):
-            query("insert into test_table values (%d)" % (i + 1))
+        self.createTable('test_table', 'n int', oids=True, values=range(1, 7))
         r = dict(n=3)
         self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
         s = get('test_table', 1, 'n')
@@ -1995,13 +1926,8 @@
     def testDeleteWithCompositeKey(self):
         query = self.db.query
         table = 'delete_test_table_1'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            "n integer, t text, primary key (n))" % table)
-        for n, t in enumerate('abc'):
-            query("insert into %s values("
-                "%d, '%s')" % (table, n + 1, t))
+        self.createTable(table, 'n integer primary key, t text',
+            values=enumerate('abc', start=1))
         self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
         self.assertEqual(self.db.delete(table, dict(n=2)), 1)
         r = query('select t from "%s" where n=2' % table).getresult()
@@ -2010,15 +1936,10 @@
         r = query('select t from "%s" where n=3' % table).getresult()[0][0]
         self.assertEqual(r, 'c')
         table = 'delete_test_table_2'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            "n integer, m integer, t text, primary key (n, m))" % table)
-        for n in range(3):
-            for m in range(2):
-                t = chr(ord('a') + 2 * n + m)
-                query('insert into "%s" values('
-                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
+        self.createTable(table,
+            'n integer, m integer, t text, primary key (n, m)',
+            values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
+                for n in range(3) for m in range(2)])
         self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
         self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
         r = [r[0] for r in query('select t from "%s" where n=2'
@@ -2037,13 +1958,9 @@
         delete = self.db.delete
         query = self.db.query
         table = 'test table for delete()'
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" ('
-            '"Prime!" smallint primary key,'
-            ' "much space" integer, "Questions?" text)' % table)
-        query('insert into "%s"'
-              " values(19, 5005, 'Yes!')" % table)
+        self.createTable(table, '"Prime!" smallint primary key,'
+            ' "much space" integer, "Questions?" text',
+            values=[(19, 5005, 'Yes!')])
         r = {'Prime!': 17}
         r = delete(table, r)
         self.assertEqual(r, 0)
@@ -2058,16 +1975,10 @@
     def testDeleteReferenced(self):
         delete = self.db.delete
         query = self.db.query
-        query("drop table if exists test_child")
-        query("drop table if exists test_parent")
-        self.addCleanup(query, "drop table test_parent")
-        query("create table test_parent (n smallint primary key)")
-        self.addCleanup(query, "drop table test_child")
-        query("create table test_child ("
-            " n smallint primary key references test_parent (n))")
-        for n in range(3):
-            query("insert into test_parent (n) values (%d)" % n)
-            query("insert into test_child (n) values (%d)" % n)
+        self.createTable('test_parent',
+            'n smallint primary key', values=range(3))
+        self.createTable('test_child',
+            'n smallint primary key references test_parent', values=range(3))
         q = ("select (select count(*) from test_parent),"
             " (select count(*) from test_child)")
         self.assertEqual(query(q).getresult()[0], (3, 3))
@@ -2104,11 +2015,8 @@
         self.assertRaises(TypeError, truncate, 42)
         self.assertRaises(TypeError, truncate, dict(test_table=None))
         query = self.db.query
-        query("drop table if exists test_table")
-        self.addCleanup(query, "drop table test_table")
-        query("create table test_table (n smallint)")
-        for i in range(3):
-            query("insert into test_table values (1)")
+        self.createTable('test_table', 'n smallint',
+            temporary=False, values=[1] * 3)
         q = "select count(*) from test_table"
         r = query(q).getresult()[0][0]
         self.assertEqual(r, 3)
@@ -2122,9 +2030,7 @@
         truncate('public.test_table')
         r = query(q).getresult()[0][0]
         self.assertEqual(r, 0)
-        query("drop table if exists test_table_2")
-        self.addCleanup(query, "drop table test_table_2")
-        query('create table test_table_2 (n smallint)')
+        self.createTable('test_table_2', 'n smallint', temporary=True)
         for t in (list, tuple, set):
             for i in range(3):
                 query("insert into test_table values (1)")
@@ -2141,9 +2047,7 @@
         truncate = self.db.truncate
         self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
         query = self.db.query
-        query("drop table if exists test_table")
-        self.addCleanup(query, "drop table test_table")
-        query("create table test_table (n serial, t text)")
+        self.createTable('test_table', 'n serial, t text')
         for n in range(3):
             query("insert into test_table (t) values ('test')")
         q = "select count(n), min(n), max(n) from test_table"
@@ -2168,16 +2072,11 @@
         truncate = self.db.truncate
         self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
         query = self.db.query
-        query("drop table if exists test_child")
-        query("drop table if exists test_parent")
-        self.addCleanup(query, "drop table test_parent")
-        query("create table test_parent (n smallint primary key)")
-        self.addCleanup(query, "drop table test_child")
-        query("create table test_child ("
-            " n smallint primary key references test_parent (n))")
-        for n in range(3):
-            query("insert into test_parent (n) values (%d)" % n)
-            query("insert into test_child (n) values (%d)" % n)
+        self.createTable('test_parent', 'n smallint primary key',
+            values=range(3))
+        self.createTable('test_child',
+            'n smallint primary key references test_parent (n)',
+            values=range(3))
         q = ("select (select count(*) from test_parent),"
             " (select count(*) from test_child)")
         r = query(q).getresult()[0]
@@ -2211,13 +2110,8 @@
         truncate = self.db.truncate
         self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
         query = self.db.query
-        query("drop table if exists test_child")
-        query("drop table if exists test_parent")
-        self.addCleanup(query, "drop table test_parent")
-        query("create table test_parent (n smallint)")
-        self.addCleanup(query, "drop table test_child")
-        query("create table test_child ("
-            " m smallint) inherits (test_parent)")
+        self.createTable('test_parent', 'n smallint')
+        self.createTable('test_child', 'm smallint) inherits (test_parent')
         for n in range(3):
             query("insert into test_parent (n) values (1)")
             query("insert into test_child (n, m) values (2, 3)")
@@ -2249,18 +2143,12 @@
         self.assertEqual(r, (0, 0))
         self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
         truncate('test_parent*', only=False)
-        query("drop table if exists test_parent_2")
-        self.addCleanup(query, "drop table test_parent_2")
-        query("create table test_parent_2 (n smallint)")
-        query("drop table if exists test_child_2")
-        self.addCleanup(query, "drop table test_child_2")
-        query("create table test_child_2 ("
-            " m smallint) inherits (test_parent_2)")
-        for n in range(3):
-            query("insert into test_parent (n) values (1)")
-            query("insert into test_child (n, m) values (2, 3)")
-            query("insert into test_parent_2 (n) values (1)")
-            query("insert into test_child_2 (n, m) values (2, 3)")
+        self.createTable('test_parent_2', 'n smallint')
+        self.createTable('test_child_2', 'm smallint) inherits (test_parent_2')
+        for t in '', '_2':
+            for n in range(3):
+                query("insert into test_parent%s (n) values (1)" % t)
+                query("insert into test_child%s (n, m) values (2, 3)" % t)
         q = ("select (select count(*) from test_parent),"
             " (select count(*) from test_child),"
             " (select count(*) from test_parent_2),"
@@ -2281,11 +2169,7 @@
         truncate = self.db.truncate
         query = self.db.query
         table = "test table for truncate()"
-        query('drop table if exists "%s"' % table)
-        self.addCleanup(query, 'drop table "%s"' % table)
-        query('create table "%s" (n smallint)' % table)
-        for i in range(3):
-            query('insert into "%s" values (1)' % table)
+        self.createTable(table, 'n smallint', temporary=False, values=[1] * 3)
         q = 'select count(*) from "%s"' % table
         r = query(q).getresult()[0][0]
         self.assertEqual(r, 3)
@@ -2302,9 +2186,7 @@
 
     def testTransaction(self):
         query = self.db.query
-        query("drop table if exists test_table")
-        self.addCleanup(query, "drop table test_table")
-        query("create table test_table (n integer)")
+        self.createTable('test_table', 'n integer', temporary=False)
         self.db.begin()
         query("insert into test_table values (1)")
         query("insert into test_table values (2)")
@@ -2348,9 +2230,7 @@
 
     def testContextManager(self):
         query = self.db.query
-        query("drop table if exists test_table")
-        self.addCleanup(query, "drop table test_table")
-        query("create table test_table (n integer check(n>0))")
+        self.createTable('test_table', 'n integer check(n>0)')
         with self.db:
             query("insert into test_table values (1)")
             query("insert into test_table values (2)")
@@ -2377,9 +2257,7 @@
 
     def testBytea(self):
         query = self.db.query
-        query('drop table if exists bytea_test')
-        self.addCleanup(query, 'drop table bytea_test')
-        query('create table bytea_test (n smallint primary key, data bytea)')
+        self.createTable('bytea_test', 'n smallint primary key, data bytea')
         s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
         r = self.db.escape_bytea(s)
         query('insert into bytea_test values(3,$1)', (r,))
@@ -2396,9 +2274,7 @@
 
     def testInsertUpdateGetBytea(self):
         query = self.db.query
-        query('drop table if exists bytea_test')
-        self.addCleanup(query, 'drop table bytea_test')
-        query('create table bytea_test (n smallint primary key, data bytea)')
+        self.createTable('bytea_test', 'n smallint primary key, data bytea')
         # insert null value
         r = self.db.insert('bytea_test', n=0, data=None)
         self.assertIsInstance(r, dict)
@@ -2458,9 +2334,7 @@
 
     def testUpsertBytea(self):
         query = self.db.query
-        query('drop table if exists bytea_test')
-        self.addCleanup(query, 'drop table bytea_test')
-        query('create table bytea_test (n smallint primary key, data bytea)')
+        self.createTable('bytea_test', 'n smallint primary key, data bytea')
         s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
         r = dict(n=7, data=s)
         try:
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to