This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v1-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 111ba6938854c333e8dc8c8e23245c285e0ab162 Author: William Tran <[email protected]> AuthorDate: Thu Apr 30 05:16:18 2020 -0400 [AIRFLOW-4734] Upsert functionality for PostgresHook.insert_rows() (#8625) PostgresHook's parent class, DbApiHook, implements upsert in its insert_rows() method with the replace=True flag. However, the underlying generated SQL is specific to MySQL's "REPLACE INTO" syntax and is not applicable to PostgreSQL. This pulls out the sql generation code for insert/upsert out in to a method that is then overridden in the PostgreSQL subclass to generate the "INSERT ... ON CONFLICT DO UPDATE" syntax ("new" since Postgres 9.5) (cherry picked from commit a28c66f23d373cd0f8bfc765a515f21d4b66a0e9) --- airflow/contrib/hooks/bigquery_hook.py | 2 +- airflow/hooks/dbapi_hook.py | 54 ++++++++++++++++++++++--------- airflow/hooks/postgres_hook.py | 54 +++++++++++++++++++++++++++++++ tests/hooks/test_postgres_hook.py | 59 ++++++++++++++++++++++++++++++++++ 4 files changed, 153 insertions(+), 16 deletions(-) diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py index 930d212..07a2ab8 100644 --- a/airflow/contrib/hooks/bigquery_hook.py +++ b/airflow/contrib/hooks/bigquery_hook.py @@ -85,7 +85,7 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook): return build( 'bigquery', 'v2', http=http_authorized, cache_discovery=False) - def insert_rows(self, table, rows, target_fields=None, commit_every=1000): + def insert_rows(self, table, rows, target_fields=None, commit_every=1000, **kwargs): """ Insertion is currently unsupported. Theoretically, you could use BigQuery's streaming API to insert rows into a table, but this hasn't diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py index 218ff83..ac54881 100644 --- a/airflow/hooks/dbapi_hook.py +++ b/airflow/hooks/dbapi_hook.py @@ -211,8 +211,43 @@ class DbApiHook(BaseHook): """ return self.get_conn().cursor() + @staticmethod + def _generate_insert_sql(table, values, target_fields, replace, **kwargs): + """ + Static helper method that generate the INSERT SQL statement. + The REPLACE variant is specific to MySQL syntax. + + :param table: Name of the target table + :type table: str + :param values: The row to insert into the table + :type values: tuple of cell values + :param target_fields: The names of the columns to fill in the table + :type target_fields: iterable of strings + :param replace: Whether to replace instead of insert + :type replace: bool + :return: The generated INSERT or REPLACE SQL statement + :rtype: str + """ + placeholders = ["%s", ] * len(values) + + if target_fields: + target_fields = ", ".join(target_fields) + target_fields = "({})".format(target_fields) + else: + target_fields = '' + + if not replace: + sql = "INSERT INTO " + else: + sql = "REPLACE INTO " + sql += "{0} {1} VALUES ({2})".format( + table, + target_fields, + ",".join(placeholders)) + return sql + def insert_rows(self, table, rows, target_fields=None, commit_every=1000, - replace=False): + replace=False, **kwargs): """ A generic way to insert a set of tuples into a table, a new transaction is created every commit_every rows @@ -229,11 +264,6 @@ class DbApiHook(BaseHook): :param replace: Whether to replace instead of insert :type replace: bool """ - if target_fields: - target_fields = ", ".join(target_fields) - target_fields = "({})".format(target_fields) - else: - target_fields = '' i = 0 with closing(self.get_conn()) as conn: if self.supports_autocommit: @@ -247,15 +277,9 @@ class DbApiHook(BaseHook): for cell in row: lst.append(self._serialize_cell(cell, conn)) values = tuple(lst) - placeholders = ["%s", ] * len(values) - if not replace: - sql = "INSERT INTO " - else: - sql = "REPLACE INTO " - sql += "{0} {1} VALUES ({2})".format( - table, - target_fields, - ",".join(placeholders)) + sql = self._generate_insert_sql( + table, values, target_fields, replace, **kwargs + ) cur.execute(sql, values) if commit_every and i % commit_every == 0: conn.commit() diff --git a/airflow/hooks/postgres_hook.py b/airflow/hooks/postgres_hook.py index a6d6523..4c7a324 100644 --- a/airflow/hooks/postgres_hook.py +++ b/airflow/hooks/postgres_hook.py @@ -177,3 +177,57 @@ class PostgresHook(DbApiHook): client = aws_hook.get_client_type('rds') token = client.generate_db_auth_token(conn.host, port, conn.login) return login, token, port + + @staticmethod + def _generate_insert_sql(table, values, target_fields, replace, **kwargs): + """ + Static helper method that generate the INSERT SQL statement. + The REPLACE variant is specific to MySQL syntax. + + :param table: Name of the target table + :type table: str + :param values: The row to insert into the table + :type values: tuple of cell values + :param target_fields: The names of the columns to fill in the table + :type target_fields: iterable of strings + :param replace: Whether to replace instead of insert + :type replace: bool + :param replace_index: the column or list of column names to act as + index for the ON CONFLICT clause + :type replace_index: str or list + :return: The generated INSERT or REPLACE SQL statement + :rtype: str + """ + placeholders = ["%s", ] * len(values) + replace_index = kwargs.get("replace_index", None) + + if target_fields: + target_fields_fragment = ", ".join(target_fields) + target_fields_fragment = "({})".format(target_fields_fragment) + else: + target_fields_fragment = '' + + sql = "INSERT INTO {0} {1} VALUES ({2})".format( + table, + target_fields_fragment, + ",".join(placeholders)) + + if replace: + if target_fields is None: + raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires column names") + if replace_index is None: + raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires an unique index") + if isinstance(replace_index, str): + replace_index = [replace_index] + replace_index_set = set(replace_index) + + replace_target = [ + "{0} = excluded.{0}".format(col) + for col in target_fields + if col not in replace_index_set + ] + sql += " ON CONFLICT ({0}) DO UPDATE SET {1}".format( + ", ".join(replace_index), + ", ".join(replace_target), + ) + return sql diff --git a/tests/hooks/test_postgres_hook.py b/tests/hooks/test_postgres_hook.py index f706d56..061dca0 100644 --- a/tests/hooks/test_postgres_hook.py +++ b/tests/hooks/test_postgres_hook.py @@ -183,3 +183,62 @@ class TestPostgresHook(unittest.TestCase): results = [line.rstrip().decode("utf-8") for line in f.readlines()] self.assertEqual(sorted(input_data), sorted(results)) + + @pytest.mark.backend("postgres") + def test_insert_rows(self): + table = "table" + rows = [("hello",), + ("world",)] + + self.db_hook.insert_rows(table, rows) + + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 + + commit_count = 2 # The first and last commit + self.assertEqual(commit_count, self.conn.commit.call_count) + + sql = "INSERT INTO {} VALUES (%s)".format(table) + for row in rows: + self.cur.execute.assert_any_call(sql, row) + + @pytest.mark.backend("postgres") + def test_insert_rows_replace(self): + table = "table" + rows = [(1, "hello",), + (2, "world",)] + fields = ("id", "value") + + self.db_hook.insert_rows( + table, rows, fields, replace=True, replace_index=fields[0]) + + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 + + commit_count = 2 # The first and last commit + self.assertEqual(commit_count, self.conn.commit.call_count) + + sql = "INSERT INTO {0} ({1}, {2}) VALUES (%s,%s) " \ + "ON CONFLICT ({1}) DO UPDATE SET {2} = excluded.{2}".format( + table, fields[0], fields[1]) + for row in rows: + self.cur.execute.assert_any_call(sql, row) + + @pytest.mark.xfail + @pytest.mark.backend("postgres") + def test_insert_rows_replace_missing_target_field_arg(self): + table = "table" + rows = [(1, "hello",), + (2, "world",)] + fields = ("id", "value") + self.db_hook.insert_rows( + table, rows, replace=True, replace_index=fields[0]) + + @pytest.mark.xfail + @pytest.mark.backend("postgres") + def test_insert_rows_replace_missing_replace_index_arg(self): + table = "table" + rows = [(1, "hello",), + (2, "world",)] + fields = ("id", "value") + self.db_hook.insert_rows(table, rows, fields, replace=True)
