This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9736143468 Add a new parameter to SQL operators to specify conn id 
field (#30784)
9736143468 is described below

commit 9736143468cfe034e65afb3df3031ab3626f0f6d
Author: Hussein Awala <[email protected]>
AuthorDate: Mon Aug 7 22:00:20 2023 +0200

    Add a new parameter to SQL operators to specify conn id field (#30784)
---
 airflow/providers/common/sql/operators/sql.py      | 11 +++++---
 .../databricks/operators/databricks_sql.py         |  1 +
 airflow/providers/exasol/operators/exasol.py       |  4 ++-
 .../providers/google/cloud/operators/bigquery.py   | 11 ++++++++
 airflow/providers/qubole/operators/qubole_check.py | 11 ++++++--
 airflow/providers/snowflake/operators/snowflake.py | 12 +++++++-
 tests/providers/common/sql/operators/test_sql.py   | 32 ++++++++++++++++++++++
 7 files changed, 74 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/common/sql/operators/sql.py 
b/airflow/providers/common/sql/operators/sql.py
index 810207d1b7..709c831b5a 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -123,6 +123,8 @@ class BaseSQLOperator(BaseOperator):
     :param conn_id: reference to a specific database
     """
 
+    conn_id_field = "conn_id"
+
     def __init__(
         self,
         *,
@@ -141,8 +143,9 @@ class BaseSQLOperator(BaseOperator):
     @cached_property
     def _hook(self):
         """Get DB Hook based on connection type."""
-        self.log.debug("Get connection for %s", self.conn_id)
-        conn = BaseHook.get_connection(self.conn_id)
+        conn_id = getattr(self, self.conn_id_field)
+        self.log.debug("Get connection for %s", conn_id)
+        conn = BaseHook.get_connection(conn_id)
         hook = conn.get_hook(hook_params=self.hook_params)
         if not isinstance(hook, DbApiHook):
             from airflow.hooks.dbapi_hook import DbApiHook as _DbApiHook
@@ -411,7 +414,7 @@ class SQLColumnCheckOperator(BaseSQLOperator):
         :ref:`howto/operator:SQLColumnCheckOperator`
     """
 
-    template_fields = ("partition_clause", "table", "sql")
+    template_fields: Sequence[str] = ("partition_clause", "table", "sql")
     template_fields_renderers = {"sql": "sql"}
 
     sql_check_template = """
@@ -639,7 +642,7 @@ class SQLTableCheckOperator(BaseSQLOperator):
         :ref:`howto/operator:SQLTableCheckOperator`
     """
 
-    template_fields = ("partition_clause", "table", "sql", "conn_id")
+    template_fields: Sequence[str] = ("partition_clause", "table", "sql", 
"conn_id")
 
     template_fields_renderers = {"sql": "sql"}
 
diff --git a/airflow/providers/databricks/operators/databricks_sql.py 
b/airflow/providers/databricks/operators/databricks_sql.py
index 7d9fbb2885..2561b380fa 100644
--- a/airflow/providers/databricks/operators/databricks_sql.py
+++ b/airflow/providers/databricks/operators/databricks_sql.py
@@ -77,6 +77,7 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
 
     template_ext: Sequence[str] = (".sql",)
     template_fields_renderers = {"sql": "sql"}
+    conn_id_field = "databricks_conn_id"
 
     def __init__(
         self,
diff --git a/airflow/providers/exasol/operators/exasol.py 
b/airflow/providers/exasol/operators/exasol.py
index ecd05442d6..407fdf6591 100644
--- a/airflow/providers/exasol/operators/exasol.py
+++ b/airflow/providers/exasol/operators/exasol.py
@@ -38,10 +38,11 @@ class ExasolOperator(SQLExecuteQueryOperator):
     :param handler: (optional) handler to process the results of the query
     """
 
-    template_fields: Sequence[str] = ("sql",)
+    template_fields: Sequence[str] = ("sql", "exasol_conn_id")
     template_ext: Sequence[str] = (".sql",)
     template_fields_renderers = {"sql": "sql"}
     ui_color = "#ededed"
+    conn_id_field = "exasol_conn_id"
 
     def __init__(
         self,
@@ -51,6 +52,7 @@ class ExasolOperator(SQLExecuteQueryOperator):
         handler=exasol_fetch_all_handler,
         **kwargs,
     ) -> None:
+        self.exasol_conn_id = exasol_conn_id
         if schema is not None:
             hook_params = kwargs.pop("hook_params", {})
             kwargs["hook_params"] = {"schema": schema, **hook_params}
diff --git a/airflow/providers/google/cloud/operators/bigquery.py 
b/airflow/providers/google/cloud/operators/bigquery.py
index ca6f290004..d665414153 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -253,6 +253,7 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin, 
SQLCheckOperator):
     )
     template_ext: Sequence[str] = (".sql",)
     ui_color = BigQueryUIColors.CHECK.value
+    conn_id_field = "gcp_conn_id"
 
     def __init__(
         self,
@@ -371,6 +372,7 @@ class BigQueryValueCheckOperator(_BigQueryDbHookMixin, 
SQLValueCheckOperator):
     )
     template_ext: Sequence[str] = (".sql",)
     ui_color = BigQueryUIColors.CHECK.value
+    conn_id_field = "gcp_conn_id"
 
     def __init__(
         self,
@@ -509,6 +511,7 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, 
SQLIntervalCheckOperat
         "labels",
     )
     ui_color = BigQueryUIColors.CHECK.value
+    conn_id_field = "gcp_conn_id"
 
     def __init__(
         self,
@@ -634,6 +637,10 @@ class BigQueryColumnCheckOperator(_BigQueryDbHookMixin, 
SQLColumnCheckOperator):
     :param labels: a dictionary containing labels for the table, passed to 
BigQuery
     """
 
+    template_fields: Sequence[str] = 
tuple(set(SQLColumnCheckOperator.template_fields) | {"gcp_conn_id"})
+
+    conn_id_field = "gcp_conn_id"
+
     def __init__(
         self,
         *,
@@ -757,6 +764,10 @@ class BigQueryTableCheckOperator(_BigQueryDbHookMixin, 
SQLTableCheckOperator):
     :param labels: a dictionary containing labels for the table, passed to 
BigQuery
     """
 
+    template_fields: Sequence[str] = 
tuple(set(SQLTableCheckOperator.template_fields) | {"gcp_conn_id"})
+
+    conn_id_field = "gcp_conn_id"
+
     def __init__(
         self,
         *,
diff --git a/airflow/providers/qubole/operators/qubole_check.py 
b/airflow/providers/qubole/operators/qubole_check.py
index 18b2667bda..8ecbfc6a57 100644
--- a/airflow/providers/qubole/operators/qubole_check.py
+++ b/airflow/providers/qubole/operators/qubole_check.py
@@ -103,8 +103,10 @@ class QuboleCheckOperator(_QuboleCheckOperatorMixin, 
SQLCheckOperator, QuboleOpe
 
     """
 
+    conn_id_field = "qubole_conn_id"
+
     template_fields: Sequence[str] = tuple(
-        set(QuboleOperator.template_fields) | 
set(SQLCheckOperator.template_fields)
+        set(QuboleOperator.template_fields) | 
set(SQLCheckOperator.template_fields) | {"qubole_conn_id"}
     )
     template_ext = QuboleOperator.template_ext
     ui_fgcolor = "#000"
@@ -123,6 +125,7 @@ class QuboleCheckOperator(_QuboleCheckOperatorMixin, 
SQLCheckOperator, QuboleOpe
         self.on_failure_callback = QuboleCheckHook.handle_failure_retry
         self.on_retry_callback = QuboleCheckHook.handle_failure_retry
         self._hook_context = None
+        self.qubole_conn_id = qubole_conn_id
 
 
 # TODO(xinbinhuang): refactor to reduce levels of inheritance
@@ -155,9 +158,12 @@ class QuboleValueCheckOperator(_QuboleCheckOperatorMixin, 
SQLValueCheckOperator,
             QuboleOperator and SQLValueCheckOperator are template-supported.
     """
 
-    template_fields = tuple(set(QuboleOperator.template_fields) | 
set(SQLValueCheckOperator.template_fields))
+    template_fields = tuple(
+        set(QuboleOperator.template_fields) | 
set(SQLValueCheckOperator.template_fields) | {"qubole_conn_id"}
+    )
     template_ext = QuboleOperator.template_ext
     ui_fgcolor = "#000"
+    conn_id_field = "qubole_conn_id"
 
     def __init__(
         self,
@@ -177,6 +183,7 @@ class QuboleValueCheckOperator(_QuboleCheckOperatorMixin, 
SQLValueCheckOperator,
         self.on_failure_callback = QuboleCheckHook.handle_failure_retry
         self.on_retry_callback = QuboleCheckHook.handle_failure_retry
         self._hook_context = None
+        self.qubole_conn_id = qubole_conn_id
 
 
 def get_sql_from_qbol_cmd(params) -> str:
diff --git a/airflow/providers/snowflake/operators/snowflake.py 
b/airflow/providers/snowflake/operators/snowflake.py
index 8f29eefd5b..090f9cf384 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -192,9 +192,10 @@ class SnowflakeCheckOperator(SQLCheckOperator):
         the time you connect to Snowflake
     """
 
-    template_fields: Sequence[str] = ("sql",)
+    template_fields: Sequence[str] = 
tuple(set(SQLCheckOperator.template_fields) | {"snowflake_conn_id"})
     template_ext: Sequence[str] = (".sql",)
     ui_color = "#ededed"
+    conn_id_field = "snowflake_conn_id"
 
     def __init__(
         self,
@@ -259,6 +260,10 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
         the time you connect to Snowflake
     """
 
+    template_fields: Sequence[str] = 
tuple(set(SQLValueCheckOperator.template_fields) | {"snowflake_conn_id"})
+
+    conn_id_field = "snowflake_conn_id"
+
     def __init__(
         self,
         *,
@@ -333,6 +338,11 @@ class 
SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
         the time you connect to Snowflake
     """
 
+    template_fields: Sequence[str] = tuple(
+        set(SQLIntervalCheckOperator.template_fields) | {"snowflake_conn_id"}
+    )
+    conn_id_field = "snowflake_conn_id"
+
     def __init__(
         self,
         *,
diff --git a/tests/providers/common/sql/operators/test_sql.py 
b/tests/providers/common/sql/operators/test_sql.py
index 86608e7440..e80bad08ed 100644
--- a/tests/providers/common/sql/operators/test_sql.py
+++ b/tests/providers/common/sql/operators/test_sql.py
@@ -1278,3 +1278,35 @@ class TestSqlBranch:
                     assert ti.state == State.NONE
                 else:
                     raise ValueError(f"Invalid task id {ti.task_id} found!")
+
+
+class TestBaseSQLOperatorSubClass:
+
+    from airflow.providers.common.sql.operators.sql import BaseSQLOperator
+
+    class NewStyleBaseSQLOperatorSubClass(BaseSQLOperator):
+        """New style subclass of BaseSQLOperator"""
+
+        conn_id_field = "custom_conn_id_field"
+
+        def __init__(self, custom_conn_id_field="test_conn", **kwargs):
+            super().__init__(**kwargs)
+            self.custom_conn_id_field = custom_conn_id_field
+
+    class OldStyleBaseSQLOperatorSubClass(BaseSQLOperator):
+        """Old style subclass of BaseSQLOperator"""
+
+        def __init__(self, custom_conn_id_field="test_conn", **kwargs):
+            super().__init__(conn_id=custom_conn_id_field, **kwargs)
+
+    @pytest.mark.parametrize(
+        "operator_class", [NewStyleBaseSQLOperatorSubClass, 
OldStyleBaseSQLOperatorSubClass]
+    )
+    @mock.patch("airflow.hooks.base.BaseHook.get_connection")
+    def test_new_style_subclass(self, mock_get_connection, operator_class):
+        from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+        op = operator_class(task_id="test_task")
+        mock_get_connection.return_value.get_hook.return_value = 
MagicMock(spec=DbApiHook)
+        op.get_db_hook()
+        mock_get_connection.assert_called_once_with("test_conn")

Reply via email to