Repository: incubator-airflow Updated Branches: refs/heads/master 6c3c8f445 -> 3d8c3db7f
[AIRFLOW-2638] dbapi_hook: support REPLACE INTO Sometimes, it's desirable to be able to use REPLACE INTO instead of INSERT INTO for insert_rows method (if importing the same data multiple times). This adds an optional parameter to the insert_rows column that flips the generated sql statement from "INSERT INTO" to "REPLACE INTO". Closes #3517 from flokli/dbapi_hook-replace Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/3d8c3db7 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/3d8c3db7 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/3d8c3db7 Branch: refs/heads/master Commit: 3d8c3db7f1f02b0737a23ef79eeb6c4fd0abcef7 Parents: 6c3c8f4 Author: Florian Klink <[email protected]> Authored: Tue Jun 19 10:04:36 2018 +0200 Committer: Fokko Driesprong <[email protected]> Committed: Tue Jun 19 10:04:36 2018 +0200 ---------------------------------------------------------------------- airflow/hooks/dbapi_hook.py | 11 +++++++++-- tests/hooks/test_dbapi_hook.py | 39 ++++++++++++++++++++++++++----------- 2 files changed, 37 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3d8c3db7/airflow/hooks/dbapi_hook.py ---------------------------------------------------------------------- diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py index 358360d..5b50ade 100644 --- a/airflow/hooks/dbapi_hook.py +++ b/airflow/hooks/dbapi_hook.py @@ -205,7 +205,8 @@ class DbApiHook(BaseHook): """ return self.get_conn().cursor() - def insert_rows(self, table, rows, target_fields=None, commit_every=1000): + def insert_rows(self, table, rows, target_fields=None, commit_every=1000, + replace=False): """ A generic way to insert a set of tuples into a table, a new transaction is created every commit_every rows @@ -219,6 +220,8 @@ class DbApiHook(BaseHook): :param commit_every: The maximum number of rows to insert in one transaction. Set to 0 to insert all rows in one transaction. :type commit_every: int + :param replace: Whether to replace instead of insert + :type replace: bool """ if target_fields: target_fields = ", ".join(target_fields) @@ -239,7 +242,11 @@ class DbApiHook(BaseHook): lst.append(self._serialize_cell(cell, conn)) values = tuple(lst) placeholders = ["%s", ] * len(values) - sql = "INSERT INTO {0} {1} VALUES ({2})".format( + if not replace: + sql = "INSERT INTO " + else: + sql = "REPLACE INTO " + sql += "{0} {1} VALUES ({2})".format( table, target_fields, ",".join(placeholders)) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3d8c3db7/tests/hooks/test_dbapi_hook.py ---------------------------------------------------------------------- diff --git a/tests/hooks/test_dbapi_hook.py b/tests/hooks/test_dbapi_hook.py index c3ae187..3484ee9 100644 --- a/tests/hooks/test_dbapi_hook.py +++ b/tests/hooks/test_dbapi_hook.py @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -46,36 +46,36 @@ class TestDbApiHook(unittest.TestCase): statement = "SQL" rows = [("hello",), ("world",)] - + self.cur.fetchall.return_value = rows - + self.assertEqual(rows, self.db_hook.get_records(statement)) - + self.conn.close.assert_called_once() self.cur.close.assert_called_once() self.cur.execute.assert_called_once_with(statement) - + def test_get_records_parameters(self): statement = "SQL" parameters = ["X", "Y", "Z"] rows = [("hello",), ("world",)] - + self.cur.fetchall.return_value = rows self.assertEqual(rows, self.db_hook.get_records(statement, parameters)) - + self.conn.close.assert_called_once() self.cur.close.assert_called_once() self.cur.execute.assert_called_once_with(statement, parameters) - + def test_get_records_exception(self): statement = "SQL" self.cur.fetchall.side_effect = RuntimeError('Great Problems') - + with self.assertRaises(RuntimeError): self.db_hook.get_records(statement) - + self.conn.close.assert_called_once() self.cur.close.assert_called_once() self.cur.execute.assert_called_once_with(statement) @@ -97,6 +97,23 @@ class TestDbApiHook(unittest.TestCase): for row in rows: self.cur.execute.assert_any_call(sql, row) + def test_insert_rows_replace(self): + table = "table" + rows = [("hello",), + ("world",)] + + self.db_hook.insert_rows(table, rows, replace=True) + + self.conn.close.assert_called_once() + self.cur.close.assert_called_once() + + commit_count = 2 # The first and last commit + self.assertEqual(commit_count, self.conn.commit.call_count) + + sql = "REPLACE INTO {} VALUES (%s)".format(table) + for row in rows: + self.cur.execute.assert_any_call(sql, row) + def test_insert_rows_target_fields(self): table = "table" rows = [("hello",),
