This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 636625f  Add insert_args for support transfer replace (#15825)
636625f is described below

commit 636625fdb99e6b7beb1375c5df52b06c09e6bafb
Author: Jun <[email protected]>
AuthorDate: Tue Jul 20 03:04:32 2021 +0800

    Add insert_args for support transfer replace (#15825)
---
 airflow/operators/generic_transfer.py    |  6 +++-
 tests/operators/test_generic_transfer.py | 47 ++++++++++++++++++++++++++++++++
 2 files changed, 52 insertions(+), 1 deletion(-)

diff --git a/airflow/operators/generic_transfer.py 
b/airflow/operators/generic_transfer.py
index 55f6384..1bdfa79 100644
--- a/airflow/operators/generic_transfer.py
+++ b/airflow/operators/generic_transfer.py
@@ -41,6 +41,8 @@ class GenericTransfer(BaseOperator):
     :param preoperator: sql statement or list of statements to be
         executed prior to loading the data. (templated)
     :type preoperator: str or list[str]
+    :param insert_args: extra params for `insert_rows` method.
+    :type insert_args: dict
     """
 
     template_fields = ('sql', 'destination_table', 'preoperator')
@@ -59,6 +61,7 @@ class GenericTransfer(BaseOperator):
         source_conn_id: str,
         destination_conn_id: str,
         preoperator: Optional[Union[str, List[str]]] = None,
+        insert_args: Optional[dict] = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -67,6 +70,7 @@ class GenericTransfer(BaseOperator):
         self.source_conn_id = source_conn_id
         self.destination_conn_id = destination_conn_id
         self.preoperator = preoperator
+        self.insert_args = insert_args or {}
 
     def execute(self, context):
         source_hook = BaseHook.get_hook(self.source_conn_id)
@@ -82,4 +86,4 @@ class GenericTransfer(BaseOperator):
             destination_hook.run(self.preoperator)
 
         self.log.info("Inserting rows into %s", self.destination_conn_id)
-        destination_hook.insert_rows(table=self.destination_table, 
rows=results)
+        destination_hook.insert_rows(table=self.destination_table, 
rows=results, **self.insert_args)
diff --git a/tests/operators/test_generic_transfer.py 
b/tests/operators/test_generic_transfer.py
index 4780b41..2ce165b 100644
--- a/tests/operators/test_generic_transfer.py
+++ b/tests/operators/test_generic_transfer.py
@@ -18,6 +18,7 @@
 
 import unittest
 from contextlib import closing
+from unittest import mock
 
 import pytest
 from parameterized import parameterized
@@ -73,6 +74,27 @@ class TestMySql(unittest.TestCase):
             )
             op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
 
+    @mock.patch('airflow.hooks.dbapi.DbApiHook.insert_rows')
+    def test_mysql_to_mysql_replace(self, mock_insert):
+        sql = "SELECT * FROM connection LIMIT 10;"
+        op = GenericTransfer(
+            task_id='test_m2m',
+            preoperator=[
+                "DROP TABLE IF EXISTS test_mysql_to_mysql",
+                "CREATE TABLE IF NOT EXISTS test_mysql_to_mysql LIKE 
connection",
+            ],
+            source_conn_id='airflow_db',
+            destination_conn_id='airflow_db',
+            destination_table="test_mysql_to_mysql",
+            sql=sql,
+            dag=self.dag,
+            insert_args={'replace': True},
+        )
+        op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+        assert mock_insert.called
+        _, kwargs = mock_insert.call_args
+        assert 'replace' in kwargs
+
 
 @pytest.mark.backend("postgres")
 class TestPostgres(unittest.TestCase):
@@ -103,3 +125,28 @@ class TestPostgres(unittest.TestCase):
             dag=self.dag,
         )
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+
+    @mock.patch('airflow.hooks.dbapi.DbApiHook.insert_rows')
+    def test_postgres_to_postgres_replace(self, mock_insert):
+        sql = "SELECT id, conn_id, conn_type FROM connection LIMIT 10;"
+        op = GenericTransfer(
+            task_id='test_p2p',
+            preoperator=[
+                "DROP TABLE IF EXISTS test_postgres_to_postgres",
+                "CREATE TABLE IF NOT EXISTS test_postgres_to_postgres (LIKE 
connection INCLUDING INDEXES)",
+            ],
+            source_conn_id='postgres_default',
+            destination_conn_id='postgres_default',
+            destination_table="test_postgres_to_postgres",
+            sql=sql,
+            dag=self.dag,
+            insert_args={
+                'replace': True,
+                'target_fields': ('id', 'conn_id', 'conn_type'),
+                'replace_index': 'id',
+            },
+        )
+        op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+        assert mock_insert.called
+        _, kwargs = mock_insert.call_args
+        assert 'replace' in kwargs

Reply via email to