This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7acd9cb5350 Add session-level query tags to Databricks SQL operators 
(#66895)
7acd9cb5350 is described below

commit 7acd9cb53505e0fde250a487524bbfb3793036ba
Author: Nguyễn Ngọc Thành <[email protected]>
AuthorDate: Tue May 26 13:34:29 2026 +0700

    Add session-level query tags to Databricks SQL operators (#66895)
---
 .../providers/databricks/hooks/databricks_sql.py   |  64 +++++++-
 .../databricks/operators/databricks_sql.py         |  59 +++++++-
 .../databricks/sensors/databricks_partition.py     |   2 +-
 .../providers/databricks/sensors/databricks_sql.py |   2 +-
 .../unit/databricks/hooks/test_databricks_sql.py   | 145 +++++++++++++++++-
 .../databricks/operators/test_databricks_sql.py    | 165 ++++++++++++++++++++-
 6 files changed, 425 insertions(+), 12 deletions(-)

diff --git 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
index 021142395b2..ffc089607c3 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import logging
 import threading
 from collections import namedtuple
 from collections.abc import Callable, Iterable, Mapping, Sequence
@@ -51,6 +52,8 @@ if TYPE_CHECKING:
 
 T = TypeVar("T")
 
+log = logging.getLogger(__name__)
+
 
 def create_timeout_thread(
     cur, execution_timeout: timedelta | None
@@ -71,6 +74,35 @@ def create_timeout_thread(
     return timer, timeout_event
 
 
+def _format_query_tag_value(value: str) -> str:
+    """
+    Escape special characters and truncate a single query tag value.
+
+    Databricks ``QUERY_TAGS`` uses ``key:value`` pairs delimited by commas, so
+    backslash, comma and colon inside *values* must be escaped.  Values are 
also
+    capped at 128 characters before escaping to keep the overall tag string
+    within reasonable bounds.
+    """
+    raw = str(value)
+    if len(raw) > 128:
+        log.warning(
+            "Query tag value truncated to 128 characters (original length %d): 
%r", len(raw), raw[:128]
+        )
+    value = raw[:128]
+    return value.replace("\\", "\\\\").replace(",", "\\,").replace(":", "\\:")
+
+
+def _format_query_tags(tags: dict[str, str | None]) -> str:
+    """
+    Serialize a query-tags dict to the ``key:value,key:value`` string expected 
by ``QUERY_TAGS``.
+
+    Entries whose value is ``None`` are omitted.
+    """
+    return ",".join(
+        f"{key}:{_format_query_tag_value(value)}" for key, value in 
tags.items() if value is not None
+    )
+
+
 class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
     """
     Hook to interact with Databricks SQL.
@@ -88,6 +120,10 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
         on every request
     :param catalog: An optional initial catalog to use. Requires DBR version 
9.0+
     :param schema: An optional initial schema to use. Requires DBR version 9.0+
+    :param query_tags: An optional dict of query tags to attach to every SQL 
statement executed by
+        this hook.  Tags are injected via the ``QUERY_TAGS`` Databricks 
session parameter so they
+        appear in ``system.query.history``.  Any existing ``QUERY_TAGS`` 
already present in
+        *session_configuration* are preserved and the new tags are appended.
     :param kwargs: Additional parameters internal to Databricks SQL Connector 
parameters
     """
 
@@ -104,6 +140,7 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
         catalog: str | None = None,
         schema: str | None = None,
         caller: str = "DatabricksSqlHook",
+        query_tags: dict[str, str | None] | None = None,
         **kwargs,
     ) -> None:
         super().__init__(databricks_conn_id, caller=caller)
@@ -118,6 +155,7 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
         self.schema = schema
         self.additional_params = kwargs
         self.query_ids: list[str] = []
+        self.query_tags = query_tags
 
     def _get_extra_config(self) -> dict[str, Any | None]:
         extra_params = copy(self.databricks_conn.extra_dejson)
@@ -169,20 +207,32 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
         if not self.session_config:
             self.session_config = 
self.databricks_conn.extra_dejson.get("session_configuration")
 
+        # session_configuration (including QUERY_TAGS) is applied only when 
opening a new
+        # connection; changing query_tags after the first get_conn() call has 
no effect.
         if not self._sql_conn or prev_token != new_token:
             if self._sql_conn:  # close already existing connection
                 self._sql_conn.close()
+            session_config: dict[str, str] = dict(self.session_config) if 
self.session_config else {}
+            if self.query_tags:
+                tags_str = _format_query_tags(self.query_tags)
+                existing = session_config.get("QUERY_TAGS", "")
+                session_config["QUERY_TAGS"] = f"{existing},{tags_str}" if 
existing else tags_str
+
+            connect_kwargs = {
+                "schema": self.schema,
+                "catalog": self.catalog,
+                "session_configuration": session_config or None,
+                "http_headers": self.http_headers,
+                "_user_agent_entry": self.user_agent_value,
+                **self._get_extra_config(),
+                **self.additional_params,
+            }
+
             self._sql_conn = sql.connect(
                 self.host,
                 self._http_path,
                 self._token,
-                schema=self.schema,
-                catalog=self.catalog,
-                session_configuration=self.session_config,
-                http_headers=self.http_headers,
-                _user_agent_entry=self.user_agent_value,
-                **self._get_extra_config(),
-                **self.additional_params,
+                **connect_kwargs,
             )
 
         if self._sql_conn is None:
diff --git 
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
 
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
index b50c434d04c..f72514f2488 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
@@ -46,6 +46,24 @@ _IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
 _DISALLOWED_SQL_TOKENS = (";", "--", "/*", "*/")
 
 
+def _get_airflow_query_tags(context: Context) -> dict[str, str | None]:
+    """Return Airflow context metadata as a query-tags dict."""
+    task_instance = context.get("ti")
+    if task_instance is None:
+        return {}
+
+    def _as_str(value: Any) -> str | None:
+        return None if value is None else str(value)
+
+    return {
+        "airflow_dag_id": _as_str(task_instance.dag_id),
+        "airflow_task_id": _as_str(task_instance.task_id),
+        "airflow_run_id": _as_str(task_instance.run_id),
+        "airflow_try_number": _as_str(task_instance.try_number),
+        "airflow_map_index": _as_str(task_instance.map_index),
+    }
+
+
 class DatabricksSqlOperator(SQLExecuteQueryOperator):
     """
     Executes SQL code in a Databricks SQL endpoint or a Databricks cluster.
@@ -68,6 +86,8 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
     :param session_configuration: An optional dictionary of Spark session 
parameters. Defaults to None.
         If not specified, it could be specified in the Databricks connection's 
extra parameters.
     :param client_parameters: Additional parameters internal to Databricks SQL 
Connector parameters
+    :param query_tags: Optional dictionary of query tags to attach to 
Databricks SQL queries.
+    :param include_airflow_query_tags: If True, add Airflow DAG/task/run 
metadata as query tags.
     :param http_headers: An optional list of (k, v) pairs that will be set as 
HTTP headers on every request.
          (templated)
     :param catalog: An optional initial catalog to use. Requires DBR version 
9.0+ (templated)
@@ -93,6 +113,7 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
             "http_headers",
             "databricks_conn_id",
             "_gcs_impersonation_chain",
+            "query_tags",
         }
         | set(SQLExecuteQueryOperator.template_fields)
     )
@@ -115,6 +136,8 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
         output_format: str = "csv",
         csv_params: dict[str, Any] | None = None,
         client_parameters: dict[str, Any] | None = None,
+        query_tags: dict[str, str | None] | None = None,
+        include_airflow_query_tags: bool = True,
         gcp_conn_id: str = "google_cloud_default",
         gcs_impersonation_chain: str | Sequence[str] | None = None,
         **kwargs,
@@ -132,6 +155,8 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
         self.http_headers = http_headers
         self.catalog = catalog
         self.schema = schema
+        self.query_tags = query_tags or {}
+        self.include_airflow_query_tags = include_airflow_query_tags
         self._gcp_conn_id = gcp_conn_id
         self._gcs_impersonation_chain = gcs_impersonation_chain
 
@@ -303,6 +328,20 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
 
         return list(zip(descriptions, results))
 
+    def _get_query_tags(self, context: Context) -> dict[str, str | None] | 
None:
+        query_tags: dict[str, str | None] = {}
+
+        if self.include_airflow_query_tags and context is not None:
+            query_tags.update(_get_airflow_query_tags(context))
+
+        query_tags.update(self.query_tags)
+
+        return query_tags or None
+
+    def execute(self, context: Context) -> Any:
+        self.get_db_hook().query_tags = self._get_query_tags(context)
+        return super().execute(context)
+
 
 COPY_INTO_APPROVED_FORMATS = ["CSV", "JSON", "AVRO", "ORC", "PARQUET", "TEXT", 
"BINARYFILE"]
 
@@ -335,6 +374,8 @@ class DatabricksCopyIntoOperator(BaseOperator):
     :param catalog: An optional initial catalog to use. Requires DBR version 
9.0+
     :param schema: An optional initial schema to use. Requires DBR version 9.0+
     :param client_parameters: Additional parameters internal to Databricks SQL 
Connector parameters
+    :param query_tags: Optional dictionary of query tags to attach to 
Databricks SQL queries.
+    :param include_airflow_query_tags: If True, add Airflow DAG/task/run 
metadata as query tags.
     :param files: optional list of files to import. Can't be specified 
together with ``pattern``. (templated)
     :param pattern: optional regex string to match file names to import.
         Can't be specified together with ``files``.
@@ -355,6 +396,7 @@ class DatabricksCopyIntoOperator(BaseOperator):
         "files",
         "table_name",
         "databricks_conn_id",
+        "query_tags",
     )
 
     def __init__(
@@ -381,9 +423,11 @@ class DatabricksCopyIntoOperator(BaseOperator):
         force_copy: bool | None = None,
         copy_options: dict[str, str] | None = None,
         validate: bool | int | None = None,
+        query_tags: dict[str, str | None] | None = None,
+        include_airflow_query_tags: bool = True,
         **kwargs,
     ) -> None:
-        """Create a new ``DatabricksSqlOperator``."""
+        """Create a new ``DatabricksCopyIntoOperator``."""
         super().__init__(**kwargs)
         if files is not None and pattern is not None:
             raise AirflowException("Only one of 'pattern' or 'files' should be 
specified")
@@ -413,6 +457,8 @@ class DatabricksCopyIntoOperator(BaseOperator):
         self._validate = validate
         self._http_headers = http_headers
         self._client_parameters = client_parameters or {}
+        self.query_tags = query_tags or {}
+        self.include_airflow_query_tags = include_airflow_query_tags
         if force_copy is not None:
             self._copy_options["force"] = "true" if force_copy else "false"
         self._sql: str | None = None
@@ -514,10 +560,21 @@ FILEFORMAT = {self._file_format}
 """
         return sql.strip()
 
+    def _get_query_tags(self, context: Context) -> dict[str, str | None] | 
None:
+        query_tags: dict[str, str | None] = {}
+
+        if self.include_airflow_query_tags and context is not None:
+            query_tags.update(_get_airflow_query_tags(context))
+
+        query_tags.update(self.query_tags)
+
+        return query_tags or None
+
     def execute(self, context: Context) -> Any:
         self._sql = self._create_sql_query()
         self.log.info("Executing: %s", self._sql)
         hook = self._get_hook()
+        hook.query_tags = self._get_query_tags(context)
         hook.run(self._sql)
 
     def on_kill(self) -> None:
diff --git 
a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py
 
b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py
index b4501ef1d43..2036dca97c3 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py
@@ -130,7 +130,7 @@ class DatabricksPartitionSensor(BaseSensorOperator):
             self.http_headers,
             self.catalog,
             self.schema,
-            self.caller,
+            caller=self.caller,
             **self.client_parameters,
             **self.hook_params,
         )
diff --git 
a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_sql.py
 
b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_sql.py
index 68a44a20fec..6ab52df67b5 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_sql.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_sql.py
@@ -109,7 +109,7 @@ class DatabricksSqlSensor(BaseSensorOperator):
             self.http_headers,
             self.catalog,
             self.schema,
-            self.caller,
+            caller=self.caller,
             **self.client_parameters,
             **self.hook_params,
         )
diff --git 
a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py 
b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
index f3f053d443c..cd3c00e2839 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
@@ -32,7 +32,12 @@ from databricks.sql.types import Row
 from airflow.models import Connection
 from airflow.providers.common.compat.sdk import AirflowException, 
AirflowOptionalProviderFeatureException
 from airflow.providers.common.sql.hooks.handlers import fetch_all_handler
-from airflow.providers.databricks.hooks.databricks_sql import 
DatabricksSqlHook, create_timeout_thread
+from airflow.providers.databricks.hooks.databricks_sql import (
+    DatabricksSqlHook,
+    _format_query_tag_value,
+    _format_query_tags,
+    create_timeout_thread,
+)
 
 TASK_ID = "databricks-sql-operator"
 DEFAULT_CONN_ID = "databricks_default"
@@ -792,3 +797,141 @@ class TestGetSqlEndpointByName:
         hook = DatabricksSqlHook(sql_endpoint_name="Test")
         with pytest.raises(RuntimeError, match="Can't list Databricks SQL 
warehouses"):
             hook._get_sql_endpoint_by_name("Test")
+
+
[email protected]("airflow.providers.databricks.hooks.databricks_sql.sql.connect")
+def test_get_conn_passes_query_tags_via_session_configuration(mock_connect, 
mock_get_requests):
+    """query_tags must be injected into session_configuration['QUERY_TAGS'], 
not sql.connect(query_tags=)."""
+    hook = DatabricksSqlHook(
+        databricks_conn_id=DEFAULT_CONN_ID,
+        http_path=HTTP_PATH,
+        query_tags={"airflow_dag_id": "dag_1", "airflow_task_id": "task_1"},
+    )
+
+    hook.get_conn()
+
+    mock_connect.assert_called_once()
+    session_cfg = mock_connect.call_args.kwargs["session_configuration"]
+    assert session_cfg is not None
+    assert "QUERY_TAGS" in session_cfg
+    query_tags_str = session_cfg["QUERY_TAGS"]
+    assert "airflow_dag_id:dag_1" in query_tags_str
+    assert "airflow_task_id:task_1" in query_tags_str
+    assert "query_tags" not in mock_connect.call_args.kwargs
+
+
[email protected]("airflow.providers.databricks.hooks.databricks_sql.sql.connect")
+def 
test_get_conn_merges_query_tags_with_existing_session_configuration(mock_connect,
 mock_get_requests):
+    """Existing QUERY_TAGS in session_configuration must be preserved and new 
tags appended."""
+    hook = DatabricksSqlHook(
+        databricks_conn_id=DEFAULT_CONN_ID,
+        http_path=HTTP_PATH,
+        session_configuration={"QUERY_TAGS": "existing_tag:existing_value"},
+        query_tags={"airflow_dag_id": "dag_1"},
+    )
+
+    hook.get_conn()
+
+    mock_connect.assert_called_once()
+    session_cfg = mock_connect.call_args.kwargs["session_configuration"]
+    query_tags_str = session_cfg["QUERY_TAGS"]
+    assert "existing_tag:existing_value" in query_tags_str
+    assert "airflow_dag_id:dag_1" in query_tags_str
+
+
[email protected]("airflow.providers.databricks.hooks.databricks_sql.sql.connect")
+def test_get_conn_no_query_tags(mock_connect, mock_get_requests):
+    """When no query_tags are provided, session_configuration should not gain 
a QUERY_TAGS key."""
+    hook = DatabricksSqlHook(
+        databricks_conn_id=DEFAULT_CONN_ID,
+        http_path=HTTP_PATH,
+    )
+
+    hook.get_conn()
+
+    mock_connect.assert_called_once()
+    session_cfg = mock_connect.call_args.kwargs.get("session_configuration")
+    assert session_cfg is None or "QUERY_TAGS" not in session_cfg
+
+
+class TestFormatQueryTags:
+    def test_simple_values(self):
+        result = _format_query_tags({"dag_id": "my_dag", "task_id": "my_task"})
+        assert "dag_id:my_dag" in result
+        assert "task_id:my_task" in result
+
+    def test_none_values_omitted(self):
+        result = _format_query_tags({"dag_id": "my_dag", "map_index": None})
+        assert "dag_id:my_dag" in result
+        assert "map_index" not in result
+
+    def test_empty_dict_returns_empty_string(self):
+        assert _format_query_tags({}) == ""
+
+    def test_value_escaping_comma(self):
+        result = _format_query_tag_value("a,b")
+        assert result == "a\\,b"
+
+    def test_value_escaping_colon(self):
+        result = _format_query_tag_value("a:b")
+        assert result == "a\\:b"
+
+    def test_value_escaping_backslash(self):
+        result = _format_query_tag_value("a\\b")
+        assert result == "a\\\\b"
+
+    def test_value_truncated_at_128_chars(self):
+        long_value = "x" * 200
+        result = _format_query_tag_value(long_value)
+        assert len(result) == 128
+
+    def test_format_query_tags_roundtrip(self):
+        tags = {"airflow_dag_id": "dag:1", "airflow_run_id": "run,2"}
+        result = _format_query_tags(tags)
+        assert "airflow_dag_id:dag\\:1" in result
+        assert "airflow_run_id:run\\,2" in result
+
+
+class TestDatabricksSqlHookQueryTagsParamOrder:
+    """Ensure moving query_tags after caller preserves positional backward 
compatibility."""
+
+    def test_query_tags_keyword_sets_field(self):
+        """query_tags kwarg must be stored on the instance."""
+        with patch(
+            
"airflow.providers.databricks.hooks.databricks_sql.BaseDatabricksHook.__init__",
+            return_value=None,
+        ) as mock_base_init:
+            hook = DatabricksSqlHook.__new__(DatabricksSqlHook)
+            DatabricksSqlHook.__init__(
+                hook,
+                DEFAULT_CONN_ID,
+                query_tags={"key": "val"},
+            )
+            assert hook.query_tags == {"key": "val"}
+            # caller is forwarded to BaseDatabricksHook.__init__; verify the 
default was passed
+            assert mock_base_init.call_args.kwargs.get("caller") == 
"DatabricksSqlHook"
+
+    def test_caller_positional_not_confused_with_query_tags(self):
+        """Passing caller as the 8th positional arg must not end up in 
query_tags."""
+        with patch(
+            
"airflow.providers.databricks.hooks.databricks_sql.BaseDatabricksHook.__init__",
+            return_value=None,
+        ) as mock_base_init:
+            hook = DatabricksSqlHook.__new__(DatabricksSqlHook)
+            # positional order: conn_id, http_path, sql_endpoint, session_cfg,
+            #                   http_headers, catalog, schema, caller
+            DatabricksSqlHook.__init__(
+                hook,
+                DEFAULT_CONN_ID,  # databricks_conn_id
+                None,  # http_path
+                None,  # sql_endpoint_name
+                None,  # session_configuration
+                None,  # http_headers
+                None,  # catalog
+                None,  # schema
+                "CustomCaller",  # caller (8th positional)
+            )
+            # caller is forwarded to BaseDatabricksHook.__init__; verify it 
was not
+            # confused with query_tags (which comes after caller)
+            assert mock_base_init.call_args.kwargs.get("caller") == 
"CustomCaller"
+            assert hook.query_tags is None
diff --git 
a/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py 
b/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py
index e216c56bea2..bfd6d89437b 100644
--- 
a/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py
+++ 
b/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py
@@ -20,13 +20,17 @@ from __future__ import annotations
 import json
 import os
 from collections import namedtuple
+from unittest import mock
 from unittest.mock import patch
 
 import pytest
 from databricks.sql.types import Row
 
 from airflow.providers.common.sql.hooks.handlers import fetch_all_handler
-from airflow.providers.databricks.operators.databricks_sql import 
DatabricksSqlOperator
+from airflow.providers.databricks.operators.databricks_sql import (
+    DatabricksSqlOperator,
+    _get_airflow_query_tags,
+)
 
 DATE = "2017-04-20"
 TASK_ID = "databricks-sql-operator"
@@ -453,3 +457,162 @@ def test_parse_gcs_path():
     bucket, object_name = 
op._parse_gcs_path("gs://my-bucket/path/to/file.parquet")
     assert bucket == "my-bucket"
     assert object_name == "path/to/file.parquet"
+
+
+class TestDatabricksSqlOperatorQueryTags:
+    """Tests for query tags support in DatabricksSqlOperator."""
+
+    def 
test_get_airflow_query_tags_returns_empty_dict_without_task_instance(self):
+        """_get_airflow_query_tags must return {} when context has no 'ti' 
key."""
+        result = _get_airflow_query_tags({})
+        assert result == {}
+
+    def test_get_query_tags_with_none_context_returns_custom_tags_only(self):
+        """When context is None, only custom tags are returned (no Airflow 
tags)."""
+        op = DatabricksSqlOperator(
+            task_id=TASK_ID,
+            sql="SELECT 1",
+            query_tags={"custom_tag": "custom_value"},
+        )
+        result = op._get_query_tags(None)
+        assert result == {"custom_tag": "custom_value"}
+
+    def 
test_get_query_tags_with_none_context_and_no_custom_tags_returns_none(self):
+        """When context is None and no custom tags, None is returned."""
+        op = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1")
+        result = op._get_query_tags(None)
+        assert result is None
+
+    def test_get_query_tags_with_disabled_airflow_tags(self):
+        """When include_airflow_query_tags=False, only custom tags are 
returned."""
+        op = DatabricksSqlOperator(
+            task_id=TASK_ID,
+            sql="SELECT 1",
+            query_tags={"custom_tag": "val"},
+            include_airflow_query_tags=False,
+        )
+        mock_context = {"ti": object()}
+        result = op._get_query_tags(mock_context)
+        assert result == {"custom_tag": "val"}
+
+    def test_get_query_tags_with_airflow_context(self):
+        """When context is provided and include_airflow_query_tags=True, 
Airflow tags are included."""
+        op = DatabricksSqlOperator(
+            task_id=TASK_ID,
+            sql="SELECT 1",
+            query_tags={"custom_tag": "custom_value"},
+        )
+        mock_ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id", 
"try_number", "map_index"])
+        mock_ti.dag_id = "test_dag"
+        mock_ti.task_id = "test_task"
+        mock_ti.run_id = "test_run"
+        mock_ti.try_number = 1
+        mock_ti.map_index = -1
+        mock_context = {"ti": mock_ti}
+
+        result = op._get_query_tags(mock_context)
+
+        assert result is not None
+        assert result["airflow_dag_id"] == "test_dag"
+        assert result["airflow_task_id"] == "test_task"
+        assert result["airflow_run_id"] == "test_run"
+        assert result["airflow_try_number"] == "1"
+        assert result["airflow_map_index"] == "-1"
+        assert result["custom_tag"] == "custom_value"
+
+    def test_execute_sets_query_tags_on_hook(self):
+        """execute() sets query_tags on the hook before delegating to 
SQLExecuteQueryOperator."""
+        with 
patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
 as mock_cls:
+            mock_hook = mock_cls.return_value
+            mock_hook.run.return_value = []
+            mock_hook.descriptions = [[]]
+
+            op = DatabricksSqlOperator(
+                task_id=TASK_ID,
+                sql="SELECT 1",
+                query_tags={"env": "test"},
+                include_airflow_query_tags=False,
+            )
+
+            op.execute(None)
+
+            assert mock_hook.query_tags == {"env": "test"}
+
+    def test_custom_tags_override_airflow_tags_on_key_collision(self):
+        """Custom query_tags override Airflow tags when the same key is 
used."""
+        op = DatabricksSqlOperator(
+            task_id=TASK_ID,
+            sql="SELECT 1",
+            query_tags={"airflow_dag_id": "overridden"},
+        )
+        mock_ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id", 
"try_number", "map_index"])
+        mock_ti.dag_id = "original_dag"
+        mock_ti.task_id = "task"
+        mock_ti.run_id = "run"
+        mock_ti.try_number = 1
+        mock_ti.map_index = -1
+        mock_context = {"ti": mock_ti}
+
+        result = op._get_query_tags(mock_context)
+
+        assert result is not None
+        assert result["airflow_dag_id"] == "overridden"
+
+
+class TestDatabricksCopyIntoOperatorQueryTags:
+    """Tests for query tags support in DatabricksCopyIntoOperator."""
+
+    def _make_op(self, **kwargs):
+        from airflow.providers.databricks.operators.databricks_sql import 
DatabricksCopyIntoOperator
+
+        return DatabricksCopyIntoOperator(
+            task_id=TASK_ID,
+            table_name="test_table",
+            file_location="s3://bucket/path",
+            file_format="CSV",
+            **kwargs,
+        )
+
+    def test_get_query_tags_with_none_context_returns_custom_tags_only(self):
+        op = self._make_op(query_tags={"custom": "value"})
+        result = op._get_query_tags(None)
+        assert result == {"custom": "value"}
+
+    def 
test_get_query_tags_with_none_context_and_no_custom_tags_returns_none(self):
+        op = self._make_op()
+        result = op._get_query_tags(None)
+        assert result is None
+
+    def test_get_query_tags_with_airflow_context(self):
+        op = self._make_op(query_tags={"env": "staging"})
+        mock_ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id", 
"try_number", "map_index"])
+        mock_ti.dag_id = "copy_dag"
+        mock_ti.task_id = "copy_task"
+        mock_ti.run_id = "run_1"
+        mock_ti.try_number = 2
+        mock_ti.map_index = 0
+        mock_context = {"ti": mock_ti}
+
+        result = op._get_query_tags(mock_context)
+
+        assert result is not None
+        assert result["airflow_dag_id"] == "copy_dag"
+        assert result["airflow_task_id"] == "copy_task"
+        assert result["env"] == "staging"
+
+    def test_execute_sets_query_tags_on_hook(self):
+        from airflow.providers.databricks.operators.databricks_sql import 
DatabricksCopyIntoOperator
+
+        with 
patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
 as mock_cls:
+            mock_hook = mock_cls.return_value
+            op = DatabricksCopyIntoOperator(
+                task_id=TASK_ID,
+                table_name="test_table",
+                file_location="s3://bucket/path",
+                file_format="CSV",
+                query_tags={"env": "prod"},
+                include_airflow_query_tags=False,
+            )
+            op.execute(None)
+
+            assert mock_hook.query_tags == {"env": "prod"}

Reply via email to