Author: cito
Date: Mon Jan  4 17:53:44 2016
New Revision: 694

Log:
Add tests for copy_from() and copy_to()

Added:
   trunk/module/tests/test_dbapi20_copy.py   (contents, props changed)

Added: trunk/module/tests/test_dbapi20_copy.py
==============================================================================
--- /dev/null   00:00:00 1970   (empty, because file is newly added)
+++ trunk/module/tests/test_dbapi20_copy.py     Mon Jan  4 17:53:44 2016        
(r694)
@@ -0,0 +1,531 @@
+#! /usr/bin/python
+# -*- coding: utf-8 -*-
+
+"""Test the modern PyGreSQL interface.
+
+Sub-tests for the copy methods.
+
+Contributed by Christoph Zwerschke.
+
+These tests need a database to test against.
+
+"""
+
+try:
+    import unittest2 as unittest  # for Python < 2.7
+except ImportError:
+    import unittest
+
+from collections import Iterable
+
+import pgdb  # the module under test
+
+# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
+# get our information from that.  Otherwise we use the defaults.
+# The current user must have create schema privilege on the database.
+dbname = 'unittest'
+dbhost = None
+dbport = 5432
+
+try:
+    from .LOCAL_PyGreSQL import *
+except (ImportError, ValueError):
+    try:
+        from LOCAL_PyGreSQL import *
+    except ImportError:
+        pass
+
+try:
+    unicode
+except NameError:  # Python >= 3.0
+    unicode = str
+
+
+class InputStream:
+
+    def __init__(self, data):
+        if isinstance(data, unicode):
+            data = data.encode('utf-8')
+        self.data = data or b''
+        self.sizes = []
+
+    def __str__(self):
+        data = self.data
+        if str is unicode:
+            data = data.decode('utf-8')
+        return data
+
+    def __len__(self):
+        return len(self.data)
+
+    def read(self, size=None):
+        if size is None:
+            output, data = self.data, b''
+        else:
+            output, data = self.data[:size], self.data[size:]
+        self.data = data
+        self.sizes.append(size)
+        return output
+
+
+class OutputStream:
+
+    def __init__(self):
+        self.data = b''
+        self.sizes = []
+
+    def __str__(self):
+        data = self.data
+        if str is unicode:
+            data = data.decode('utf-8')
+        return data
+
+    def __len__(self):
+        return len(self.data)
+
+    def write(self, data):
+        if isinstance(data, unicode):
+            data = data.encode('utf-8')
+        self.data += data
+        self.sizes.append(len(data))
+
+
+class TestStreams(unittest.TestCase):
+
+    def test_input(self):
+        stream = InputStream('Hello, Wörld!')
+        self.assertIsInstance(stream.data, bytes)
+        self.assertEqual(stream.data, b'Hello, W\xc3\xb6rld!')
+        self.assertIsInstance(str(stream), str)
+        self.assertEqual(str(stream), 'Hello, Wörld!')
+        self.assertEqual(len(stream), 14)
+        self.assertEqual(stream.read(3), b'Hel')
+        self.assertEqual(stream.read(2), b'lo')
+        self.assertEqual(stream.read(1), b',')
+        self.assertEqual(stream.read(1), b' ')
+        self.assertEqual(stream.read(), b'W\xc3\xb6rld!')
+        self.assertEqual(stream.read(), b'')
+        self.assertEqual(len(stream), 0)
+        self.assertEqual(stream.sizes, [3, 2, 1, 1, None, None])
+
+    def test_output(self):
+        stream = OutputStream()
+        self.assertEqual(len(stream), 0)
+        for chunk in 'Hel', 'lo', ',', ' ', 'Wörld!':
+            stream.write(chunk)
+        self.assertIsInstance(stream.data, bytes)
+        self.assertEqual(stream.data, b'Hello, W\xc3\xb6rld!')
+        self.assertIsInstance(str(stream), str)
+        self.assertEqual(str(stream), 'Hello, Wörld!')
+        self.assertEqual(len(stream), 14)
+        self.assertEqual(stream.sizes, [3, 2, 1, 1, 7])
+
+
+class TestCopy(unittest.TestCase):
+
+    @staticmethod
+    def connect():
+        return pgdb.connect(database=dbname,
+            host='%s:%d' % (dbhost or '', dbport or -1))
+
+    @classmethod
+    def setUpClass(cls):
+        con = cls.connect()
+        cur = con.cursor()
+        cur.execute("set client_min_messages=warning")
+        cur.execute("drop table if exists copytest cascade")
+        cur.execute("create table copytest ("
+            "id smallint primary key, name varchar(64))")
+        cur.close()
+        con.commit()
+        con.close()
+
+    @classmethod
+    def tearDownClass(cls):
+        con = cls.connect()
+        cur = con.cursor()
+        cur.execute("set client_min_messages=warning")
+        cur.execute("drop table if exists copytest cascade")
+        con.commit()
+        con.close()
+
+    def setUp(self):
+        self.con = self.connect()
+        self.cursor = self.con.cursor()
+        self.cursor.execute("set client_encoding=utf8")
+
+    def tearDown(self):
+        try:
+            self.cursor.close()
+        except Exception:
+            pass
+        try:
+            self.con.rollback()
+        except Exception:
+            pass
+        try:
+            self.con.close()
+        except Exception:
+            pass
+
+    data = [(1935, 'Luciano Pavarotti'),
+            (1941, 'Plácido Domingo'),
+            (1946, 'José Carreras')]
+
+    @property
+    def data_text(self):
+        return ''.join('%d\t%s\n' % row for row in self.data)
+
+    @property
+    def data_csv(self):
+        return ''.join('%d,%s\n' % row for row in self.data)
+
+    def truncate_table(self):
+        self.cursor.execute("truncate table copytest")
+
+    @property
+    def table_data(self):
+        self.cursor.execute("select * from copytest")
+        return self.cursor.fetchall()
+
+    def check_table(self):
+        self.assertEqual(self.table_data, self.data)
+
+
+class TestCopyFrom(TestCopy):
+    """Test the copy_from method."""
+
+    def tearDown(self):
+        super(TestCopyFrom, self).tearDown()
+        self.setUp()
+        self.truncate_table()
+        super(TestCopyFrom, self).tearDown()
+
+    def copy_from(self, stream, **options):
+        return self.cursor.copy_from(stream, 'copytest', **options)
+
+    @property
+    def data_file(self):
+        return InputStream(self.data_text)
+
+    def test_bad_params(self):
+        call = self.cursor.copy_from
+        call('0\t', 'copytest'), self.cursor
+        call('1\t', 'copytest',
+             format='text', sep='\t', null='', columns=['id', 'name'])
+        self.assertRaises(TypeError, call)
+        self.assertRaises(TypeError, call, '0\t')
+        self.assertRaises(TypeError, call, '0\t', 42)
+        self.assertRaises(TypeError, call, '0\t', ['copytest'])
+        self.assertRaises(TypeError, call, '0\t', 'copytest', format=42)
+        self.assertRaises(ValueError, call, '0\t', 'copytest', format='bad')
+        self.assertRaises(TypeError, call, '0\t', 'copytest', sep=42)
+        self.assertRaises(ValueError, call, '0\t', 'copytest', sep='bad')
+        self.assertRaises(TypeError, call, '0\t', 'copytest', null=42)
+        self.assertRaises(ValueError, call, '0\t', 'copytest', size='bad')
+        self.assertRaises(TypeError, call, '0\t', 'copytest', columns=42)
+
+    def test_input_string(self):
+        ret = self.copy_from('42\tHello, world!')
+        self.assertIs(ret, self.cursor)
+        self.assertEqual(self.table_data, [(42, 'Hello, world!')])
+
+    def test_input_string_with_newline(self):
+        self.copy_from('42\tHello, world!\n')
+        self.assertEqual(self.table_data, [(42, 'Hello, world!')])
+
+    def test_input_string_multiple_rows(self):
+        ret = self.copy_from(self.data_text)
+        self.assertIs(ret, self.cursor)
+        self.check_table()
+
+    if str is unicode:
+
+        def test_input_bytes(self):
+            self.copy_from(b'42\tHello, world!')
+            self.assertEqual(self.table_data, [(42, 'Hello, world!')])
+            self.truncate_table()
+            self.copy_from(self.data_text.encode('utf-8'))
+            self.check_table()
+
+    if str is not unicode:
+
+        def test_input_unicode(self):
+            self.copy_from(u'43\tWürstel, Käse!')
+            self.assertEqual(self.table_data, [(43, 'Würstel, Käse!')])
+            self.truncate_table()
+            self.copy_from(self.data_text.decode('utf-8'))
+            self.check_table()
+
+    def test_input_iterable(self):
+        self.copy_from(self.data_text.splitlines())
+        self.check_table()
+
+    def test_input_iterable_with_newlines(self):
+        self.copy_from('%s\n' % row for row in self.data_text.splitlines())
+        self.check_table()
+
+    def test_sep(self):
+        stream = ('%d-%s' % row for row in self.data)
+        self.copy_from(stream, sep='-')
+        self.check_table()
+
+    def test_null(self):
+        self.copy_from('0\t\\N')
+        self.assertEqual(self.table_data, [(0, None)])
+        self.assertIsNone(self.table_data[0][1])
+        self.truncate_table()
+        self.copy_from('1\tNix')
+        self.assertEqual(self.table_data, [(1, 'Nix')])
+        self.assertIsNotNone(self.table_data[0][1])
+        self.truncate_table()
+        self.copy_from('2\tNix', null='Nix')
+        self.assertEqual(self.table_data, [(2, None)])
+        self.assertIsNone(self.table_data[0][1])
+        self.truncate_table()
+        self.copy_from('3\t')
+        self.assertEqual(self.table_data, [(3, '')])
+        self.assertIsNotNone(self.table_data[0][1])
+        self.truncate_table()
+        self.copy_from('4\t', null='')
+        self.assertEqual(self.table_data, [(4, None)])
+        self.assertIsNone(self.table_data[0][1])
+
+    def test_columns(self):
+        self.copy_from('1', columns='id')
+        self.copy_from('2', columns=['id'])
+        self.copy_from('3\tThree')
+        self.copy_from('4\tFour', columns='id, name')
+        self.copy_from('5\tFive', columns=['id', 'name'])
+        self.assertEqual(self.table_data, [
+            (1, None), (2, None), (3, 'Three'), (4, 'Four'), (5, 'Five')])
+        self.assertRaises(pgdb.ProgrammingError, self.copy_from,
+            '6\t42', columns=['id', 'age'])
+
+    def test_csv(self):
+        self.copy_from(self.data_csv, format='csv')
+        self.check_table()
+
+    def test_csv_with_sep(self):
+        stream = ('%d;"%s"\n' % row for row in self.data)
+        self.copy_from(stream, format='csv', sep=';')
+        self.check_table()
+
+    def test_binary(self):
+        self.assertRaises(IOError, self.copy_from,
+            b'NOPGCOPY\n', format='binary')
+
+    def test_binary_with_sep(self):
+        self.assertRaises(ValueError, self.copy_from,
+            '', format='binary', sep='\t')
+
+    def test_binary_with_unicode(self):
+        self.assertRaises(ValueError, self.copy_from, u'', format='binary')
+
+    def test_query(self):
+        self.assertRaises(ValueError, self.cursor.copy_from, '', "select null")
+
+    def test_file(self):
+        stream = self.data_file
+        ret = self.copy_from(stream)
+        self.assertIs(ret, self.cursor)
+        self.check_table()
+        self.assertEqual(len(stream), 0)
+        self.assertEqual(stream.sizes, [8192])
+
+    def test_size_positive(self):
+        stream = self.data_file
+        size = 7
+        num_chunks = (len(stream) + size - 1) // size
+        self.copy_from(stream, size=size)
+        self.check_table()
+        self.assertEqual(len(stream), 0)
+        self.assertEqual(stream.sizes, [size] * num_chunks)
+
+    def test_size_negative(self):
+        stream = self.data_file
+        self.copy_from(stream, size=-1)
+        self.check_table()
+        self.assertEqual(len(stream), 0)
+        self.assertEqual(stream.sizes, [None])
+
+
+class TestCopyTo(TestCopy):
+    """Test the copy_to method."""
+
+    @classmethod
+    def setUpClass(cls):
+        super(TestCopyTo, cls).setUpClass()
+        con = cls.connect()
+        cur = con.cursor()
+        cur.execute("insert into copytest values (%d, %s)", cls.data)
+        cur.close()
+        con.commit()
+        con.close()
+
+    def copy_to(self, stream=None, **options):
+        return self.cursor.copy_to(stream, 'copytest', **options)
+
+    @property
+    def data_file(self):
+        return OutputStream()
+
+    def test_bad_params(self):
+        call = self.cursor.copy_to
+        call(None, 'copytest')
+        call(None, 'copytest',
+             format='text', sep='\t', null='', columns=['id', 'name'])
+        self.assertRaises(TypeError, call)
+        self.assertRaises(TypeError, call, None)
+        self.assertRaises(TypeError, call, None, 42)
+        self.assertRaises(TypeError, call, None, ['copytest'])
+        self.assertRaises(TypeError, call, 'bad', 'copytest')
+        self.assertRaises(TypeError, call, None, 'copytest', format=42)
+        self.assertRaises(ValueError, call, None, 'copytest', format='bad')
+        self.assertRaises(TypeError, call, None, 'copytest', sep=42)
+        self.assertRaises(ValueError, call, None, 'copytest', sep='bad')
+        self.assertRaises(TypeError, call, None, 'copytest', null=42)
+        self.assertRaises(TypeError, call, None, 'copytest', decode='bad')
+        self.assertRaises(TypeError, call, None, 'copytest', columns=42)
+
+    def test_generator(self):
+        ret = self.copy_to()
+        self.assertIsInstance(ret, Iterable)
+        rows = list(ret)
+        self.assertEqual(len(rows), 3)
+        rows = ''.join(rows)
+        self.assertIsInstance(rows, str)
+        self.assertEqual(rows, self.data_text)
+
+    if str is unicode:
+
+        def test_generator_bytes(self):
+            ret = self.copy_to(decode=False)
+            self.assertIsInstance(ret, Iterable)
+            rows = list(ret)
+            self.assertEqual(len(rows), 3)
+            rows = b''.join(rows)
+            self.assertIsInstance(rows, bytes)
+            self.assertEqual(rows, self.data_text.encode('utf-8'))
+
+    if str is not unicode:
+
+        def test_generator_unicode(self):
+            ret = self.copy_to(decode=True)
+            self.assertIsInstance(ret, Iterable)
+            rows = list(ret)
+            self.assertEqual(len(rows), 3)
+            rows = ''.join(rows)
+            self.assertIsInstance(rows, unicode)
+            self.assertEqual(rows, self.data_text.decode('utf-8'))
+
+    def test_decode(self):
+        ret_raw = b''.join(self.copy_to(decode=False))
+        ret_decoded = ''.join(self.copy_to(decode=True))
+        self.assertIsInstance(ret_raw, bytes)
+        self.assertIsInstance(ret_decoded, unicode)
+        self.assertEqual(ret_decoded, ret_raw.decode('utf-8'))
+
+    def test_sep(self):
+        ret = list(self.copy_to(sep='-'))
+        self.assertEqual(ret, ['%d-%s\n' % row for row in self.data])
+
+    def test_null(self):
+        data = ['%d\t%s\n' % row for row in self.data]
+        self.cursor.execute('insert into copytest values(4, null)')
+        try:
+            ret = list(self.copy_to())
+            self.assertEqual(ret, data + ['4\t\\N\n'])
+            ret = list(self.copy_to(null='Nix'))
+            self.assertEqual(ret, data + ['4\tNix\n'])
+            ret = list(self.copy_to(null=''))
+            self.assertEqual(ret, data + ['4\t\n'])
+        finally:
+            self.cursor.execute('delete from copytest where id=4')
+
+    def test_columns(self):
+        data_id = ''.join('%d\n' % row[0] for row in self.data)
+        data_name = ''.join('%s\n' % row[1] for row in self.data)
+        ret = ''.join(self.copy_to(columns='id'))
+        self.assertEqual(ret, data_id)
+        ret = ''.join(self.copy_to(columns=['id']))
+        self.assertEqual(ret, data_id)
+        ret = ''.join(self.copy_to(columns='name'))
+        self.assertEqual(ret, data_name)
+        ret = ''.join(self.copy_to(columns=['name']))
+        self.assertEqual(ret, data_name)
+        ret = ''.join(self.copy_to(columns='id, name'))
+        self.assertEqual(ret, self.data_text)
+        ret = ''.join(self.copy_to(columns=['id', 'name']))
+        self.assertEqual(ret, self.data_text)
+        self.assertRaises(pgdb.ProgrammingError, self.copy_to,
+            columns=['id', 'age'])
+
+    def test_csv(self):
+        ret = self.copy_to(format='csv')
+        self.assertIsInstance(ret, Iterable)
+        rows = list(ret)
+        self.assertEqual(len(rows), 3)
+        rows = ''.join(rows)
+        self.assertIsInstance(rows, str)
+        self.assertEqual(rows, self.data_csv)
+
+    def test_csv_with_sep(self):
+        rows = ''.join(self.copy_to(format='csv', sep=';'))
+        self.assertEqual(rows, self.data_csv.replace(',', ';'))
+
+    def test_binary(self):
+        ret = self.copy_to(format='binary')
+        self.assertIsInstance(ret, Iterable)
+        for row in ret:
+            self.assertTrue(row.startswith(b'PGCOPY\n\377\r\n\0'))
+            break
+
+    def test_binary_with_sep(self):
+        self.assertRaises(ValueError, self.copy_to, format='binary', sep='\t')
+
+    def test_binary_with_unicode(self):
+        self.assertRaises(ValueError, self.copy_to,
+            format='binary', decode=True)
+
+    def test_query(self):
+        ret = self.cursor.copy_to(None,
+            "select name||'!' from copytest where id=1941")
+        self.assertIsInstance(ret, Iterable)
+        rows = list(ret)
+        self.assertEqual(len(rows), 1)
+        self.assertIsInstance(rows[0], str)
+        self.assertEqual(rows[0], '%s!\n' % self.data[1][1])
+
+    def test_file(self):
+        stream = self.data_file
+        ret = self.copy_to(stream)
+        self.assertIs(ret, self.cursor)
+        self.assertEqual(str(stream), self.data_text)
+        data = self.data_text
+        if str is unicode:
+            data = data.encode('utf-8')
+        sizes = [len(row) + 1 for row in data.splitlines()]
+        self.assertEqual(stream.sizes, sizes)
+
+
+class TestBinary(TestCopy):
+    """Test the copy_from and copy_to methods with binary data."""
+
+    def test_round_trip(self):
+        # fill table from textual data
+        self.cursor.copy_from(self.data_text, 'copytest', format='text')
+        self.check_table()
+        # get data back in binary format
+        ret = self.cursor.copy_to(None, 'copytest', format='binary')
+        self.assertIsInstance(ret, Iterable)
+        data_binary = b''.join(ret)
+        self.assertTrue(data_binary.startswith(b'PGCOPY\n\377\r\n\0'))
+        self.truncate_table()
+        # fill table from binary data
+        self.cursor.copy_from(data_binary, 'copytest', format='binary')
+        self.check_table()
+
+
+if __name__ == '__main__':
+    unittest.main()
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to