This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 636625f Add insert_args for support transfer replace (#15825)
636625f is described below
commit 636625fdb99e6b7beb1375c5df52b06c09e6bafb
Author: Jun <[email protected]>
AuthorDate: Tue Jul 20 03:04:32 2021 +0800
Add insert_args for support transfer replace (#15825)
---
airflow/operators/generic_transfer.py | 6 +++-
tests/operators/test_generic_transfer.py | 47 ++++++++++++++++++++++++++++++++
2 files changed, 52 insertions(+), 1 deletion(-)
diff --git a/airflow/operators/generic_transfer.py
b/airflow/operators/generic_transfer.py
index 55f6384..1bdfa79 100644
--- a/airflow/operators/generic_transfer.py
+++ b/airflow/operators/generic_transfer.py
@@ -41,6 +41,8 @@ class GenericTransfer(BaseOperator):
:param preoperator: sql statement or list of statements to be
executed prior to loading the data. (templated)
:type preoperator: str or list[str]
+ :param insert_args: extra params for `insert_rows` method.
+ :type insert_args: dict
"""
template_fields = ('sql', 'destination_table', 'preoperator')
@@ -59,6 +61,7 @@ class GenericTransfer(BaseOperator):
source_conn_id: str,
destination_conn_id: str,
preoperator: Optional[Union[str, List[str]]] = None,
+ insert_args: Optional[dict] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -67,6 +70,7 @@ class GenericTransfer(BaseOperator):
self.source_conn_id = source_conn_id
self.destination_conn_id = destination_conn_id
self.preoperator = preoperator
+ self.insert_args = insert_args or {}
def execute(self, context):
source_hook = BaseHook.get_hook(self.source_conn_id)
@@ -82,4 +86,4 @@ class GenericTransfer(BaseOperator):
destination_hook.run(self.preoperator)
self.log.info("Inserting rows into %s", self.destination_conn_id)
- destination_hook.insert_rows(table=self.destination_table,
rows=results)
+ destination_hook.insert_rows(table=self.destination_table,
rows=results, **self.insert_args)
diff --git a/tests/operators/test_generic_transfer.py
b/tests/operators/test_generic_transfer.py
index 4780b41..2ce165b 100644
--- a/tests/operators/test_generic_transfer.py
+++ b/tests/operators/test_generic_transfer.py
@@ -18,6 +18,7 @@
import unittest
from contextlib import closing
+from unittest import mock
import pytest
from parameterized import parameterized
@@ -73,6 +74,27 @@ class TestMySql(unittest.TestCase):
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+ @mock.patch('airflow.hooks.dbapi.DbApiHook.insert_rows')
+ def test_mysql_to_mysql_replace(self, mock_insert):
+ sql = "SELECT * FROM connection LIMIT 10;"
+ op = GenericTransfer(
+ task_id='test_m2m',
+ preoperator=[
+ "DROP TABLE IF EXISTS test_mysql_to_mysql",
+ "CREATE TABLE IF NOT EXISTS test_mysql_to_mysql LIKE
connection",
+ ],
+ source_conn_id='airflow_db',
+ destination_conn_id='airflow_db',
+ destination_table="test_mysql_to_mysql",
+ sql=sql,
+ dag=self.dag,
+ insert_args={'replace': True},
+ )
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+ assert mock_insert.called
+ _, kwargs = mock_insert.call_args
+ assert 'replace' in kwargs
+
@pytest.mark.backend("postgres")
class TestPostgres(unittest.TestCase):
@@ -103,3 +125,28 @@ class TestPostgres(unittest.TestCase):
dag=self.dag,
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+
+ @mock.patch('airflow.hooks.dbapi.DbApiHook.insert_rows')
+ def test_postgres_to_postgres_replace(self, mock_insert):
+ sql = "SELECT id, conn_id, conn_type FROM connection LIMIT 10;"
+ op = GenericTransfer(
+ task_id='test_p2p',
+ preoperator=[
+ "DROP TABLE IF EXISTS test_postgres_to_postgres",
+ "CREATE TABLE IF NOT EXISTS test_postgres_to_postgres (LIKE
connection INCLUDING INDEXES)",
+ ],
+ source_conn_id='postgres_default',
+ destination_conn_id='postgres_default',
+ destination_table="test_postgres_to_postgres",
+ sql=sql,
+ dag=self.dag,
+ insert_args={
+ 'replace': True,
+ 'target_fields': ('id', 'conn_id', 'conn_type'),
+ 'replace_index': 'id',
+ },
+ )
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+ assert mock_insert.called
+ _, kwargs = mock_insert.call_args
+ assert 'replace' in kwargs