This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4d54cda411 Make conn id parameters templated in GenericTransfer and
also allow passing hook parameters like in BaseSQLOperator (#42891)
4d54cda411 is described below
commit 4d54cda4114125bb671b0bfccddc73b646855a2d
Author: David Blain <[email protected]>
AuthorDate: Thu Oct 24 11:34:56 2024 +0200
Make conn id parameters templated in GenericTransfer and also allow passing
hook parameters like in BaseSQLOperator (#42891)
* refactored: Added hook_params to get_hook BaseHook method, templated
conn_id's for GenericTransfer and also allow passing hook params for each
connection in GenericTransfer
---------
Co-authored-by: David Blain <[email protected]>
---
airflow/hooks/base.py | 5 +-
generated/provider_dependencies.json | 1 +
.../airflow/providers/common/sql/operators/sql.py | 16 +++++-
.../standard}/operators/generic_transfer.py | 35 +++++++++++--
.../src/airflow/providers/standard/provider.yaml | 2 +
providers/tests/common/sql/operators/test_sql.py | 21 +++++++-
.../standard}/operators/test_generic_transfer.py | 60 +++++++++++++++++++++-
tests_common/test_utils/compat.py | 2 +
8 files changed, 131 insertions(+), 11 deletions(-)
diff --git a/airflow/hooks/base.py b/airflow/hooks/base.py
index e82c838c8c..8f95d7bfe1 100644
--- a/airflow/hooks/base.py
+++ b/airflow/hooks/base.py
@@ -67,15 +67,16 @@ class BaseHook(LoggingMixin):
return conn
@classmethod
- def get_hook(cls, conn_id: str) -> BaseHook:
+ def get_hook(cls, conn_id: str, hook_params: dict | None = None) ->
BaseHook:
"""
Return default hook for this connection id.
:param conn_id: connection id
+ :param hook_params: hook parameters
:return: default hook for this connection
"""
connection = cls.get_connection(conn_id)
- return connection.get_hook()
+ return connection.get_hook(hook_params=hook_params)
def get_conn(self) -> Any:
"""Return connection for the hook."""
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 2f284cc4de..dc7aedc3b9 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -1279,6 +1279,7 @@
},
"standard": {
"deps": [
+ "apache-airflow-providers-common-sql>=1.18.0",
"apache-airflow>=2.10.0"
],
"devel-deps": [],
diff --git a/providers/src/airflow/providers/common/sql/operators/sql.py
b/providers/src/airflow/providers/common/sql/operators/sql.py
index fa1539e725..dae389be02 100644
--- a/providers/src/airflow/providers/common/sql/operators/sql.py
+++ b/providers/src/airflow/providers/common/sql/operators/sql.py
@@ -145,13 +145,25 @@ class BaseSQLOperator(BaseOperator):
self.hook_params = hook_params or {}
self.retry_on_failure = retry_on_failure
+ @classmethod
+ # TODO: can be removed once Airflow min version for this provider is 3.0.0
or higher
+ def get_hook(cls, conn_id: str, hook_params: dict | None = None) ->
BaseHook:
+ """
+ Return default hook for this connection id.
+
+ :param conn_id: connection id
+ :param hook_params: hook parameters
+ :return: default hook for this connection
+ """
+ connection = BaseHook.get_connection(conn_id)
+ return connection.get_hook(hook_params=hook_params)
+
@cached_property
def _hook(self):
"""Get DB Hook based on connection type."""
conn_id = getattr(self, self.conn_id_field)
self.log.debug("Get connection for %s", conn_id)
- conn = BaseHook.get_connection(conn_id)
- hook = conn.get_hook(hook_params=self.hook_params)
+ hook = self.get_hook(conn_id=conn_id, hook_params=self.hook_params)
if not isinstance(hook, DbApiHook):
raise AirflowException(
f"You are trying to use `common-sql` with
{hook.__class__.__name__},"
diff --git a/airflow/operators/generic_transfer.py
b/providers/src/airflow/providers/standard/operators/generic_transfer.py
similarity index 76%
rename from airflow/operators/generic_transfer.py
rename to providers/src/airflow/providers/standard/operators/generic_transfer.py
index a808e23997..255b08c54c 100644
--- a/airflow/operators/generic_transfer.py
+++ b/providers/src/airflow/providers/standard/operators/generic_transfer.py
@@ -38,14 +38,21 @@ class GenericTransfer(BaseOperator):
:param sql: SQL query to execute against the source database. (templated)
:param destination_table: target table. (templated)
- :param source_conn_id: source connection
- :param destination_conn_id: destination connection
+ :param source_conn_id: source connection. (templated)
+ :param destination_conn_id: destination connection. (templated)
:param preoperator: sql statement or list of statements to be
executed prior to loading the data. (templated)
:param insert_args: extra params for `insert_rows` method.
"""
- template_fields: Sequence[str] = ("sql", "destination_table",
"preoperator")
+ template_fields: Sequence[str] = (
+ "source_conn_id",
+ "destination_conn_id",
+ "sql",
+ "destination_table",
+ "preoperator",
+ "insert_args",
+ )
template_ext: Sequence[str] = (
".sql",
".hql",
@@ -59,7 +66,9 @@ class GenericTransfer(BaseOperator):
sql: str,
destination_table: str,
source_conn_id: str,
+ source_hook_params: dict | None = None,
destination_conn_id: str,
+ destination_hook_params: dict | None = None,
preoperator: str | list[str] | None = None,
insert_args: dict | None = None,
**kwargs,
@@ -68,13 +77,29 @@ class GenericTransfer(BaseOperator):
self.sql = sql
self.destination_table = destination_table
self.source_conn_id = source_conn_id
+ self.source_hook_params = source_hook_params
self.destination_conn_id = destination_conn_id
+ self.destination_hook_params = destination_hook_params
self.preoperator = preoperator
self.insert_args = insert_args or {}
+ @classmethod
+ def get_hook(cls, conn_id: str, hook_params: dict | None = None) ->
BaseHook:
+ """
+ Return default hook for this connection id.
+
+ :param conn_id: connection id
+ :param hook_params: hook parameters
+ :return: default hook for this connection
+ """
+ connection = BaseHook.get_connection(conn_id)
+ return connection.get_hook(hook_params=hook_params)
+
def execute(self, context: Context):
- source_hook = BaseHook.get_hook(self.source_conn_id)
- destination_hook = BaseHook.get_hook(self.destination_conn_id)
+ source_hook = self.get_hook(conn_id=self.source_conn_id,
hook_params=self.source_hook_params)
+ destination_hook = self.get_hook(
+ conn_id=self.destination_conn_id,
hook_params=self.destination_hook_params
+ )
self.log.info("Extracting data from %s", self.source_conn_id)
self.log.info("Executing: \n %s", self.sql)
diff --git a/providers/src/airflow/providers/standard/provider.yaml
b/providers/src/airflow/providers/standard/provider.yaml
index b3111d62b1..6231e6c3d5 100644
--- a/providers/src/airflow/providers/standard/provider.yaml
+++ b/providers/src/airflow/providers/standard/provider.yaml
@@ -29,6 +29,7 @@ versions:
dependencies:
- apache-airflow>=2.10.0
+ - apache-airflow-providers-common-sql>=1.18.0
integrations:
- integration-name: Standard
@@ -43,6 +44,7 @@ operators:
- airflow.providers.standard.operators.datetime
- airflow.providers.standard.operators.weekday
- airflow.providers.standard.operators.bash
+ - airflow.providers.standard.operators.generic_transfer
sensors:
- integration-name: Standard
diff --git a/providers/tests/common/sql/operators/test_sql.py
b/providers/tests/common/sql/operators/test_sql.py
index 5026a5ec98..9ff5cebb76 100644
--- a/providers/tests/common/sql/operators/test_sql.py
+++ b/providers/tests/common/sql/operators/test_sql.py
@@ -18,13 +18,14 @@
from __future__ import annotations
import datetime
+import inspect
from unittest import mock
from unittest.mock import MagicMock
import pytest
from airflow import DAG
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.models import Connection, DagRun, TaskInstance as TI, XCom
from airflow.operators.empty import EmptyOperator
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
@@ -45,6 +46,7 @@ from airflow.utils.session import create_session
from airflow.utils.state import State
from tests_common.test_utils.compat import AIRFLOW_V_2_8_PLUS,
AIRFLOW_V_3_0_PLUS
+from tests_common.test_utils.providers import get_provider_min_airflow_version
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType
@@ -91,6 +93,23 @@ class TestBaseSQLOperator:
assert operator.database == "my_database"
assert operator.hook_params == {"key": "value"}
+ def
test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_get_hook_method(self):
+ """
+ Once this test starts failing due to the fact that the minimum Airflow
version is now 3.0.0 or higher
+ for this provider, you should remove the obsolete get_hook method in
the BaseSQLOperator operator
+ and remove this test. This test was added to make sure to not forget
to remove the fallback code
+ for backward compatibility with Airflow 2.8.x which isn't need anymore
once this provider depends on
+ Airflow 3.0.0 or higher.
+ """
+ min_airflow_version =
get_provider_min_airflow_version("apache-airflow-providers-common-sql")
+
+ # Check if the current Airflow version is 3.0.0 or higher
+ if min_airflow_version[0] >= 3:
+ method_source = inspect.getsource(BaseSQLOperator.get_hook)
+ raise AirflowProviderDeprecationWarning(
+ f"Check TODO's to remove obsolete get_hook method in
BaseSQLOperator:\n\r\n\r\t\t\t{method_source}"
+ )
+
class TestSQLExecuteQueryOperator:
def _construct_operator(self, sql, **kwargs):
diff --git a/tests/operators/test_generic_transfer.py
b/providers/tests/standard/operators/test_generic_transfer.py
similarity index 69%
rename from tests/operators/test_generic_transfer.py
rename to providers/tests/standard/operators/test_generic_transfer.py
index e1281ed4b2..4ea08e4889 100644
--- a/tests/operators/test_generic_transfer.py
+++ b/providers/tests/standard/operators/test_generic_transfer.py
@@ -17,16 +17,21 @@
# under the License.
from __future__ import annotations
+import inspect
from contextlib import closing
+from datetime import datetime
from unittest import mock
import pytest
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models.dag import DAG
-from airflow.operators.generic_transfer import GenericTransfer
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils import timezone
+from tests_common.test_utils.compat import GenericTransfer
+from tests_common.test_utils.providers import get_provider_min_airflow_version
+
pytestmark = pytest.mark.db_test
DEFAULT_DATE = timezone.datetime(2015, 1, 1)
@@ -151,3 +156,56 @@ class TestPostgres:
assert mock_insert.called
_, kwargs = mock_insert.call_args
assert "replace" in kwargs
+
+
+class TestGenericTransfer:
+ def test_templated_fields(self):
+ dag = DAG(
+ "test_dag",
+ schedule=None,
+ start_date=datetime(2024, 10, 10),
+ render_template_as_native_obj=True,
+ )
+ operator = GenericTransfer(
+ task_id="test_task",
+ sql="{{ sql }}",
+ destination_table="{{ destination_table }}",
+ source_conn_id="{{ source_conn_id }}",
+ destination_conn_id="{{ destination_conn_id }}",
+ preoperator="{{ preoperator }}",
+ insert_args="{{ insert_args }}",
+ dag=dag,
+ )
+ operator.render_template_fields(
+ {
+ "sql": "my_sql",
+ "destination_table": "my_destination_table",
+ "source_conn_id": "my_source_conn_id",
+ "destination_conn_id": "my_destination_conn_id",
+ "preoperator": "my_preoperator",
+ "insert_args": {"commit_every": 5000, "executemany": True,
"replace": True},
+ }
+ )
+ assert operator.sql == "my_sql"
+ assert operator.destination_table == "my_destination_table"
+ assert operator.source_conn_id == "my_source_conn_id"
+ assert operator.destination_conn_id == "my_destination_conn_id"
+ assert operator.preoperator == "my_preoperator"
+ assert operator.insert_args == {"commit_every": 5000, "executemany":
True, "replace": True}
+
+ def
test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_method(self):
+ """
+ Once this test starts failing due to the fact that the minimum Airflow
version is now 3.0.0 or higher
+ for this provider, you should remove the obsolete get_hook method in
the GenericTransfer and use the
+ one from BaseHook and remove this test. This test was added to make
sure to not forget to remove the
+ fallback code for backward compatibility with Airflow 2.8.x which
isn't need anymore once this
+ provider depends on Airflow 3.0.0 or higher.
+ """
+ min_airflow_version =
get_provider_min_airflow_version("apache-airflow-providers-standard")
+
+ # Check if the current Airflow version is 3.0.0 or higher
+ if min_airflow_version[0] >= 3:
+ method_source = inspect.getsource(GenericTransfer.get_hook)
+ raise AirflowProviderDeprecationWarning(
+ f"Remove obsolete get_hook method in
GenericTransfer:\n\r\n\r\t\t\t{method_source}"
+ )
diff --git a/tests_common/test_utils/compat.py
b/tests_common/test_utils/compat.py
index bc04f798e0..42c566c5d5 100644
--- a/tests_common/test_utils/compat.py
+++ b/tests_common/test_utils/compat.py
@@ -52,11 +52,13 @@ except ImportError:
try:
from airflow.providers.standard.operators.bash import BashOperator
+ from airflow.providers.standard.operators.generic_transfer import
GenericTransfer
from airflow.providers.standard.sensors.bash import BashSensor
from airflow.providers.standard.sensors.date_time import DateTimeSensor
except ImportError:
# Compatibility for Airflow < 2.10.*
from airflow.operators.bash import BashOperator # type:
ignore[no-redef,attr-defined]
+ from airflow.operators.generic_transfer import GenericTransfer # type:
ignore[no-redef,attr-defined]
from airflow.sensors.bash import BashSensor # type:
ignore[no-redef,attr-defined]
from airflow.sensors.date_time import DateTimeSensor # type:
ignore[no-redef,attr-defined]