Repository: incubator-airflow Updated Branches: refs/heads/master ed9329017 -> 65b6ceae7
[AIRFLOW-2234] Enable insert_rows for PrestoHook PrestoHook.insert_rows() raises NotImplementedError for now. But Presto 0.126+ allows specifying column names in INSERT queries, so we can leverage DbApiHook.insert_rows() almost as is. This PR enables this function. Closes #3146 from sekikn/AIRFLOW-2234 Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/65b6ceae Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/65b6ceae Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/65b6ceae Branch: refs/heads/master Commit: 65b6ceae74c166efe95113ad5aa55004e2ad25c5 Parents: ed93290 Author: Kengo Seki <[email protected]> Authored: Mon Apr 23 19:01:38 2018 +0200 Committer: Bolke de Bruin <[email protected]> Committed: Mon Apr 23 19:01:38 2018 +0200 ---------------------------------------------------------------------- airflow/hooks/dbapi_hook.py | 4 +-- airflow/hooks/presto_hook.py | 17 ++++++++-- tests/hooks/test_dbapi_hook.py | 64 +++++++++++++++++++++++++++++++++--- tests/hooks/test_presto_hook.py | 49 +++++++++++++++++++++++++++ 4 files changed, 125 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/65b6ceae/airflow/hooks/dbapi_hook.py ---------------------------------------------------------------------- diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py index a9f4e43..de0a3a3 100644 --- a/airflow/hooks/dbapi_hook.py +++ b/airflow/hooks/dbapi_hook.py @@ -213,8 +213,8 @@ class DbApiHook(BaseHook): for cell in row: l.append(self._serialize_cell(cell, conn)) values = tuple(l) - placeholders = ["%s",]*len(values) - sql = "INSERT INTO {0} {1} VALUES ({2});".format( + placeholders = ["%s", ] * len(values) + sql = "INSERT INTO {0} {1} VALUES ({2})".format( table, target_fields, ",".join(placeholders)) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/65b6ceae/airflow/hooks/presto_hook.py ---------------------------------------------------------------------- diff --git a/airflow/hooks/presto_hook.py b/airflow/hooks/presto_hook.py index 935a9a5..8920448 100644 --- a/airflow/hooks/presto_hook.py +++ b/airflow/hooks/presto_hook.py @@ -114,5 +114,18 @@ class PrestoHook(DbApiHook): """ return super(PrestoHook, self).run(self._strip_sql(hql), parameters) - def insert_rows(self): - raise NotImplementedError() + # TODO Enable commit_every once PyHive supports transaction. + # Unfortunately, PyHive 0.5.1 doesn't support transaction for now, + # whereas Presto 0.132+ does. + def insert_rows(self, table, rows, target_fields=None): + """ + A generic way to insert a set of tuples into a table. + + :param table: Name of the target table + :type table: str + :param rows: The rows to insert into the table + :type rows: iterable of tuples + :param target_fields: The names of the columns to fill in the table + :type target_fields: iterable of strings + """ + super(PrestoHook, self).insert_rows(table, rows, target_fields, 0) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/65b6ceae/tests/hooks/test_dbapi_hook.py ---------------------------------------------------------------------- diff --git a/tests/hooks/test_dbapi_hook.py b/tests/hooks/test_dbapi_hook.py index 9fcc970..c3ae187 100644 --- a/tests/hooks/test_dbapi_hook.py +++ b/tests/hooks/test_dbapi_hook.py @@ -28,17 +28,18 @@ class TestDbApiHook(unittest.TestCase): def setUp(self): super(TestDbApiHook, self).setUp() - + self.cur = mock.MagicMock() - self.conn = conn = mock.MagicMock() + self.conn = mock.MagicMock() self.conn.cursor.return_value = self.cur - + conn = self.conn + class TestDBApiHook(DbApiHook): conn_name_attr = 'test_conn_id' - + def get_conn(self): return conn - + self.db_hook = TestDBApiHook() def test_get_records(self): @@ -78,3 +79,56 @@ class TestDbApiHook(unittest.TestCase): self.conn.close.assert_called_once() self.cur.close.assert_called_once() self.cur.execute.assert_called_once_with(statement) + + def test_insert_rows(self): + table = "table" + rows = [("hello",), + ("world",)] + + self.db_hook.insert_rows(table, rows) + + self.conn.close.assert_called_once() + self.cur.close.assert_called_once() + + commit_count = 2 # The first and last commit + self.assertEqual(commit_count, self.conn.commit.call_count) + + sql = "INSERT INTO {} VALUES (%s)".format(table) + for row in rows: + self.cur.execute.assert_any_call(sql, row) + + def test_insert_rows_target_fields(self): + table = "table" + rows = [("hello",), + ("world",)] + target_fields = ["field"] + + self.db_hook.insert_rows(table, rows, target_fields) + + self.conn.close.assert_called_once() + self.cur.close.assert_called_once() + + commit_count = 2 # The first and last commit + self.assertEqual(commit_count, self.conn.commit.call_count) + + sql = "INSERT INTO {} ({}) VALUES (%s)".format(table, target_fields[0]) + for row in rows: + self.cur.execute.assert_any_call(sql, row) + + def test_insert_rows_commit_every(self): + table = "table" + rows = [("hello",), + ("world",)] + commit_every = 1 + + self.db_hook.insert_rows(table, rows, commit_every=commit_every) + + self.conn.close.assert_called_once() + self.cur.close.assert_called_once() + + commit_count = 2 + divmod(len(rows), commit_every)[0] + self.assertEqual(commit_count, self.conn.commit.call_count) + + sql = "INSERT INTO {} VALUES (%s)".format(table) + for row in rows: + self.cur.execute.assert_any_call(sql, row) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/65b6ceae/tests/hooks/test_presto_hook.py ---------------------------------------------------------------------- diff --git a/tests/hooks/test_presto_hook.py b/tests/hooks/test_presto_hook.py new file mode 100644 index 0000000..b01782e --- /dev/null +++ b/tests/hooks/test_presto_hook.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import mock +import unittest + +from mock import patch + +from airflow.hooks.presto_hook import PrestoHook + + +class TestPrestoHook(unittest.TestCase): + + def setUp(self): + super(TestPrestoHook, self).setUp() + + self.cur = mock.MagicMock() + self.conn = mock.MagicMock() + self.conn.cursor.return_value = self.cur + conn = self.conn + + class UnitTestPrestoHook(PrestoHook): + conn_name_attr = 'test_conn_id' + + def get_conn(self): + return conn + + self.db_hook = UnitTestPrestoHook() + + @patch('airflow.hooks.dbapi_hook.DbApiHook.insert_rows') + def test_insert_rows(self, mock_insert_rows): + table = "table" + rows = [("hello",), + ("world",)] + target_fields = None + self.db_hook.insert_rows(table, rows, target_fields) + mock_insert_rows.assert_called_once_with(table, rows, None, 0)
