Repository: incubator-airflow
Updated Branches:
  refs/heads/master ed9329017 -> 65b6ceae7


[AIRFLOW-2234] Enable insert_rows for PrestoHook

PrestoHook.insert_rows() raises
NotImplementedError for now.
But Presto 0.126+ allows specifying column names
in INSERT queries,
so we can leverage DbApiHook.insert_rows() almost
as is.
This PR enables this function.

Closes #3146 from sekikn/AIRFLOW-2234


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/65b6ceae
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/65b6ceae
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/65b6ceae

Branch: refs/heads/master
Commit: 65b6ceae74c166efe95113ad5aa55004e2ad25c5
Parents: ed93290
Author: Kengo Seki <[email protected]>
Authored: Mon Apr 23 19:01:38 2018 +0200
Committer: Bolke de Bruin <[email protected]>
Committed: Mon Apr 23 19:01:38 2018 +0200

----------------------------------------------------------------------
 airflow/hooks/dbapi_hook.py     |  4 +--
 airflow/hooks/presto_hook.py    | 17 ++++++++--
 tests/hooks/test_dbapi_hook.py  | 64 +++++++++++++++++++++++++++++++++---
 tests/hooks/test_presto_hook.py | 49 +++++++++++++++++++++++++++
 4 files changed, 125 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/65b6ceae/airflow/hooks/dbapi_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py
index a9f4e43..de0a3a3 100644
--- a/airflow/hooks/dbapi_hook.py
+++ b/airflow/hooks/dbapi_hook.py
@@ -213,8 +213,8 @@ class DbApiHook(BaseHook):
                     for cell in row:
                         l.append(self._serialize_cell(cell, conn))
                     values = tuple(l)
-                    placeholders = ["%s",]*len(values)
-                    sql = "INSERT INTO {0} {1} VALUES ({2});".format(
+                    placeholders = ["%s", ] * len(values)
+                    sql = "INSERT INTO {0} {1} VALUES ({2})".format(
                         table,
                         target_fields,
                         ",".join(placeholders))

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/65b6ceae/airflow/hooks/presto_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/presto_hook.py b/airflow/hooks/presto_hook.py
index 935a9a5..8920448 100644
--- a/airflow/hooks/presto_hook.py
+++ b/airflow/hooks/presto_hook.py
@@ -114,5 +114,18 @@ class PrestoHook(DbApiHook):
         """
         return super(PrestoHook, self).run(self._strip_sql(hql), parameters)
 
-    def insert_rows(self):
-        raise NotImplementedError()
+    # TODO Enable commit_every once PyHive supports transaction.
+    # Unfortunately, PyHive 0.5.1 doesn't support transaction for now,
+    # whereas Presto 0.132+ does.
+    def insert_rows(self, table, rows, target_fields=None):
+        """
+        A generic way to insert a set of tuples into a table.
+
+        :param table: Name of the target table
+        :type table: str
+        :param rows: The rows to insert into the table
+        :type rows: iterable of tuples
+        :param target_fields: The names of the columns to fill in the table
+        :type target_fields: iterable of strings
+        """
+        super(PrestoHook, self).insert_rows(table, rows, target_fields, 0)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/65b6ceae/tests/hooks/test_dbapi_hook.py
----------------------------------------------------------------------
diff --git a/tests/hooks/test_dbapi_hook.py b/tests/hooks/test_dbapi_hook.py
index 9fcc970..c3ae187 100644
--- a/tests/hooks/test_dbapi_hook.py
+++ b/tests/hooks/test_dbapi_hook.py
@@ -28,17 +28,18 @@ class TestDbApiHook(unittest.TestCase):
 
     def setUp(self):
         super(TestDbApiHook, self).setUp()
-        
+
         self.cur = mock.MagicMock()
-        self.conn = conn = mock.MagicMock()
+        self.conn = mock.MagicMock()
         self.conn.cursor.return_value = self.cur
-        
+        conn = self.conn
+
         class TestDBApiHook(DbApiHook):
             conn_name_attr = 'test_conn_id'
-            
+
             def get_conn(self):
                 return conn
-        
+
         self.db_hook = TestDBApiHook()
 
     def test_get_records(self):
@@ -78,3 +79,56 @@ class TestDbApiHook(unittest.TestCase):
         self.conn.close.assert_called_once()
         self.cur.close.assert_called_once()
         self.cur.execute.assert_called_once_with(statement)
+
+    def test_insert_rows(self):
+        table = "table"
+        rows = [("hello",),
+                ("world",)]
+
+        self.db_hook.insert_rows(table, rows)
+
+        self.conn.close.assert_called_once()
+        self.cur.close.assert_called_once()
+
+        commit_count = 2  # The first and last commit
+        self.assertEqual(commit_count, self.conn.commit.call_count)
+
+        sql = "INSERT INTO {}  VALUES (%s)".format(table)
+        for row in rows:
+            self.cur.execute.assert_any_call(sql, row)
+
+    def test_insert_rows_target_fields(self):
+        table = "table"
+        rows = [("hello",),
+                ("world",)]
+        target_fields = ["field"]
+
+        self.db_hook.insert_rows(table, rows, target_fields)
+
+        self.conn.close.assert_called_once()
+        self.cur.close.assert_called_once()
+
+        commit_count = 2  # The first and last commit
+        self.assertEqual(commit_count, self.conn.commit.call_count)
+
+        sql = "INSERT INTO {} ({}) VALUES (%s)".format(table, target_fields[0])
+        for row in rows:
+            self.cur.execute.assert_any_call(sql, row)
+
+    def test_insert_rows_commit_every(self):
+        table = "table"
+        rows = [("hello",),
+                ("world",)]
+        commit_every = 1
+
+        self.db_hook.insert_rows(table, rows, commit_every=commit_every)
+
+        self.conn.close.assert_called_once()
+        self.cur.close.assert_called_once()
+
+        commit_count = 2 + divmod(len(rows), commit_every)[0]
+        self.assertEqual(commit_count, self.conn.commit.call_count)
+
+        sql = "INSERT INTO {}  VALUES (%s)".format(table)
+        for row in rows:
+            self.cur.execute.assert_any_call(sql, row)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/65b6ceae/tests/hooks/test_presto_hook.py
----------------------------------------------------------------------
diff --git a/tests/hooks/test_presto_hook.py b/tests/hooks/test_presto_hook.py
new file mode 100644
index 0000000..b01782e
--- /dev/null
+++ b/tests/hooks/test_presto_hook.py
@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import mock
+import unittest
+
+from mock import patch
+
+from airflow.hooks.presto_hook import PrestoHook
+
+
+class TestPrestoHook(unittest.TestCase):
+
+    def setUp(self):
+        super(TestPrestoHook, self).setUp()
+
+        self.cur = mock.MagicMock()
+        self.conn = mock.MagicMock()
+        self.conn.cursor.return_value = self.cur
+        conn = self.conn
+
+        class UnitTestPrestoHook(PrestoHook):
+            conn_name_attr = 'test_conn_id'
+
+            def get_conn(self):
+                return conn
+
+        self.db_hook = UnitTestPrestoHook()
+
+    @patch('airflow.hooks.dbapi_hook.DbApiHook.insert_rows')
+    def test_insert_rows(self, mock_insert_rows):
+        table = "table"
+        rows = [("hello",),
+                ("world",)]
+        target_fields = None
+        self.db_hook.insert_rows(table, rows, target_fields)
+        mock_insert_rows.assert_called_once_with(table, rows, None, 0)

Reply via email to