This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4d54cda411 Make conn id parameters templated in GenericTransfer and 
also allow passing hook parameters like in BaseSQLOperator (#42891)
4d54cda411 is described below

commit 4d54cda4114125bb671b0bfccddc73b646855a2d
Author: David Blain <[email protected]>
AuthorDate: Thu Oct 24 11:34:56 2024 +0200

    Make conn id parameters templated in GenericTransfer and also allow passing 
hook parameters like in BaseSQLOperator (#42891)
    
    * refactored: Added hook_params to get_hook BaseHook method, templated 
conn_id's for GenericTransfer and also allow passing hook params for each 
connection in GenericTransfer
    
    
    ---------
    
    Co-authored-by: David Blain <[email protected]>
---
 airflow/hooks/base.py                              |  5 +-
 generated/provider_dependencies.json               |  1 +
 .../airflow/providers/common/sql/operators/sql.py  | 16 +++++-
 .../standard}/operators/generic_transfer.py        | 35 +++++++++++--
 .../src/airflow/providers/standard/provider.yaml   |  2 +
 providers/tests/common/sql/operators/test_sql.py   | 21 +++++++-
 .../standard}/operators/test_generic_transfer.py   | 60 +++++++++++++++++++++-
 tests_common/test_utils/compat.py                  |  2 +
 8 files changed, 131 insertions(+), 11 deletions(-)

diff --git a/airflow/hooks/base.py b/airflow/hooks/base.py
index e82c838c8c..8f95d7bfe1 100644
--- a/airflow/hooks/base.py
+++ b/airflow/hooks/base.py
@@ -67,15 +67,16 @@ class BaseHook(LoggingMixin):
         return conn
 
     @classmethod
-    def get_hook(cls, conn_id: str) -> BaseHook:
+    def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> 
BaseHook:
         """
         Return default hook for this connection id.
 
         :param conn_id: connection id
+        :param hook_params: hook parameters
         :return: default hook for this connection
         """
         connection = cls.get_connection(conn_id)
-        return connection.get_hook()
+        return connection.get_hook(hook_params=hook_params)
 
     def get_conn(self) -> Any:
         """Return connection for the hook."""
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 2f284cc4de..dc7aedc3b9 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -1279,6 +1279,7 @@
   },
   "standard": {
     "deps": [
+      "apache-airflow-providers-common-sql>=1.18.0",
       "apache-airflow>=2.10.0"
     ],
     "devel-deps": [],
diff --git a/providers/src/airflow/providers/common/sql/operators/sql.py 
b/providers/src/airflow/providers/common/sql/operators/sql.py
index fa1539e725..dae389be02 100644
--- a/providers/src/airflow/providers/common/sql/operators/sql.py
+++ b/providers/src/airflow/providers/common/sql/operators/sql.py
@@ -145,13 +145,25 @@ class BaseSQLOperator(BaseOperator):
         self.hook_params = hook_params or {}
         self.retry_on_failure = retry_on_failure
 
+    @classmethod
+    # TODO: can be removed once Airflow min version for this provider is 3.0.0 
or higher
+    def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> 
BaseHook:
+        """
+        Return default hook for this connection id.
+
+        :param conn_id: connection id
+        :param hook_params: hook parameters
+        :return: default hook for this connection
+        """
+        connection = BaseHook.get_connection(conn_id)
+        return connection.get_hook(hook_params=hook_params)
+
     @cached_property
     def _hook(self):
         """Get DB Hook based on connection type."""
         conn_id = getattr(self, self.conn_id_field)
         self.log.debug("Get connection for %s", conn_id)
-        conn = BaseHook.get_connection(conn_id)
-        hook = conn.get_hook(hook_params=self.hook_params)
+        hook = self.get_hook(conn_id=conn_id, hook_params=self.hook_params)
         if not isinstance(hook, DbApiHook):
             raise AirflowException(
                 f"You are trying to use `common-sql` with 
{hook.__class__.__name__},"
diff --git a/airflow/operators/generic_transfer.py 
b/providers/src/airflow/providers/standard/operators/generic_transfer.py
similarity index 76%
rename from airflow/operators/generic_transfer.py
rename to providers/src/airflow/providers/standard/operators/generic_transfer.py
index a808e23997..255b08c54c 100644
--- a/airflow/operators/generic_transfer.py
+++ b/providers/src/airflow/providers/standard/operators/generic_transfer.py
@@ -38,14 +38,21 @@ class GenericTransfer(BaseOperator):
 
     :param sql: SQL query to execute against the source database. (templated)
     :param destination_table: target table. (templated)
-    :param source_conn_id: source connection
-    :param destination_conn_id: destination connection
+    :param source_conn_id: source connection. (templated)
+    :param destination_conn_id: destination connection. (templated)
     :param preoperator: sql statement or list of statements to be
         executed prior to loading the data. (templated)
     :param insert_args: extra params for `insert_rows` method.
     """
 
-    template_fields: Sequence[str] = ("sql", "destination_table", 
"preoperator")
+    template_fields: Sequence[str] = (
+        "source_conn_id",
+        "destination_conn_id",
+        "sql",
+        "destination_table",
+        "preoperator",
+        "insert_args",
+    )
     template_ext: Sequence[str] = (
         ".sql",
         ".hql",
@@ -59,7 +66,9 @@ class GenericTransfer(BaseOperator):
         sql: str,
         destination_table: str,
         source_conn_id: str,
+        source_hook_params: dict | None = None,
         destination_conn_id: str,
+        destination_hook_params: dict | None = None,
         preoperator: str | list[str] | None = None,
         insert_args: dict | None = None,
         **kwargs,
@@ -68,13 +77,29 @@ class GenericTransfer(BaseOperator):
         self.sql = sql
         self.destination_table = destination_table
         self.source_conn_id = source_conn_id
+        self.source_hook_params = source_hook_params
         self.destination_conn_id = destination_conn_id
+        self.destination_hook_params = destination_hook_params
         self.preoperator = preoperator
         self.insert_args = insert_args or {}
 
+    @classmethod
+    def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> 
BaseHook:
+        """
+        Return default hook for this connection id.
+
+        :param conn_id: connection id
+        :param hook_params: hook parameters
+        :return: default hook for this connection
+        """
+        connection = BaseHook.get_connection(conn_id)
+        return connection.get_hook(hook_params=hook_params)
+
     def execute(self, context: Context):
-        source_hook = BaseHook.get_hook(self.source_conn_id)
-        destination_hook = BaseHook.get_hook(self.destination_conn_id)
+        source_hook = self.get_hook(conn_id=self.source_conn_id, 
hook_params=self.source_hook_params)
+        destination_hook = self.get_hook(
+            conn_id=self.destination_conn_id, 
hook_params=self.destination_hook_params
+        )
 
         self.log.info("Extracting data from %s", self.source_conn_id)
         self.log.info("Executing: \n %s", self.sql)
diff --git a/providers/src/airflow/providers/standard/provider.yaml 
b/providers/src/airflow/providers/standard/provider.yaml
index b3111d62b1..6231e6c3d5 100644
--- a/providers/src/airflow/providers/standard/provider.yaml
+++ b/providers/src/airflow/providers/standard/provider.yaml
@@ -29,6 +29,7 @@ versions:
 
 dependencies:
   - apache-airflow>=2.10.0
+  - apache-airflow-providers-common-sql>=1.18.0
 
 integrations:
   - integration-name: Standard
@@ -43,6 +44,7 @@ operators:
       - airflow.providers.standard.operators.datetime
       - airflow.providers.standard.operators.weekday
       - airflow.providers.standard.operators.bash
+      - airflow.providers.standard.operators.generic_transfer
 
 sensors:
   - integration-name: Standard
diff --git a/providers/tests/common/sql/operators/test_sql.py 
b/providers/tests/common/sql/operators/test_sql.py
index 5026a5ec98..9ff5cebb76 100644
--- a/providers/tests/common/sql/operators/test_sql.py
+++ b/providers/tests/common/sql/operators/test_sql.py
@@ -18,13 +18,14 @@
 from __future__ import annotations
 
 import datetime
+import inspect
 from unittest import mock
 from unittest.mock import MagicMock
 
 import pytest
 
 from airflow import DAG
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.models import Connection, DagRun, TaskInstance as TI, XCom
 from airflow.operators.empty import EmptyOperator
 from airflow.providers.common.sql.hooks.sql import fetch_all_handler
@@ -45,6 +46,7 @@ from airflow.utils.session import create_session
 from airflow.utils.state import State
 
 from tests_common.test_utils.compat import AIRFLOW_V_2_8_PLUS, 
AIRFLOW_V_3_0_PLUS
+from tests_common.test_utils.providers import get_provider_min_airflow_version
 
 if AIRFLOW_V_3_0_PLUS:
     from airflow.utils.types import DagRunTriggeredByType
@@ -91,6 +93,23 @@ class TestBaseSQLOperator:
         assert operator.database == "my_database"
         assert operator.hook_params == {"key": "value"}
 
+    def 
test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_get_hook_method(self):
+        """
+        Once this test starts failing due to the fact that the minimum Airflow 
version is now 3.0.0 or higher
+        for this provider, you should remove the obsolete get_hook method in 
the BaseSQLOperator operator
+        and remove this test.  This test was added to make sure to not forget 
to remove the fallback code
+        for backward compatibility with Airflow 2.8.x which isn't need anymore 
once this provider depends on
+        Airflow 3.0.0 or higher.
+        """
+        min_airflow_version = 
get_provider_min_airflow_version("apache-airflow-providers-common-sql")
+
+        # Check if the current Airflow version is 3.0.0 or higher
+        if min_airflow_version[0] >= 3:
+            method_source = inspect.getsource(BaseSQLOperator.get_hook)
+            raise AirflowProviderDeprecationWarning(
+                f"Check TODO's to remove obsolete get_hook method in 
BaseSQLOperator:\n\r\n\r\t\t\t{method_source}"
+            )
+
 
 class TestSQLExecuteQueryOperator:
     def _construct_operator(self, sql, **kwargs):
diff --git a/tests/operators/test_generic_transfer.py 
b/providers/tests/standard/operators/test_generic_transfer.py
similarity index 69%
rename from tests/operators/test_generic_transfer.py
rename to providers/tests/standard/operators/test_generic_transfer.py
index e1281ed4b2..4ea08e4889 100644
--- a/tests/operators/test_generic_transfer.py
+++ b/providers/tests/standard/operators/test_generic_transfer.py
@@ -17,16 +17,21 @@
 # under the License.
 from __future__ import annotations
 
+import inspect
 from contextlib import closing
+from datetime import datetime
 from unittest import mock
 
 import pytest
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.models.dag import DAG
-from airflow.operators.generic_transfer import GenericTransfer
 from airflow.providers.postgres.hooks.postgres import PostgresHook
 from airflow.utils import timezone
 
+from tests_common.test_utils.compat import GenericTransfer
+from tests_common.test_utils.providers import get_provider_min_airflow_version
+
 pytestmark = pytest.mark.db_test
 
 DEFAULT_DATE = timezone.datetime(2015, 1, 1)
@@ -151,3 +156,56 @@ class TestPostgres:
         assert mock_insert.called
         _, kwargs = mock_insert.call_args
         assert "replace" in kwargs
+
+
+class TestGenericTransfer:
+    def test_templated_fields(self):
+        dag = DAG(
+            "test_dag",
+            schedule=None,
+            start_date=datetime(2024, 10, 10),
+            render_template_as_native_obj=True,
+        )
+        operator = GenericTransfer(
+            task_id="test_task",
+            sql="{{ sql }}",
+            destination_table="{{ destination_table }}",
+            source_conn_id="{{ source_conn_id }}",
+            destination_conn_id="{{ destination_conn_id }}",
+            preoperator="{{ preoperator }}",
+            insert_args="{{ insert_args }}",
+            dag=dag,
+        )
+        operator.render_template_fields(
+            {
+                "sql": "my_sql",
+                "destination_table": "my_destination_table",
+                "source_conn_id": "my_source_conn_id",
+                "destination_conn_id": "my_destination_conn_id",
+                "preoperator": "my_preoperator",
+                "insert_args": {"commit_every": 5000, "executemany": True, 
"replace": True},
+            }
+        )
+        assert operator.sql == "my_sql"
+        assert operator.destination_table == "my_destination_table"
+        assert operator.source_conn_id == "my_source_conn_id"
+        assert operator.destination_conn_id == "my_destination_conn_id"
+        assert operator.preoperator == "my_preoperator"
+        assert operator.insert_args == {"commit_every": 5000, "executemany": 
True, "replace": True}
+
+    def 
test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_method(self):
+        """
+        Once this test starts failing due to the fact that the minimum Airflow 
version is now 3.0.0 or higher
+        for this provider, you should remove the obsolete get_hook method in 
the GenericTransfer and use the
+        one from BaseHook and remove this test.  This test was added to make 
sure to not forget to remove the
+        fallback code for backward compatibility with Airflow 2.8.x which 
isn't need anymore once this
+        provider depends on Airflow 3.0.0 or higher.
+        """
+        min_airflow_version = 
get_provider_min_airflow_version("apache-airflow-providers-standard")
+
+        # Check if the current Airflow version is 3.0.0 or higher
+        if min_airflow_version[0] >= 3:
+            method_source = inspect.getsource(GenericTransfer.get_hook)
+            raise AirflowProviderDeprecationWarning(
+                f"Remove obsolete get_hook method in 
GenericTransfer:\n\r\n\r\t\t\t{method_source}"
+            )
diff --git a/tests_common/test_utils/compat.py 
b/tests_common/test_utils/compat.py
index bc04f798e0..42c566c5d5 100644
--- a/tests_common/test_utils/compat.py
+++ b/tests_common/test_utils/compat.py
@@ -52,11 +52,13 @@ except ImportError:
 
 try:
     from airflow.providers.standard.operators.bash import BashOperator
+    from airflow.providers.standard.operators.generic_transfer import 
GenericTransfer
     from airflow.providers.standard.sensors.bash import BashSensor
     from airflow.providers.standard.sensors.date_time import DateTimeSensor
 except ImportError:
     # Compatibility for Airflow < 2.10.*
     from airflow.operators.bash import BashOperator  # type: 
ignore[no-redef,attr-defined]
+    from airflow.operators.generic_transfer import GenericTransfer  # type: 
ignore[no-redef,attr-defined]
     from airflow.sensors.bash import BashSensor  # type: 
ignore[no-redef,attr-defined]
     from airflow.sensors.date_time import DateTimeSensor  # type: 
ignore[no-redef,attr-defined]
 

Reply via email to