Repository: incubator-airflow
Updated Branches:
  refs/heads/master c5a5ae947 -> c27098b8d


[AIRFLOW-59] Implement bulk_dump and bulk_load for the Postgres hook

This PR implements bulk_dump and bulk_load,
which are inherited from DbApiHook and
already implemented for MySqlHook.

Closes #3456 from sekikn/AIRFLOW-59


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/c27098b8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/c27098b8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/c27098b8

Branch: refs/heads/master
Commit: c27098b8d31fee7177f37108a6c2fb7c7ad37170
Parents: c5a5ae9
Author: Kengo Seki <[email protected]>
Authored: Sat Jun 9 22:14:51 2018 +0200
Committer: Fokko Driesprong <[email protected]>
Committed: Sat Jun 9 22:14:51 2018 +0200

----------------------------------------------------------------------
 airflow/hooks/postgres_hook.py    | 12 ++++++++++
 tests/hooks/test_postgres_hook.py | 43 ++++++++++++++++++++++++++++++++++
 2 files changed, 55 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c27098b8/airflow/hooks/postgres_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/postgres_hook.py b/airflow/hooks/postgres_hook.py
index 0395e70..dd250bf 100644
--- a/airflow/hooks/postgres_hook.py
+++ b/airflow/hooks/postgres_hook.py
@@ -82,6 +82,18 @@ class PostgresHook(DbApiHook):
                     f.truncate(f.tell())
                     conn.commit()
 
+    def bulk_load(self, table, tmp_file):
+        """
+        Loads a tab-delimited file into a database table
+        """
+        self.copy_expert("COPY {table} FROM STDIN".format(table=table), 
tmp_file)
+
+    def bulk_dump(self, table, tmp_file):
+        """
+        Dumps a database table into a tab-delimited file
+        """
+        self.copy_expert("COPY {table} TO STDOUT".format(table=table), 
tmp_file)
+
     @staticmethod
     def _serialize_cell(cell, conn):
         """

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c27098b8/tests/hooks/test_postgres_hook.py
----------------------------------------------------------------------
diff --git a/tests/hooks/test_postgres_hook.py 
b/tests/hooks/test_postgres_hook.py
index 2828f8b..0520239 100644
--- a/tests/hooks/test_postgres_hook.py
+++ b/tests/hooks/test_postgres_hook.py
@@ -21,6 +21,8 @@
 import mock
 import unittest
 
+from tempfile import NamedTemporaryFile
+
 from airflow.hooks.postgres_hook import PostgresHook
 
 
@@ -56,3 +58,44 @@ class TestPostgresHook(unittest.TestCase):
             self.conn.commit.assert_called_once()
             self.cur.copy_expert.assert_called_once_with(statement, 
m.return_value)
             self.assertEqual(m.call_args[0], (filename, "r+"))
+
+    def test_bulk_load(self):
+        hook = PostgresHook()
+        table = "t"
+        input_data = ["foo", "bar", "baz"]
+
+        with hook.get_conn() as conn:
+            with conn.cursor() as cur:
+                cur.execute("DROP TABLE IF EXISTS {}".format(table))
+                cur.execute("CREATE TABLE {} (c VARCHAR)".format(table))
+                conn.commit()
+
+                with NamedTemporaryFile() as f:
+                    f.write("\n".join(input_data).encode("utf-8"))
+                    f.flush()
+                    hook.bulk_load(table, f.name)
+
+                cur.execute("SELECT * FROM {}".format(table))
+                results = [row[0] for row in cur.fetchall()]
+
+        self.assertEqual(sorted(input_data), sorted(results))
+
+    def test_bulk_dump(self):
+        hook = PostgresHook()
+        table = "t"
+        input_data = ["foo", "bar", "baz"]
+
+        with hook.get_conn() as conn:
+            with conn.cursor() as cur:
+                cur.execute("DROP TABLE IF EXISTS {}".format(table))
+                cur.execute("CREATE TABLE {} (c VARCHAR)".format(table))
+                values = ",".join("('{}')".format(data) for data in input_data)
+                cur.execute("INSERT INTO {} VALUES {}".format(table, values))
+                conn.commit()
+
+                with NamedTemporaryFile() as f:
+                    hook.bulk_dump(table, f.name)
+                    f.seek(0)
+                    results = [line.rstrip().decode("utf-8") for line in 
f.readlines()]
+
+        self.assertEqual(sorted(input_data), sorted(results))

Reply via email to