This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9736143468 Add a new parameter to SQL operators to specify conn id
field (#30784)
9736143468 is described below
commit 9736143468cfe034e65afb3df3031ab3626f0f6d
Author: Hussein Awala <[email protected]>
AuthorDate: Mon Aug 7 22:00:20 2023 +0200
Add a new parameter to SQL operators to specify conn id field (#30784)
---
airflow/providers/common/sql/operators/sql.py | 11 +++++---
.../databricks/operators/databricks_sql.py | 1 +
airflow/providers/exasol/operators/exasol.py | 4 ++-
.../providers/google/cloud/operators/bigquery.py | 11 ++++++++
airflow/providers/qubole/operators/qubole_check.py | 11 ++++++--
airflow/providers/snowflake/operators/snowflake.py | 12 +++++++-
tests/providers/common/sql/operators/test_sql.py | 32 ++++++++++++++++++++++
7 files changed, 74 insertions(+), 8 deletions(-)
diff --git a/airflow/providers/common/sql/operators/sql.py
b/airflow/providers/common/sql/operators/sql.py
index 810207d1b7..709c831b5a 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -123,6 +123,8 @@ class BaseSQLOperator(BaseOperator):
:param conn_id: reference to a specific database
"""
+ conn_id_field = "conn_id"
+
def __init__(
self,
*,
@@ -141,8 +143,9 @@ class BaseSQLOperator(BaseOperator):
@cached_property
def _hook(self):
"""Get DB Hook based on connection type."""
- self.log.debug("Get connection for %s", self.conn_id)
- conn = BaseHook.get_connection(self.conn_id)
+ conn_id = getattr(self, self.conn_id_field)
+ self.log.debug("Get connection for %s", conn_id)
+ conn = BaseHook.get_connection(conn_id)
hook = conn.get_hook(hook_params=self.hook_params)
if not isinstance(hook, DbApiHook):
from airflow.hooks.dbapi_hook import DbApiHook as _DbApiHook
@@ -411,7 +414,7 @@ class SQLColumnCheckOperator(BaseSQLOperator):
:ref:`howto/operator:SQLColumnCheckOperator`
"""
- template_fields = ("partition_clause", "table", "sql")
+ template_fields: Sequence[str] = ("partition_clause", "table", "sql")
template_fields_renderers = {"sql": "sql"}
sql_check_template = """
@@ -639,7 +642,7 @@ class SQLTableCheckOperator(BaseSQLOperator):
:ref:`howto/operator:SQLTableCheckOperator`
"""
- template_fields = ("partition_clause", "table", "sql", "conn_id")
+ template_fields: Sequence[str] = ("partition_clause", "table", "sql",
"conn_id")
template_fields_renderers = {"sql": "sql"}
diff --git a/airflow/providers/databricks/operators/databricks_sql.py
b/airflow/providers/databricks/operators/databricks_sql.py
index 7d9fbb2885..2561b380fa 100644
--- a/airflow/providers/databricks/operators/databricks_sql.py
+++ b/airflow/providers/databricks/operators/databricks_sql.py
@@ -77,6 +77,7 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"sql": "sql"}
+ conn_id_field = "databricks_conn_id"
def __init__(
self,
diff --git a/airflow/providers/exasol/operators/exasol.py
b/airflow/providers/exasol/operators/exasol.py
index ecd05442d6..407fdf6591 100644
--- a/airflow/providers/exasol/operators/exasol.py
+++ b/airflow/providers/exasol/operators/exasol.py
@@ -38,10 +38,11 @@ class ExasolOperator(SQLExecuteQueryOperator):
:param handler: (optional) handler to process the results of the query
"""
- template_fields: Sequence[str] = ("sql",)
+ template_fields: Sequence[str] = ("sql", "exasol_conn_id")
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"sql": "sql"}
ui_color = "#ededed"
+ conn_id_field = "exasol_conn_id"
def __init__(
self,
@@ -51,6 +52,7 @@ class ExasolOperator(SQLExecuteQueryOperator):
handler=exasol_fetch_all_handler,
**kwargs,
) -> None:
+ self.exasol_conn_id = exasol_conn_id
if schema is not None:
hook_params = kwargs.pop("hook_params", {})
kwargs["hook_params"] = {"schema": schema, **hook_params}
diff --git a/airflow/providers/google/cloud/operators/bigquery.py
b/airflow/providers/google/cloud/operators/bigquery.py
index ca6f290004..d665414153 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -253,6 +253,7 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin,
SQLCheckOperator):
)
template_ext: Sequence[str] = (".sql",)
ui_color = BigQueryUIColors.CHECK.value
+ conn_id_field = "gcp_conn_id"
def __init__(
self,
@@ -371,6 +372,7 @@ class BigQueryValueCheckOperator(_BigQueryDbHookMixin,
SQLValueCheckOperator):
)
template_ext: Sequence[str] = (".sql",)
ui_color = BigQueryUIColors.CHECK.value
+ conn_id_field = "gcp_conn_id"
def __init__(
self,
@@ -509,6 +511,7 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin,
SQLIntervalCheckOperat
"labels",
)
ui_color = BigQueryUIColors.CHECK.value
+ conn_id_field = "gcp_conn_id"
def __init__(
self,
@@ -634,6 +637,10 @@ class BigQueryColumnCheckOperator(_BigQueryDbHookMixin,
SQLColumnCheckOperator):
:param labels: a dictionary containing labels for the table, passed to
BigQuery
"""
+ template_fields: Sequence[str] =
tuple(set(SQLColumnCheckOperator.template_fields) | {"gcp_conn_id"})
+
+ conn_id_field = "gcp_conn_id"
+
def __init__(
self,
*,
@@ -757,6 +764,10 @@ class BigQueryTableCheckOperator(_BigQueryDbHookMixin,
SQLTableCheckOperator):
:param labels: a dictionary containing labels for the table, passed to
BigQuery
"""
+ template_fields: Sequence[str] =
tuple(set(SQLTableCheckOperator.template_fields) | {"gcp_conn_id"})
+
+ conn_id_field = "gcp_conn_id"
+
def __init__(
self,
*,
diff --git a/airflow/providers/qubole/operators/qubole_check.py
b/airflow/providers/qubole/operators/qubole_check.py
index 18b2667bda..8ecbfc6a57 100644
--- a/airflow/providers/qubole/operators/qubole_check.py
+++ b/airflow/providers/qubole/operators/qubole_check.py
@@ -103,8 +103,10 @@ class QuboleCheckOperator(_QuboleCheckOperatorMixin,
SQLCheckOperator, QuboleOpe
"""
+ conn_id_field = "qubole_conn_id"
+
template_fields: Sequence[str] = tuple(
- set(QuboleOperator.template_fields) |
set(SQLCheckOperator.template_fields)
+ set(QuboleOperator.template_fields) |
set(SQLCheckOperator.template_fields) | {"qubole_conn_id"}
)
template_ext = QuboleOperator.template_ext
ui_fgcolor = "#000"
@@ -123,6 +125,7 @@ class QuboleCheckOperator(_QuboleCheckOperatorMixin,
SQLCheckOperator, QuboleOpe
self.on_failure_callback = QuboleCheckHook.handle_failure_retry
self.on_retry_callback = QuboleCheckHook.handle_failure_retry
self._hook_context = None
+ self.qubole_conn_id = qubole_conn_id
# TODO(xinbinhuang): refactor to reduce levels of inheritance
@@ -155,9 +158,12 @@ class QuboleValueCheckOperator(_QuboleCheckOperatorMixin,
SQLValueCheckOperator,
QuboleOperator and SQLValueCheckOperator are template-supported.
"""
- template_fields = tuple(set(QuboleOperator.template_fields) |
set(SQLValueCheckOperator.template_fields))
+ template_fields = tuple(
+ set(QuboleOperator.template_fields) |
set(SQLValueCheckOperator.template_fields) | {"qubole_conn_id"}
+ )
template_ext = QuboleOperator.template_ext
ui_fgcolor = "#000"
+ conn_id_field = "qubole_conn_id"
def __init__(
self,
@@ -177,6 +183,7 @@ class QuboleValueCheckOperator(_QuboleCheckOperatorMixin,
SQLValueCheckOperator,
self.on_failure_callback = QuboleCheckHook.handle_failure_retry
self.on_retry_callback = QuboleCheckHook.handle_failure_retry
self._hook_context = None
+ self.qubole_conn_id = qubole_conn_id
def get_sql_from_qbol_cmd(params) -> str:
diff --git a/airflow/providers/snowflake/operators/snowflake.py
b/airflow/providers/snowflake/operators/snowflake.py
index 8f29eefd5b..090f9cf384 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -192,9 +192,10 @@ class SnowflakeCheckOperator(SQLCheckOperator):
the time you connect to Snowflake
"""
- template_fields: Sequence[str] = ("sql",)
+ template_fields: Sequence[str] =
tuple(set(SQLCheckOperator.template_fields) | {"snowflake_conn_id"})
template_ext: Sequence[str] = (".sql",)
ui_color = "#ededed"
+ conn_id_field = "snowflake_conn_id"
def __init__(
self,
@@ -259,6 +260,10 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
the time you connect to Snowflake
"""
+ template_fields: Sequence[str] =
tuple(set(SQLValueCheckOperator.template_fields) | {"snowflake_conn_id"})
+
+ conn_id_field = "snowflake_conn_id"
+
def __init__(
self,
*,
@@ -333,6 +338,11 @@ class
SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
the time you connect to Snowflake
"""
+ template_fields: Sequence[str] = tuple(
+ set(SQLIntervalCheckOperator.template_fields) | {"snowflake_conn_id"}
+ )
+ conn_id_field = "snowflake_conn_id"
+
def __init__(
self,
*,
diff --git a/tests/providers/common/sql/operators/test_sql.py
b/tests/providers/common/sql/operators/test_sql.py
index 86608e7440..e80bad08ed 100644
--- a/tests/providers/common/sql/operators/test_sql.py
+++ b/tests/providers/common/sql/operators/test_sql.py
@@ -1278,3 +1278,35 @@ class TestSqlBranch:
assert ti.state == State.NONE
else:
raise ValueError(f"Invalid task id {ti.task_id} found!")
+
+
+class TestBaseSQLOperatorSubClass:
+
+ from airflow.providers.common.sql.operators.sql import BaseSQLOperator
+
+ class NewStyleBaseSQLOperatorSubClass(BaseSQLOperator):
+ """New style subclass of BaseSQLOperator"""
+
+ conn_id_field = "custom_conn_id_field"
+
+ def __init__(self, custom_conn_id_field="test_conn", **kwargs):
+ super().__init__(**kwargs)
+ self.custom_conn_id_field = custom_conn_id_field
+
+ class OldStyleBaseSQLOperatorSubClass(BaseSQLOperator):
+ """Old style subclass of BaseSQLOperator"""
+
+ def __init__(self, custom_conn_id_field="test_conn", **kwargs):
+ super().__init__(conn_id=custom_conn_id_field, **kwargs)
+
+ @pytest.mark.parametrize(
+ "operator_class", [NewStyleBaseSQLOperatorSubClass,
OldStyleBaseSQLOperatorSubClass]
+ )
+ @mock.patch("airflow.hooks.base.BaseHook.get_connection")
+ def test_new_style_subclass(self, mock_get_connection, operator_class):
+ from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+ op = operator_class(task_id="test_task")
+ mock_get_connection.return_value.get_hook.return_value =
MagicMock(spec=DbApiHook)
+ op.get_db_hook()
+ mock_get_connection.assert_called_once_with("test_conn")