Author: cito
Date: Fri Jan 15 09:25:31 2016
New Revision: 748
Log:
Add method truncate() to DB wrapper class
This methods can be used to quickly truncate tables.
Since this is pretty useful and will not break anything, I have
also back ported this addition to the 4.x branch.
Everything is well documented and tested, of course.
Modified:
branches/4.x/docs/contents/changelog.rst
branches/4.x/docs/contents/pg/db_wrapper.rst
branches/4.x/pg.py
branches/4.x/tests/test_classic_dbwrapper.py
trunk/docs/contents/changelog.rst
trunk/docs/contents/pg/db_wrapper.rst
trunk/pg.py
trunk/tests/test_classic_dbwrapper.py
Modified: branches/4.x/docs/contents/changelog.rst
==============================================================================
--- branches/4.x/docs/contents/changelog.rst Fri Jan 15 06:06:51 2016
(r747)
+++ branches/4.x/docs/contents/changelog.rst Fri Jan 15 09:25:31 2016
(r748)
@@ -9,6 +9,8 @@
- Force build to compile with no errors.
- New methods get_parameters() and set_parameters() in the classic interface
which can be used to get or set run-time parameters.
+- New method truncate() in the classic interface that can be used to quickly
+ empty a table or a set of tables.
- Fix decimal point handling.
- Add option to return boolean values as bool objects.
- Add option to return money values as string.
Modified: branches/4.x/docs/contents/pg/db_wrapper.rst
==============================================================================
--- branches/4.x/docs/contents/pg/db_wrapper.rst Fri Jan 15 06:06:51
2016 (r747)
+++ branches/4.x/docs/contents/pg/db_wrapper.rst Fri Jan 15 09:25:31
2016 (r748)
@@ -415,6 +415,35 @@
The return value is the number of deleted rows (i.e. 0 if the row did not
exist and 1 if the row was deleted).
+truncate -- Quickly empty database tables
+-----------------------------------------
+
+.. method:: DB.truncate(self, table, [restart], [cascade], [only]):
+
+ Empty a table or set of tables
+
+ :param table: the name of the table(s)
+ :type table: str, list or set
+ :param bool restart: whether table sequences should be restarted
+ :param bool cascade: whether referenced tables should also be truncated
+ :param only: whether only parent tables should be truncated
+ :type only: bool or list
+
+This method quickly removes all rows from the given table or set
+of tables. It has the same effect as an unqualified DELETE on each
+table, but since it does not actually scan the tables it is faster.
+Furthermore, it reclaims disk space immediately, rather than requiring
+a subsequent VACUUM operation. This is most useful on large tables.
+
+If *restart* is set to `True`, sequences owned by columns of the truncated
+table(s) are automatically restarted. If *cascade* is set to `True`, it
+also truncates all tables that have foreign-key references to any of
+the named tables. If the parameter *only* is not set to `True`, all the
+descendant tables (if any) will also be truncated. Optionally, a ``*``
+can be specified after the table name to explicitly indicate that
+descendant tables are included. If the parameter *table* is a list,
+the parameter *only* can also be a list of corresponding boolean values.
+
escape_literal -- escape a literal string for use within SQL
------------------------------------------------------------
Modified: branches/4.x/pg.py
==============================================================================
--- branches/4.x/pg.py Fri Jan 15 06:06:51 2016 (r747)
+++ branches/4.x/pg.py Fri Jan 15 09:25:31 2016 (r748)
@@ -622,8 +622,8 @@
parameter = dict.fromkeys(parameter, value)
elif isinstance(parameter, dict):
if value is not None:
- raise ValueError(
- 'A value must not be set when parameter is a dictionary')
+ raise ValueError('A value must not be specified'
+ ' when parameter is a dictionary')
else:
raise TypeError(
'The parameter must be a string, list, set or dict')
@@ -639,8 +639,8 @@
raise TypeError('Invalid parameter')
if param == 'all':
if value is not None:
- raise ValueError(
- "A value must ot be set when parameter is 'all'")
+ raise ValueError('A value must ot be specified'
+ " when parameter is 'all'")
params = {'all': None}
break
params[param] = value
@@ -1099,6 +1099,62 @@
self._do_debug(q)
return int(self.db.query(q))
+ def truncate(self, table, restart=False, cascade=False, only=False):
+ """Empty a table or set of tables.
+
+ This method quickly removes all rows from the given table or set
+ of tables. It has the same effect as an unqualified DELETE on each
+ table, but since it does not actually scan the tables it is faster.
+ Furthermore, it reclaims disk space immediately, rather than requiring
+ a subsequent VACUUM operation. This is most useful on large tables.
+
+ If restart is set to True, sequences owned by columns of the truncated
+ table(s) are automatically restarted. If cascade is set to True, it
+ also truncates all tables that have foreign-key references to any of
+ the named tables. If the parameter only is not set to True, all the
+ descendant tables (if any) will also be truncated. Optionally, a '*'
+ can be specified after the table name to explicitly indicate that
+ descendant tables are included.
+ """
+ if isinstance(table, basestring):
+ only = {table: only}
+ table = [table]
+ elif isinstance(table, (list, tuple)):
+ if isinstance(only, (list, tuple)):
+ only = dict(zip(table, only))
+ else:
+ only = dict.fromkeys(table, only)
+ elif isinstance(table, (set, frozenset)):
+ only = dict.fromkeys(table, only)
+ else:
+ raise TypeError('The table must be a string, list or set')
+ if not (restart is None or isinstance(restart, (bool, int))):
+ raise TypeError('Invalid type for the restart option')
+ if not (cascade is None or isinstance(cascade, (bool, int))):
+ raise TypeError('Invalid type for the cascade option')
+ tables = []
+ for t in table:
+ u = only.get(t)
+ if not (u is None or isinstance(u, (bool, int))):
+ raise TypeError('Invalid type for the only option')
+ if t.endswith('*'):
+ if u:
+ raise ValueError(
+ 'Contradictory table name and only options')
+ t = t[:-1].rstrip()
+ t = self._add_schema(t)
+ if u:
+ t = 'ONLY %s' % t
+ tables.append(t)
+ q = ['TRUNCATE', ', '.join(tables)]
+ if restart:
+ q.append('RESTART IDENTITY')
+ if cascade:
+ q.append('CASCADE')
+ q = ' '.join(q)
+ self._do_debug(q)
+ return self.query(q)
+
def notification_handler(self, event, callback, arg_dict={}, timeout=None):
"""Get notification handler that will run the given callback."""
return NotificationHandler(self.db, event, callback, arg_dict, timeout)
Modified: branches/4.x/tests/test_classic_dbwrapper.py
==============================================================================
--- branches/4.x/tests/test_classic_dbwrapper.py Fri Jan 15 06:06:51
2016 (r747)
+++ branches/4.x/tests/test_classic_dbwrapper.py Fri Jan 15 09:25:31
2016 (r748)
@@ -94,7 +94,7 @@
'savepoint', 'server_version',
'set_notice_receiver', 'set_parameter',
'source', 'start', 'status',
- 'transaction', 'tty',
+ 'transaction', 'truncate', 'tty',
'unescape_bytea', 'update',
'use_regtypes', 'user',
]
@@ -1148,6 +1148,208 @@
self.assertEqual(r, ['f'])
query("drop table %s" % table)
+ def testTruncate(self):
+ truncate = self.db.truncate
+ self.assertRaises(TypeError, truncate, None)
+ self.assertRaises(TypeError, truncate, 42)
+ self.assertRaises(TypeError, truncate, dict(test_table=None))
+ query = self.db.query
+ query("drop table if exists test_table")
+ query("create table test_table (n smallint)")
+ for i in range(3):
+ query("insert into test_table values (1)")
+ q = "select count(*) from test_table"
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 3)
+ truncate('test_table')
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 0)
+ for i in range(3):
+ query("insert into test_table values (1)")
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 3)
+ truncate('public.test_table')
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 0)
+ query("drop table if exists test_table_2")
+ query('create table test_table_2 (n smallint)')
+ for t in (list, tuple, set):
+ for i in range(3):
+ query("insert into test_table values (1)")
+ query("insert into test_table_2 values (2)")
+ q = ("select (select count(*) from test_table),"
+ " (select count(*) from test_table_2)")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 3))
+ truncate(t(['test_table', 'test_table_2']))
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0))
+ query("drop table test_table_2")
+ query("drop table test_table")
+
+ def testTruncateRestart(self):
+ truncate = self.db.truncate
+ self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
+ query = self.db.query
+ query("drop table if exists test_table")
+ query("create table test_table (n serial, t text)")
+ for n in range(3):
+ query("insert into test_table (t) values ('test')")
+ q = "select count(n), min(n), max(n) from test_table"
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 1, 3))
+ truncate('test_table')
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, None, None))
+ for n in range(3):
+ query("insert into test_table (t) values ('test')")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 4, 6))
+ truncate('test_table', restart=True)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, None, None))
+ for n in range(3):
+ query("insert into test_table (t) values ('test')")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 1, 3))
+ query("drop table test_table")
+
+ def testTruncateCascade(self):
+ truncate = self.db.truncate
+ self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
+ query = self.db.query
+ query("drop table if exists test_child")
+ query("drop table if exists test_parent")
+ query("create table test_parent (n smallint primary key)")
+ query("create table test_child ("
+ " n smallint primary key references test_parent (n))")
+ for n in range(3):
+ query("insert into test_parent (n) values (%d)" % n)
+ query("insert into test_child (n) values (%d)" % n)
+ q = ("select (select count(*) from test_parent),"
+ " (select count(*) from test_child)")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 3))
+ self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
+ truncate(['test_parent', 'test_child'])
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0))
+ for n in range(3):
+ query("insert into test_parent (n) values (%d)" % n)
+ query("insert into test_child (n) values (%d)" % n)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 3))
+ truncate('test_parent', cascade=True)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0))
+ for n in range(3):
+ query("insert into test_parent (n) values (%d)" % n)
+ query("insert into test_child (n) values (%d)" % n)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 3))
+ truncate('test_child')
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 0))
+ self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
+ truncate('test_parent', cascade=True)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0))
+ query("drop table test_child")
+ query("drop table test_parent")
+
+ def testTruncateOnly(self):
+ truncate = self.db.truncate
+ self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
+ query = self.db.query
+ query("drop table if exists test_child")
+ query("drop table if exists test_parent")
+ query("create table test_parent (n smallint)")
+ query("create table test_child ("
+ " m smallint) inherits (test_parent)")
+ for n in range(3):
+ query("insert into test_parent (n) values (1)")
+ query("insert into test_child (n, m) values (2, 3)")
+ q = ("select (select count(*) from test_parent),"
+ " (select count(*) from test_child)")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (6, 3))
+ truncate('test_parent')
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0))
+ for n in range(3):
+ query("insert into test_parent (n) values (1)")
+ query("insert into test_child (n, m) values (2, 3)")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (6, 3))
+ truncate('test_parent*')
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0))
+ for n in range(3):
+ query("insert into test_parent (n) values (1)")
+ query("insert into test_child (n, m) values (2, 3)")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (6, 3))
+ truncate('test_parent', only=True)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 3))
+ truncate('test_parent', only=False)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0))
+ self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
+ truncate('test_parent*', only=False)
+ query("drop table if exists test_parent_2")
+ query("create table test_parent_2 (n smallint)")
+ query("drop table if exists test_child_2")
+ query("create table test_child_2 ("
+ " m smallint) inherits (test_parent_2)")
+ for n in range(3):
+ query("insert into test_parent (n) values (1)")
+ query("insert into test_child (n, m) values (2, 3)")
+ query("insert into test_parent_2 (n) values (1)")
+ query("insert into test_child_2 (n, m) values (2, 3)")
+ q = ("select (select count(*) from test_parent),"
+ " (select count(*) from test_child),"
+ " (select count(*) from test_parent_2),"
+ " (select count(*) from test_child_2)")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (6, 3, 6, 3))
+ truncate(['test_parent', 'test_parent_2'], only=[False, True])
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0, 3, 3))
+ truncate(['test_parent', 'test_parent_2'], only=False)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0, 0, 0))
+ self.assertRaises(ValueError, truncate,
+ ['test_parent*', 'test_child'], only=[True, False])
+ truncate(['test_parent*', 'test_child'], only=[False, True])
+ query("drop table test_child_2")
+ query("drop table test_parent_2")
+ query("drop table test_child")
+ query("drop table test_parent")
+
+ def testTruncateQuoted(self):
+ truncate = self.db.truncate
+ query = self.db.query
+ table = "test table for truncate()"
+ query('drop table if exists "%s"' % table)
+ query('create table "%s" (n smallint)' % table)
+ for i in range(3):
+ query('insert into "%s" values (1)' % table)
+ q = 'select count(*) from "%s"' % table
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 3)
+ truncate(table)
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 0)
+ for i in range(3):
+ query('insert into "%s" values (1)' % table)
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 3)
+ truncate('public."%s"' % table)
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 0)
+ query('drop table "%s"' % table)
+
def testTransaction(self):
query = self.db.query
query("drop table if exists test_table")
Modified: trunk/docs/contents/changelog.rst
==============================================================================
--- trunk/docs/contents/changelog.rst Fri Jan 15 06:06:51 2016 (r747)
+++ trunk/docs/contents/changelog.rst Fri Jan 15 09:25:31 2016 (r748)
@@ -48,6 +48,8 @@
- Force build to compile with no errors.
- New methods get_parameters() and set_parameters() in the classic interface
which can be used to get or set run-time parameters.
+- New method truncate() in the classic interface that can be used to quickly
+ empty a table or a set of tables.
- Fix decimal point handling.
- Add option to return boolean values as bool objects.
- Add option to return money values as string.
Modified: trunk/docs/contents/pg/db_wrapper.rst
==============================================================================
--- trunk/docs/contents/pg/db_wrapper.rst Fri Jan 15 06:06:51 2016
(r747)
+++ trunk/docs/contents/pg/db_wrapper.rst Fri Jan 15 09:25:31 2016
(r748)
@@ -476,6 +476,35 @@
The return value is the number of deleted rows (i.e. 0 if the row did not
exist and 1 if the row was deleted).
+truncate -- Quickly empty database tables
+-----------------------------------------
+
+.. method:: DB.truncate(self, table, [restart], [cascade], [only]):
+
+ Empty a table or set of tables
+
+ :param table: the name of the table(s)
+ :type table: str, list or set
+ :param bool restart: whether table sequences should be restarted
+ :param bool cascade: whether referenced tables should also be truncated
+ :param only: whether only parent tables should be truncated
+ :type only: bool or list
+
+This method quickly removes all rows from the given table or set
+of tables. It has the same effect as an unqualified DELETE on each
+table, but since it does not actually scan the tables it is faster.
+Furthermore, it reclaims disk space immediately, rather than requiring
+a subsequent VACUUM operation. This is most useful on large tables.
+
+If *restart* is set to `True`, sequences owned by columns of the truncated
+table(s) are automatically restarted. If *cascade* is set to `True`, it
+also truncates all tables that have foreign-key references to any of
+the named tables. If the parameter *only* is not set to `True`, all the
+descendant tables (if any) will also be truncated. Optionally, a ``*``
+can be specified after the table name to explicitly indicate that
+descendant tables are included. If the parameter *table* is a list,
+the parameter *only* can also be a list of corresponding boolean values.
+
escape_literal -- escape a literal string for use within SQL
------------------------------------------------------------
Modified: trunk/pg.py
==============================================================================
--- trunk/pg.py Fri Jan 15 06:06:51 2016 (r747)
+++ trunk/pg.py Fri Jan 15 09:25:31 2016 (r748)
@@ -577,8 +577,8 @@
raise TypeError('Invalid parameter')
if param == 'all':
if value is not None:
- raise ValueError(
- "A value must ot be set when parameter is 'all'")
+ raise ValueError('A value must ot be specified'
+ " when parameter is 'all'")
params = {'all': None}
break
params[param] = value
@@ -1087,6 +1087,62 @@
res = self.db.query(q, params)
return int(res)
+ def truncate(self, table, restart=False, cascade=False, only=False):
+ """Empty a table or set of tables.
+
+ This method quickly removes all rows from the given table or set
+ of tables. It has the same effect as an unqualified DELETE on each
+ table, but since it does not actually scan the tables it is faster.
+ Furthermore, it reclaims disk space immediately, rather than requiring
+ a subsequent VACUUM operation. This is most useful on large tables.
+
+ If restart is set to True, sequences owned by columns of the truncated
+ table(s) are automatically restarted. If cascade is set to True, it
+ also truncates all tables that have foreign-key references to any of
+ the named tables. If the parameter only is not set to True, all the
+ descendant tables (if any) will also be truncated. Optionally, a '*'
+ can be specified after the table name to explicitly indicate that
+ descendant tables are included.
+ """
+ if isinstance(table, basestring):
+ only = {table: only}
+ table = [table]
+ elif isinstance(table, (list, tuple)):
+ if isinstance(only, (list, tuple)):
+ only = dict(zip(table, only))
+ else:
+ only = dict.fromkeys(table, only)
+ elif isinstance(table, (set, frozenset)):
+ only = dict.fromkeys(table, only)
+ else:
+ raise TypeError('The table must be a string, list or set')
+ if not (restart is None or isinstance(restart, (bool, int))):
+ raise TypeError('Invalid type for the restart option')
+ if not (cascade is None or isinstance(cascade, (bool, int))):
+ raise TypeError('Invalid type for the cascade option')
+ tables = []
+ for t in table:
+ u = only.get(t)
+ if not (u is None or isinstance(u, (bool, int))):
+ raise TypeError('Invalid type for the only option')
+ if t.endswith('*'):
+ if u:
+ raise ValueError(
+ 'Contradictory table name and only options')
+ t = t[:-1].rstrip()
+ t = self._escape_qualified_name(t)
+ if u:
+ t = 'ONLY %s' % t
+ tables.append(t)
+ q = ['TRUNCATE', ', '.join(tables)]
+ if restart:
+ q.append('RESTART IDENTITY')
+ if cascade:
+ q.append('CASCADE')
+ q = ' '.join(q)
+ self._do_debug(q)
+ return self.query(q)
+
def notification_handler(self, event, callback, arg_dict={}, timeout=None):
"""Get notification handler that will run the given callback."""
return NotificationHandler(self.db, event, callback, arg_dict, timeout)
Modified: trunk/tests/test_classic_dbwrapper.py
==============================================================================
--- trunk/tests/test_classic_dbwrapper.py Fri Jan 15 06:06:51 2016
(r747)
+++ trunk/tests/test_classic_dbwrapper.py Fri Jan 15 09:25:31 2016
(r748)
@@ -107,7 +107,7 @@
'savepoint', 'server_version',
'set_notice_receiver', 'set_parameter',
'source', 'start', 'status',
- 'transaction',
+ 'transaction', 'truncate',
'unescape_bytea', 'update', 'upsert',
'use_regtypes', 'user',
]
@@ -1514,6 +1514,208 @@
r = query('select count(*) from "%s"' % table).getresult()
self.assertEqual(r[0][0], 0)
+ def testTruncate(self):
+ truncate = self.db.truncate
+ self.assertRaises(TypeError, truncate, None)
+ self.assertRaises(TypeError, truncate, 42)
+ self.assertRaises(TypeError, truncate, dict(test_table=None))
+ query = self.db.query
+ query("drop table if exists test_table")
+ self.addCleanup(query, "drop table test_table")
+ query("create table test_table (n smallint)")
+ for i in range(3):
+ query("insert into test_table values (1)")
+ q = "select count(*) from test_table"
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 3)
+ truncate('test_table')
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 0)
+ for i in range(3):
+ query("insert into test_table values (1)")
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 3)
+ truncate('public.test_table')
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 0)
+ query("drop table if exists test_table_2")
+ self.addCleanup(query, "drop table test_table_2")
+ query('create table test_table_2 (n smallint)')
+ for t in (list, tuple, set):
+ for i in range(3):
+ query("insert into test_table values (1)")
+ query("insert into test_table_2 values (2)")
+ q = ("select (select count(*) from test_table),"
+ " (select count(*) from test_table_2)")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 3))
+ truncate(t(['test_table', 'test_table_2']))
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0))
+
+ def testTruncateRestart(self):
+ truncate = self.db.truncate
+ self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
+ query = self.db.query
+ query("drop table if exists test_table")
+ self.addCleanup(query, "drop table test_table")
+ query("create table test_table (n serial, t text)")
+ for n in range(3):
+ query("insert into test_table (t) values ('test')")
+ q = "select count(n), min(n), max(n) from test_table"
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 1, 3))
+ truncate('test_table')
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, None, None))
+ for n in range(3):
+ query("insert into test_table (t) values ('test')")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 4, 6))
+ truncate('test_table', restart=True)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, None, None))
+ for n in range(3):
+ query("insert into test_table (t) values ('test')")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 1, 3))
+
+ def testTruncateCascade(self):
+ truncate = self.db.truncate
+ self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
+ query = self.db.query
+ query("drop table if exists test_child")
+ query("drop table if exists test_parent")
+ self.addCleanup(query, "drop table test_parent")
+ query("create table test_parent (n smallint primary key)")
+ self.addCleanup(query, "drop table test_child")
+ query("create table test_child ("
+ " n smallint primary key references test_parent (n))")
+ for n in range(3):
+ query("insert into test_parent (n) values (%d)" % n)
+ query("insert into test_child (n) values (%d)" % n)
+ q = ("select (select count(*) from test_parent),"
+ " (select count(*) from test_child)")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 3))
+ self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
+ truncate(['test_parent', 'test_child'])
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0))
+ for n in range(3):
+ query("insert into test_parent (n) values (%d)" % n)
+ query("insert into test_child (n) values (%d)" % n)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 3))
+ truncate('test_parent', cascade=True)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0))
+ for n in range(3):
+ query("insert into test_parent (n) values (%d)" % n)
+ query("insert into test_child (n) values (%d)" % n)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 3))
+ truncate('test_child')
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 0))
+ self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
+ truncate('test_parent', cascade=True)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0))
+
+ def testTruncateOnly(self):
+ truncate = self.db.truncate
+ self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
+ query = self.db.query
+ query("drop table if exists test_child")
+ query("drop table if exists test_parent")
+ self.addCleanup(query, "drop table test_parent")
+ query("create table test_parent (n smallint)")
+ self.addCleanup(query, "drop table test_child")
+ query("create table test_child ("
+ " m smallint) inherits (test_parent)")
+ for n in range(3):
+ query("insert into test_parent (n) values (1)")
+ query("insert into test_child (n, m) values (2, 3)")
+ q = ("select (select count(*) from test_parent),"
+ " (select count(*) from test_child)")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (6, 3))
+ truncate('test_parent')
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0))
+ for n in range(3):
+ query("insert into test_parent (n) values (1)")
+ query("insert into test_child (n, m) values (2, 3)")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (6, 3))
+ truncate('test_parent*')
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0))
+ for n in range(3):
+ query("insert into test_parent (n) values (1)")
+ query("insert into test_child (n, m) values (2, 3)")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (6, 3))
+ truncate('test_parent', only=True)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (3, 3))
+ truncate('test_parent', only=False)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0))
+ self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
+ truncate('test_parent*', only=False)
+ query("drop table if exists test_parent_2")
+ self.addCleanup(query, "drop table test_parent_2")
+ query("create table test_parent_2 (n smallint)")
+ query("drop table if exists test_child_2")
+ self.addCleanup(query, "drop table test_child_2")
+ query("create table test_child_2 ("
+ " m smallint) inherits (test_parent_2)")
+ for n in range(3):
+ query("insert into test_parent (n) values (1)")
+ query("insert into test_child (n, m) values (2, 3)")
+ query("insert into test_parent_2 (n) values (1)")
+ query("insert into test_child_2 (n, m) values (2, 3)")
+ q = ("select (select count(*) from test_parent),"
+ " (select count(*) from test_child),"
+ " (select count(*) from test_parent_2),"
+ " (select count(*) from test_child_2)")
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (6, 3, 6, 3))
+ truncate(['test_parent', 'test_parent_2'], only=[False, True])
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0, 3, 3))
+ truncate(['test_parent', 'test_parent_2'], only=False)
+ r = query(q).getresult()[0]
+ self.assertEqual(r, (0, 0, 0, 0))
+ self.assertRaises(ValueError, truncate,
+ ['test_parent*', 'test_child'], only=[True, False])
+ truncate(['test_parent*', 'test_child'], only=[False, True])
+
+ def testTruncateQuoted(self):
+ truncate = self.db.truncate
+ query = self.db.query
+ table = "test table for truncate()"
+ query('drop table if exists "%s"' % table)
+ self.addCleanup(query, 'drop table "%s"' % table)
+ query('create table "%s" (n smallint)' % table)
+ for i in range(3):
+ query('insert into "%s" values (1)' % table)
+ q = 'select count(*) from "%s"' % table
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 3)
+ truncate(table)
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 0)
+ for i in range(3):
+ query('insert into "%s" values (1)' % table)
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 3)
+ truncate('public."%s"' % table)
+ r = query(q).getresult()[0][0]
+ self.assertEqual(r, 0)
+
def testTransaction(self):
query = self.db.query
query("drop table if exists test_table")
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql