Author: cito
Date: Fri Jan 15 09:25:31 2016
New Revision: 748

Log:
Add method truncate() to DB wrapper class

This methods can be used to quickly truncate tables.

Since this is pretty useful and will not break anything, I have
also back ported this addition to the 4.x branch.

Everything is well documented and tested, of course.

Modified:
   branches/4.x/docs/contents/changelog.rst
   branches/4.x/docs/contents/pg/db_wrapper.rst
   branches/4.x/pg.py
   branches/4.x/tests/test_classic_dbwrapper.py
   trunk/docs/contents/changelog.rst
   trunk/docs/contents/pg/db_wrapper.rst
   trunk/pg.py
   trunk/tests/test_classic_dbwrapper.py

Modified: branches/4.x/docs/contents/changelog.rst
==============================================================================
--- branches/4.x/docs/contents/changelog.rst    Fri Jan 15 06:06:51 2016        
(r747)
+++ branches/4.x/docs/contents/changelog.rst    Fri Jan 15 09:25:31 2016        
(r748)
@@ -9,6 +9,8 @@
 - Force build to compile with no errors.
 - New methods get_parameters() and set_parameters() in the classic interface
   which can be used to get or set run-time parameters.
+- New method truncate() in the classic interface that can be used to quickly
+  empty a table or a set of tables.
 - Fix decimal point handling.
 - Add option to return boolean values as bool objects.
 - Add option to return money values as string.

Modified: branches/4.x/docs/contents/pg/db_wrapper.rst
==============================================================================
--- branches/4.x/docs/contents/pg/db_wrapper.rst        Fri Jan 15 06:06:51 
2016        (r747)
+++ branches/4.x/docs/contents/pg/db_wrapper.rst        Fri Jan 15 09:25:31 
2016        (r748)
@@ -415,6 +415,35 @@
 The return value is the number of deleted rows (i.e. 0 if the row did not
 exist and 1 if the row was deleted).
 
+truncate -- Quickly empty database tables
+-----------------------------------------
+
+.. method:: DB.truncate(self, table, [restart], [cascade], [only]):
+
+    Empty a table or set of tables
+
+    :param table: the name of the table(s)
+    :type table: str, list or set
+    :param bool restart: whether table sequences should be restarted
+    :param bool cascade: whether referenced tables should also be truncated
+    :param only: whether only parent tables should be truncated
+    :type only: bool or list
+
+This method quickly removes all rows from the given table or set
+of tables.  It has the same effect as an unqualified DELETE on each
+table, but since it does not actually scan the tables it is faster.
+Furthermore, it reclaims disk space immediately, rather than requiring
+a subsequent VACUUM operation. This is most useful on large tables.
+
+If *restart* is set to `True`, sequences owned by columns of the truncated
+table(s) are automatically restarted.  If *cascade* is set to `True`, it
+also truncates all tables that have foreign-key references to any of
+the named tables.  If the parameter *only* is not set to `True`, all the
+descendant tables (if any) will also be truncated. Optionally, a ``*``
+can be specified after the table name to explicitly indicate that
+descendant tables are included.  If the parameter *table* is a list,
+the parameter *only* can also be a list of corresponding boolean values.
+
 escape_literal -- escape a literal string for use within SQL
 ------------------------------------------------------------
 

Modified: branches/4.x/pg.py
==============================================================================
--- branches/4.x/pg.py  Fri Jan 15 06:06:51 2016        (r747)
+++ branches/4.x/pg.py  Fri Jan 15 09:25:31 2016        (r748)
@@ -622,8 +622,8 @@
             parameter = dict.fromkeys(parameter, value)
         elif isinstance(parameter, dict):
             if value is not None:
-                raise ValueError(
-                    'A value must not be set when parameter is a dictionary')
+                raise ValueError('A value must not be specified'
+                    ' when parameter is a dictionary')
         else:
             raise TypeError(
                 'The parameter must be a string, list, set or dict')
@@ -639,8 +639,8 @@
                 raise TypeError('Invalid parameter')
             if param == 'all':
                 if value is not None:
-                    raise ValueError(
-                        "A value must ot be set when parameter is 'all'")
+                    raise ValueError('A value must ot be specified'
+                        " when parameter is 'all'")
                 params = {'all': None}
                 break
             params[param] = value
@@ -1099,6 +1099,62 @@
         self._do_debug(q)
         return int(self.db.query(q))
 
+    def truncate(self, table, restart=False, cascade=False, only=False):
+        """Empty a table or set of tables.
+
+        This method quickly removes all rows from the given table or set
+        of tables.  It has the same effect as an unqualified DELETE on each
+        table, but since it does not actually scan the tables it is faster.
+        Furthermore, it reclaims disk space immediately, rather than requiring
+        a subsequent VACUUM operation. This is most useful on large tables.
+
+        If restart is set to True, sequences owned by columns of the truncated
+        table(s) are automatically restarted.  If cascade is set to True, it
+        also truncates all tables that have foreign-key references to any of
+        the named tables.  If the parameter only is not set to True, all the
+        descendant tables (if any) will also be truncated. Optionally, a '*'
+        can be specified after the table name to explicitly indicate that
+        descendant tables are included.
+        """
+        if isinstance(table, basestring):
+            only = {table: only}
+            table = [table]
+        elif isinstance(table, (list, tuple)):
+            if isinstance(only, (list, tuple)):
+                only = dict(zip(table, only))
+            else:
+                only = dict.fromkeys(table, only)
+        elif isinstance(table, (set, frozenset)):
+            only = dict.fromkeys(table, only)
+        else:
+            raise TypeError('The table must be a string, list or set')
+        if not (restart is None or isinstance(restart, (bool, int))):
+            raise TypeError('Invalid type for the restart option')
+        if not (cascade is None or isinstance(cascade, (bool, int))):
+            raise TypeError('Invalid type for the cascade option')
+        tables = []
+        for t in table:
+            u = only.get(t)
+            if not (u is None or isinstance(u, (bool, int))):
+                raise TypeError('Invalid type for the only option')
+            if t.endswith('*'):
+                if u:
+                    raise ValueError(
+                        'Contradictory table name and only options')
+                t = t[:-1].rstrip()
+            t = self._add_schema(t)
+            if u:
+                t = 'ONLY %s' % t
+            tables.append(t)
+        q = ['TRUNCATE', ', '.join(tables)]
+        if restart:
+            q.append('RESTART IDENTITY')
+        if cascade:
+            q.append('CASCADE')
+        q = ' '.join(q)
+        self._do_debug(q)
+        return self.query(q)
+
     def notification_handler(self, event, callback, arg_dict={}, timeout=None):
         """Get notification handler that will run the given callback."""
         return NotificationHandler(self.db, event, callback, arg_dict, timeout)

Modified: branches/4.x/tests/test_classic_dbwrapper.py
==============================================================================
--- branches/4.x/tests/test_classic_dbwrapper.py        Fri Jan 15 06:06:51 
2016        (r747)
+++ branches/4.x/tests/test_classic_dbwrapper.py        Fri Jan 15 09:25:31 
2016        (r748)
@@ -94,7 +94,7 @@
             'savepoint', 'server_version',
             'set_notice_receiver', 'set_parameter',
             'source', 'start', 'status',
-            'transaction', 'tty',
+            'transaction', 'truncate', 'tty',
             'unescape_bytea', 'update',
             'use_regtypes', 'user',
         ]
@@ -1148,6 +1148,208 @@
         self.assertEqual(r, ['f'])
         query("drop table %s" % table)
 
+    def testTruncate(self):
+        truncate = self.db.truncate
+        self.assertRaises(TypeError, truncate, None)
+        self.assertRaises(TypeError, truncate, 42)
+        self.assertRaises(TypeError, truncate, dict(test_table=None))
+        query = self.db.query
+        query("drop table if exists test_table")
+        query("create table test_table (n smallint)")
+        for i in range(3):
+            query("insert into test_table values (1)")
+        q = "select count(*) from test_table"
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 3)
+        truncate('test_table')
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 0)
+        for i in range(3):
+            query("insert into test_table values (1)")
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 3)
+        truncate('public.test_table')
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 0)
+        query("drop table if exists test_table_2")
+        query('create table test_table_2 (n smallint)')
+        for t in (list, tuple, set):
+            for i in range(3):
+                query("insert into test_table values (1)")
+                query("insert into test_table_2 values (2)")
+            q = ("select (select count(*) from test_table),"
+                " (select count(*) from test_table_2)")
+            r = query(q).getresult()[0]
+            self.assertEqual(r, (3, 3))
+            truncate(t(['test_table', 'test_table_2']))
+            r = query(q).getresult()[0]
+            self.assertEqual(r, (0, 0))
+        query("drop table test_table_2")
+        query("drop table test_table")
+
+    def testTruncateRestart(self):
+        truncate = self.db.truncate
+        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
+        query = self.db.query
+        query("drop table if exists test_table")
+        query("create table test_table (n serial, t text)")
+        for n in range(3):
+            query("insert into test_table (t) values ('test')")
+        q = "select count(n), min(n), max(n) from test_table"
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 1, 3))
+        truncate('test_table')
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, None, None))
+        for n in range(3):
+            query("insert into test_table (t) values ('test')")
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 4, 6))
+        truncate('test_table', restart=True)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, None, None))
+        for n in range(3):
+            query("insert into test_table (t) values ('test')")
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 1, 3))
+        query("drop table test_table")
+
+    def testTruncateCascade(self):
+        truncate = self.db.truncate
+        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
+        query = self.db.query
+        query("drop table if exists test_child")
+        query("drop table if exists test_parent")
+        query("create table test_parent (n smallint primary key)")
+        query("create table test_child ("
+            " n smallint primary key references test_parent (n))")
+        for n in range(3):
+            query("insert into test_parent (n) values (%d)" % n)
+            query("insert into test_child (n) values (%d)" % n)
+        q = ("select (select count(*) from test_parent),"
+            " (select count(*) from test_child)")
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 3))
+        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
+        truncate(['test_parent', 'test_child'])
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0))
+        for n in range(3):
+            query("insert into test_parent (n) values (%d)" % n)
+            query("insert into test_child (n) values (%d)" % n)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 3))
+        truncate('test_parent', cascade=True)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0))
+        for n in range(3):
+            query("insert into test_parent (n) values (%d)" % n)
+            query("insert into test_child (n) values (%d)" % n)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 3))
+        truncate('test_child')
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 0))
+        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
+        truncate('test_parent', cascade=True)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0))
+        query("drop table test_child")
+        query("drop table test_parent")
+
+    def testTruncateOnly(self):
+        truncate = self.db.truncate
+        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
+        query = self.db.query
+        query("drop table if exists test_child")
+        query("drop table if exists test_parent")
+        query("create table test_parent (n smallint)")
+        query("create table test_child ("
+            " m smallint) inherits (test_parent)")
+        for n in range(3):
+            query("insert into test_parent (n) values (1)")
+            query("insert into test_child (n, m) values (2, 3)")
+        q = ("select (select count(*) from test_parent),"
+            " (select count(*) from test_child)")
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (6, 3))
+        truncate('test_parent')
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0))
+        for n in range(3):
+            query("insert into test_parent (n) values (1)")
+            query("insert into test_child (n, m) values (2, 3)")
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (6, 3))
+        truncate('test_parent*')
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0))
+        for n in range(3):
+            query("insert into test_parent (n) values (1)")
+            query("insert into test_child (n, m) values (2, 3)")
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (6, 3))
+        truncate('test_parent', only=True)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 3))
+        truncate('test_parent', only=False)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0))
+        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
+        truncate('test_parent*', only=False)
+        query("drop table if exists test_parent_2")
+        query("create table test_parent_2 (n smallint)")
+        query("drop table if exists test_child_2")
+        query("create table test_child_2 ("
+            " m smallint) inherits (test_parent_2)")
+        for n in range(3):
+            query("insert into test_parent (n) values (1)")
+            query("insert into test_child (n, m) values (2, 3)")
+            query("insert into test_parent_2 (n) values (1)")
+            query("insert into test_child_2 (n, m) values (2, 3)")
+        q = ("select (select count(*) from test_parent),"
+            " (select count(*) from test_child),"
+            " (select count(*) from test_parent_2),"
+            " (select count(*) from test_child_2)")
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (6, 3, 6, 3))
+        truncate(['test_parent', 'test_parent_2'], only=[False, True])
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0, 3, 3))
+        truncate(['test_parent', 'test_parent_2'], only=False)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0, 0, 0))
+        self.assertRaises(ValueError, truncate,
+            ['test_parent*', 'test_child'], only=[True, False])
+        truncate(['test_parent*', 'test_child'], only=[False, True])
+        query("drop table test_child_2")
+        query("drop table test_parent_2")
+        query("drop table test_child")
+        query("drop table test_parent")
+
+    def testTruncateQuoted(self):
+        truncate = self.db.truncate
+        query = self.db.query
+        table = "test table for truncate()"
+        query('drop table if exists "%s"' % table)
+        query('create table "%s" (n smallint)' % table)
+        for i in range(3):
+            query('insert into "%s" values (1)' % table)
+        q = 'select count(*) from "%s"' % table
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 3)
+        truncate(table)
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 0)
+        for i in range(3):
+            query('insert into "%s" values (1)' % table)
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 3)
+        truncate('public."%s"' % table)
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 0)
+        query('drop table "%s"' % table)
+
     def testTransaction(self):
         query = self.db.query
         query("drop table if exists test_table")

Modified: trunk/docs/contents/changelog.rst
==============================================================================
--- trunk/docs/contents/changelog.rst   Fri Jan 15 06:06:51 2016        (r747)
+++ trunk/docs/contents/changelog.rst   Fri Jan 15 09:25:31 2016        (r748)
@@ -48,6 +48,8 @@
 - Force build to compile with no errors.
 - New methods get_parameters() and set_parameters() in the classic interface
   which can be used to get or set run-time parameters.
+- New method truncate() in the classic interface that can be used to quickly
+  empty a table or a set of tables.
 - Fix decimal point handling.
 - Add option to return boolean values as bool objects.
 - Add option to return money values as string.

Modified: trunk/docs/contents/pg/db_wrapper.rst
==============================================================================
--- trunk/docs/contents/pg/db_wrapper.rst       Fri Jan 15 06:06:51 2016        
(r747)
+++ trunk/docs/contents/pg/db_wrapper.rst       Fri Jan 15 09:25:31 2016        
(r748)
@@ -476,6 +476,35 @@
 The return value is the number of deleted rows (i.e. 0 if the row did not
 exist and 1 if the row was deleted).
 
+truncate -- Quickly empty database tables
+-----------------------------------------
+
+.. method:: DB.truncate(self, table, [restart], [cascade], [only]):
+
+    Empty a table or set of tables
+
+    :param table: the name of the table(s)
+    :type table: str, list or set
+    :param bool restart: whether table sequences should be restarted
+    :param bool cascade: whether referenced tables should also be truncated
+    :param only: whether only parent tables should be truncated
+    :type only: bool or list
+
+This method quickly removes all rows from the given table or set
+of tables.  It has the same effect as an unqualified DELETE on each
+table, but since it does not actually scan the tables it is faster.
+Furthermore, it reclaims disk space immediately, rather than requiring
+a subsequent VACUUM operation. This is most useful on large tables.
+
+If *restart* is set to `True`, sequences owned by columns of the truncated
+table(s) are automatically restarted.  If *cascade* is set to `True`, it
+also truncates all tables that have foreign-key references to any of
+the named tables.  If the parameter *only* is not set to `True`, all the
+descendant tables (if any) will also be truncated. Optionally, a ``*``
+can be specified after the table name to explicitly indicate that
+descendant tables are included.  If the parameter *table* is a list,
+the parameter *only* can also be a list of corresponding boolean values.
+
 escape_literal -- escape a literal string for use within SQL
 ------------------------------------------------------------
 

Modified: trunk/pg.py
==============================================================================
--- trunk/pg.py Fri Jan 15 06:06:51 2016        (r747)
+++ trunk/pg.py Fri Jan 15 09:25:31 2016        (r748)
@@ -577,8 +577,8 @@
                 raise TypeError('Invalid parameter')
             if param == 'all':
                 if value is not None:
-                    raise ValueError(
-                        "A value must ot be set when parameter is 'all'")
+                    raise ValueError('A value must ot be specified'
+                        " when parameter is 'all'")
                 params = {'all': None}
                 break
             params[param] = value
@@ -1087,6 +1087,62 @@
         res = self.db.query(q, params)
         return int(res)
 
+    def truncate(self, table, restart=False, cascade=False, only=False):
+        """Empty a table or set of tables.
+
+        This method quickly removes all rows from the given table or set
+        of tables.  It has the same effect as an unqualified DELETE on each
+        table, but since it does not actually scan the tables it is faster.
+        Furthermore, it reclaims disk space immediately, rather than requiring
+        a subsequent VACUUM operation. This is most useful on large tables.
+
+        If restart is set to True, sequences owned by columns of the truncated
+        table(s) are automatically restarted.  If cascade is set to True, it
+        also truncates all tables that have foreign-key references to any of
+        the named tables.  If the parameter only is not set to True, all the
+        descendant tables (if any) will also be truncated. Optionally, a '*'
+        can be specified after the table name to explicitly indicate that
+        descendant tables are included.
+        """
+        if isinstance(table, basestring):
+            only = {table: only}
+            table = [table]
+        elif isinstance(table, (list, tuple)):
+            if isinstance(only, (list, tuple)):
+                only = dict(zip(table, only))
+            else:
+                only = dict.fromkeys(table, only)
+        elif isinstance(table, (set, frozenset)):
+            only = dict.fromkeys(table, only)
+        else:
+            raise TypeError('The table must be a string, list or set')
+        if not (restart is None or isinstance(restart, (bool, int))):
+            raise TypeError('Invalid type for the restart option')
+        if not (cascade is None or isinstance(cascade, (bool, int))):
+            raise TypeError('Invalid type for the cascade option')
+        tables = []
+        for t in table:
+            u = only.get(t)
+            if not (u is None or isinstance(u, (bool, int))):
+                raise TypeError('Invalid type for the only option')
+            if t.endswith('*'):
+                if u:
+                    raise ValueError(
+                        'Contradictory table name and only options')
+                t = t[:-1].rstrip()
+            t = self._escape_qualified_name(t)
+            if u:
+                t = 'ONLY %s' % t
+            tables.append(t)
+        q = ['TRUNCATE', ', '.join(tables)]
+        if restart:
+            q.append('RESTART IDENTITY')
+        if cascade:
+            q.append('CASCADE')
+        q = ' '.join(q)
+        self._do_debug(q)
+        return self.query(q)
+
     def notification_handler(self, event, callback, arg_dict={}, timeout=None):
         """Get notification handler that will run the given callback."""
         return NotificationHandler(self.db, event, callback, arg_dict, timeout)

Modified: trunk/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/tests/test_classic_dbwrapper.py       Fri Jan 15 06:06:51 2016        
(r747)
+++ trunk/tests/test_classic_dbwrapper.py       Fri Jan 15 09:25:31 2016        
(r748)
@@ -107,7 +107,7 @@
             'savepoint', 'server_version',
             'set_notice_receiver', 'set_parameter',
             'source', 'start', 'status',
-            'transaction',
+            'transaction', 'truncate',
             'unescape_bytea', 'update', 'upsert',
             'use_regtypes', 'user',
         ]
@@ -1514,6 +1514,208 @@
         r = query('select count(*) from "%s"' % table).getresult()
         self.assertEqual(r[0][0], 0)
 
+    def testTruncate(self):
+        truncate = self.db.truncate
+        self.assertRaises(TypeError, truncate, None)
+        self.assertRaises(TypeError, truncate, 42)
+        self.assertRaises(TypeError, truncate, dict(test_table=None))
+        query = self.db.query
+        query("drop table if exists test_table")
+        self.addCleanup(query, "drop table test_table")
+        query("create table test_table (n smallint)")
+        for i in range(3):
+            query("insert into test_table values (1)")
+        q = "select count(*) from test_table"
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 3)
+        truncate('test_table')
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 0)
+        for i in range(3):
+            query("insert into test_table values (1)")
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 3)
+        truncate('public.test_table')
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 0)
+        query("drop table if exists test_table_2")
+        self.addCleanup(query, "drop table test_table_2")
+        query('create table test_table_2 (n smallint)')
+        for t in (list, tuple, set):
+            for i in range(3):
+                query("insert into test_table values (1)")
+                query("insert into test_table_2 values (2)")
+            q = ("select (select count(*) from test_table),"
+                " (select count(*) from test_table_2)")
+            r = query(q).getresult()[0]
+            self.assertEqual(r, (3, 3))
+            truncate(t(['test_table', 'test_table_2']))
+            r = query(q).getresult()[0]
+            self.assertEqual(r, (0, 0))
+
+    def testTruncateRestart(self):
+        truncate = self.db.truncate
+        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
+        query = self.db.query
+        query("drop table if exists test_table")
+        self.addCleanup(query, "drop table test_table")
+        query("create table test_table (n serial, t text)")
+        for n in range(3):
+            query("insert into test_table (t) values ('test')")
+        q = "select count(n), min(n), max(n) from test_table"
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 1, 3))
+        truncate('test_table')
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, None, None))
+        for n in range(3):
+            query("insert into test_table (t) values ('test')")
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 4, 6))
+        truncate('test_table', restart=True)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, None, None))
+        for n in range(3):
+            query("insert into test_table (t) values ('test')")
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 1, 3))
+
+    def testTruncateCascade(self):
+        truncate = self.db.truncate
+        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
+        query = self.db.query
+        query("drop table if exists test_child")
+        query("drop table if exists test_parent")
+        self.addCleanup(query, "drop table test_parent")
+        query("create table test_parent (n smallint primary key)")
+        self.addCleanup(query, "drop table test_child")
+        query("create table test_child ("
+            " n smallint primary key references test_parent (n))")
+        for n in range(3):
+            query("insert into test_parent (n) values (%d)" % n)
+            query("insert into test_child (n) values (%d)" % n)
+        q = ("select (select count(*) from test_parent),"
+            " (select count(*) from test_child)")
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 3))
+        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
+        truncate(['test_parent', 'test_child'])
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0))
+        for n in range(3):
+            query("insert into test_parent (n) values (%d)" % n)
+            query("insert into test_child (n) values (%d)" % n)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 3))
+        truncate('test_parent', cascade=True)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0))
+        for n in range(3):
+            query("insert into test_parent (n) values (%d)" % n)
+            query("insert into test_child (n) values (%d)" % n)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 3))
+        truncate('test_child')
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 0))
+        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
+        truncate('test_parent', cascade=True)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0))
+
+    def testTruncateOnly(self):
+        truncate = self.db.truncate
+        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
+        query = self.db.query
+        query("drop table if exists test_child")
+        query("drop table if exists test_parent")
+        self.addCleanup(query, "drop table test_parent")
+        query("create table test_parent (n smallint)")
+        self.addCleanup(query, "drop table test_child")
+        query("create table test_child ("
+            " m smallint) inherits (test_parent)")
+        for n in range(3):
+            query("insert into test_parent (n) values (1)")
+            query("insert into test_child (n, m) values (2, 3)")
+        q = ("select (select count(*) from test_parent),"
+            " (select count(*) from test_child)")
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (6, 3))
+        truncate('test_parent')
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0))
+        for n in range(3):
+            query("insert into test_parent (n) values (1)")
+            query("insert into test_child (n, m) values (2, 3)")
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (6, 3))
+        truncate('test_parent*')
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0))
+        for n in range(3):
+            query("insert into test_parent (n) values (1)")
+            query("insert into test_child (n, m) values (2, 3)")
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (6, 3))
+        truncate('test_parent', only=True)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (3, 3))
+        truncate('test_parent', only=False)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0))
+        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
+        truncate('test_parent*', only=False)
+        query("drop table if exists test_parent_2")
+        self.addCleanup(query, "drop table test_parent_2")
+        query("create table test_parent_2 (n smallint)")
+        query("drop table if exists test_child_2")
+        self.addCleanup(query, "drop table test_child_2")
+        query("create table test_child_2 ("
+            " m smallint) inherits (test_parent_2)")
+        for n in range(3):
+            query("insert into test_parent (n) values (1)")
+            query("insert into test_child (n, m) values (2, 3)")
+            query("insert into test_parent_2 (n) values (1)")
+            query("insert into test_child_2 (n, m) values (2, 3)")
+        q = ("select (select count(*) from test_parent),"
+            " (select count(*) from test_child),"
+            " (select count(*) from test_parent_2),"
+            " (select count(*) from test_child_2)")
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (6, 3, 6, 3))
+        truncate(['test_parent', 'test_parent_2'], only=[False, True])
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0, 3, 3))
+        truncate(['test_parent', 'test_parent_2'], only=False)
+        r = query(q).getresult()[0]
+        self.assertEqual(r, (0, 0, 0, 0))
+        self.assertRaises(ValueError, truncate,
+            ['test_parent*', 'test_child'], only=[True, False])
+        truncate(['test_parent*', 'test_child'], only=[False, True])
+
+    def testTruncateQuoted(self):
+        truncate = self.db.truncate
+        query = self.db.query
+        table = "test table for truncate()"
+        query('drop table if exists "%s"' % table)
+        self.addCleanup(query, 'drop table "%s"' % table)
+        query('create table "%s" (n smallint)' % table)
+        for i in range(3):
+            query('insert into "%s" values (1)' % table)
+        q = 'select count(*) from "%s"' % table
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 3)
+        truncate(table)
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 0)
+        for i in range(3):
+            query('insert into "%s" values (1)' % table)
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 3)
+        truncate('public."%s"' % table)
+        r = query(q).getresult()[0][0]
+        self.assertEqual(r, 0)
+
     def testTransaction(self):
         query = self.db.query
         query("drop table if exists test_table")
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to