This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9a14e66a64c feat: Add explicit support for DatabricksHook to Ol helper 
(#52253)
9a14e66a64c is described below

commit 9a14e66a64cdec8c366bc0722b63239bb4a84c39
Author: Kacper Muda <[email protected]>
AuthorDate: Mon Jun 30 11:59:35 2025 +0200

    feat: Add explicit support for DatabricksHook to Ol helper (#52253)
---
 .../providers/databricks/hooks/databricks_sql.py   |  24 +-
 .../providers/databricks/utils/openlineage.py      | 143 +++--
 .../unit/databricks/hooks/test_databricks_sql.py   |  17 +-
 .../unit/databricks/utils/test_openlineage.py      | 634 ++++++++++++++++++++-
 4 files changed, 732 insertions(+), 86 deletions(-)

diff --git 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
index 76aea1dde54..0ace06abc99 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
@@ -344,10 +344,9 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
 
     def get_openlineage_database_specific_lineage(self, task_instance) -> 
OperatorLineage | None:
         """
-        Generate OpenLineage metadata for a Databricks task instance based on 
executed query IDs.
+        Emit separate OpenLineage events for each Databricks query, based on 
executed query IDs.
 
-        If a single query ID is present, attach an `ExternalQueryRunFacet` to 
the lineage metadata.
-        If multiple query IDs are present, emits separate OpenLineage events 
for each query instead.
+        If a single query ID is present, also add an `ExternalQueryRunFacet` 
to the returned lineage metadata.
 
         Note that `get_openlineage_database_specific_lineage` is usually 
called after task's execution,
         so if multiple query IDs are present, both START and COMPLETE event 
for each query will be emitted
@@ -368,13 +367,22 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
         from airflow.providers.openlineage.sqlparser import SQLParser
 
         if not self.query_ids:
-            self.log.debug("openlineage: no databricks query ids found.")
+            self.log.info("OpenLineage could not find databricks query ids.")
             return None
 
         self.log.debug("openlineage: getting connection to get database info")
         connection = self.get_connection(self.get_conn_id())
         namespace = 
SQLParser.create_namespace(self.get_openlineage_database_info(connection))
 
+        self.log.info("Separate OpenLineage events will be emitted for each 
Databricks query_id.")
+        emit_openlineage_events_for_databricks_queries(
+            task_instance=task_instance,
+            hook=self,
+            query_ids=self.query_ids,
+            query_for_extra_metadata=True,
+            query_source_namespace=namespace,
+        )
+
         if len(self.query_ids) == 1:
             self.log.debug("Attaching ExternalQueryRunFacet with single 
query_id to OpenLineage event.")
             return OperatorLineage(
@@ -385,12 +393,4 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
                 }
             )
 
-        self.log.info("Multiple query_ids found. Separate OpenLineage event 
will be emitted for each query.")
-        emit_openlineage_events_for_databricks_queries(
-            query_ids=self.query_ids,
-            query_source_namespace=namespace,
-            task_instance=task_instance,
-            hook=self,
-        )
-
         return None
diff --git 
a/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py 
b/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py
index 5d2fe399703..57c964d9842 100644
--- a/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py
+++ b/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
     from openlineage.client.event_v2 import RunEvent
     from openlineage.client.facet_v2 import JobFacet
 
+    from airflow.providers.databricks.hooks.databricks import DatabricksHook
     from airflow.providers.databricks.hooks.databricks_sql import 
DatabricksSqlHook
 
 log = logging.getLogger(__name__)
@@ -121,20 +122,18 @@ def _get_parent_run_facet(task_instance):
     )
 
 
-def _run_api_call(hook: DatabricksSqlHook, query_ids: list[str]) -> list[dict]:
+def _run_api_call(hook: DatabricksSqlHook | DatabricksHook, query_ids: 
list[str]) -> list[dict]:
     """Retrieve execution details for specific queries from Databricks's query 
history API."""
-    if not hook._token:
-        # This has logic for token initialization
-        hook.get_conn()
-
-    # https://docs.databricks.com/api/azure/workspace/queryhistory/list
     try:
+        token = hook._get_token(raise_error=True)
+        # https://docs.databricks.com/api/azure/workspace/queryhistory/list
         response = requests.get(
             url=f"https://{hook.host}/api/2.0/sql/history/queries";,
-            headers={"Authorization": f"Bearer {hook._token}"},
+            headers={"Authorization": f"Bearer {token}"},
             data=json.dumps({"filter_by": {"statement_ids": query_ids}}),
             timeout=2,
         )
+        response.raise_for_status()
     except Exception as e:
         log.warning(
             "OpenLineage could not retrieve Databricks queries details. Error 
received: `%s`.",
@@ -142,48 +141,42 @@ def _run_api_call(hook: DatabricksSqlHook, query_ids: 
list[str]) -> list[dict]:
         )
         return []
 
-    if response.status_code != 200:
-        log.warning(
-            "OpenLineage could not retrieve Databricks queries details. API 
error received: `%s`: `%s`",
-            response.status_code,
-            response.text,
-        )
-        return []
-
     return response.json()["res"]
 
 
+def _process_data_from_api(data: list[dict[str, Any]]) -> list[dict[str, Any]]:
+    """Convert timestamp fields to UTC datetime objects."""
+    for row in data:
+        for key in ("query_start_time_ms", "query_end_time_ms"):
+            row[key] = datetime.datetime.fromtimestamp(row[key] / 1000, 
tz=datetime.timezone.utc)
+
+    return data
+
+
 def _get_queries_details_from_databricks(
-    hook: DatabricksSqlHook, query_ids: list[str]
+    hook: DatabricksSqlHook | DatabricksHook, query_ids: list[str]
 ) -> dict[str, dict[str, Any]]:
     if not query_ids:
         return {}
 
-    queries_info_from_api = _run_api_call(hook=hook, query_ids=query_ids)
-
     query_details = {}
-    for query_info in queries_info_from_api:
-        if not query_info.get("query_id"):
-            log.debug("Databricks query ID not found in API response.")
-            continue
-
-        q_start_time = None
-        q_end_time = None
-        if query_info.get("query_start_time_ms") and 
query_info.get("query_end_time_ms"):
-            q_start_time = datetime.datetime.fromtimestamp(
-                query_info["query_start_time_ms"] / 1000, 
tz=datetime.timezone.utc
-            )
-            q_end_time = datetime.datetime.fromtimestamp(
-                query_info["query_end_time_ms"] / 1000, 
tz=datetime.timezone.utc
-            )
-
-        query_details[query_info["query_id"]] = {
-            "status": query_info.get("status"),
-            "start_time": q_start_time,
-            "end_time": q_end_time,
-            "query_text": query_info.get("query_text"),
-            "error_message": query_info.get("error_message"),
+    try:
+        queries_info_from_api = _run_api_call(hook=hook, query_ids=query_ids)
+        queries_info_from_api = _process_data_from_api(queries_info_from_api)
+
+        query_details = {
+            query_info["query_id"]: {
+                "status": query_info.get("status"),
+                "start_time": query_info.get("query_start_time_ms"),
+                "end_time": query_info.get("query_end_time_ms"),
+                "query_text": query_info.get("query_text"),
+                "error_message": query_info.get("error_message"),
+            }
+            for query_info in queries_info_from_api
+            if query_info["query_id"]
         }
+    except Exception as e:
+        log.warning("OpenLineage could not retrieve extra metadata from 
Databricks. Error encountered: %s", e)
 
     return query_details
 
@@ -221,17 +214,18 @@ def _create_ol_event_pair(
 
 @require_openlineage_version(provider_min_version="2.3.0")
 def emit_openlineage_events_for_databricks_queries(
-    query_ids: list[str],
-    query_source_namespace: str,
     task_instance,
-    hook: DatabricksSqlHook | None = None,
+    hook: DatabricksSqlHook | DatabricksHook | None = None,
+    query_ids: list[str] | None = None,
+    query_source_namespace: str | None = None,
+    query_for_extra_metadata: bool = False,
     additional_run_facets: dict | None = None,
     additional_job_facets: dict | None = None,
 ) -> None:
     """
     Emit OpenLineage events for executed Databricks queries.
 
-    Metadata retrieval from Databricks is attempted only if a 
`DatabricksSqlHook` is provided.
+    Metadata retrieval from Databricks is attempted only if 
`get_extra_metadata` is True and hook is provided.
     If metadata is available, execution details such as start time, end time, 
execution status,
     error messages, and SQL text are included in the events. If no metadata is 
found, the function
     defaults to using the Airflow task instance's state and the current 
timestamp.
@@ -241,10 +235,16 @@ def emit_openlineage_events_for_databricks_queries(
     will correspond to actual query execution times.
 
     Args:
-        query_ids: A list of Databricks query IDs to emit events for.
-        query_source_namespace: The namespace to be included in 
ExternalQueryRunFacet.
         task_instance: The Airflow task instance that run these queries.
-        hook: A hook instance used to retrieve query metadata if available.
+        hook: A supported Databricks hook instance used to retrieve query 
metadata if available.
+        If omitted, `query_ids` and `query_source_namespace` must be provided 
explicitly and
+        `query_for_extra_metadata` must be `False`.
+        query_ids: A list of Databricks query IDs to emit events for, can only 
be None if `hook` is provided
+        and `hook.query_ids` are present (DatabricksHook does not store 
query_ids).
+        query_source_namespace: The namespace to be included in 
ExternalQueryRunFacet,
+        can be `None` only if hook is provided.
+        query_for_extra_metadata: Whether to query Databricks for additional 
metadata about queries.
+        Must be `False` if `hook` is not provided.
         additional_run_facets: Additional run facets to include in OpenLineage 
events.
         additional_job_facets: Additional job facets to include in OpenLineage 
events.
     """
@@ -259,25 +259,52 @@ def emit_openlineage_events_for_databricks_queries(
     from airflow.providers.openlineage.conf import namespace
     from airflow.providers.openlineage.plugins.listener import 
get_openlineage_listener
 
-    if not query_ids:
-        log.debug("No Databricks query IDs provided; skipping OpenLineage 
event emission.")
-        return
-
-    query_ids = [q for q in query_ids]  # Make a copy to make sure it does not 
change
+    log.info("OpenLineage will emit events for Databricks queries.")
 
     if hook:
+        if not query_ids:
+            log.debug("No Databricks query IDs provided; Checking 
`hook.query_ids` property.")
+            query_ids = getattr(hook, "query_ids", [])
+            if not query_ids:
+                raise ValueError("No Databricks query IDs provided and 
`hook.query_ids` are not present.")
+
+        if not query_source_namespace:
+            log.debug("No Databricks query namespace provided; Creating one 
from scratch.")
+
+            if hasattr(hook, "get_openlineage_database_info") and 
hasattr(hook, "get_conn_id"):
+                from airflow.providers.openlineage.sqlparser import SQLParser
+
+                query_source_namespace = SQLParser.create_namespace(
+                    
hook.get_openlineage_database_info(hook.get_connection(hook.get_conn_id()))
+                )
+            else:
+                query_source_namespace = f"databricks://{hook.host}" if 
hook.host else "databricks"
+    else:
+        if not query_ids:
+            raise ValueError("If 'hook' is not provided, 'query_ids' must be 
set.")
+        if not query_source_namespace:
+            raise ValueError("If 'hook' is not provided, 
'query_source_namespace' must be set.")
+        if query_for_extra_metadata:
+            raise ValueError("If 'hook' is not provided, 
'query_for_extra_metadata' must be False.")
+
+    query_ids = [q for q in query_ids]  # Make a copy to make sure we do not 
change hook's attribute
+
+    if query_for_extra_metadata and hook:
         log.debug("Retrieving metadata for %s queries from Databricks.", 
len(query_ids))
         databricks_metadata = _get_queries_details_from_databricks(hook, 
query_ids)
     else:
-        log.debug("DatabricksSqlHook not provided. No extra metadata fill be 
fetched from Databricks.")
+        log.debug("`query_for_extra_metadata` is False. No extra metadata fill 
be fetched from Databricks.")
         databricks_metadata = {}
 
     # If real metadata is unavailable, we send events with eventTime=now
     default_event_time = timezone.utcnow()
-    # If no query metadata is provided, we use task_instance's state when 
checking for success
+    # ti.state has no `value` attr (AF2) when task it's still running, in AF3 
we get 'running', in that case
+    # assuming it's user call and query succeeded, so we replace it with 
success.
     # Adjust state for DBX logic, where "finished" means "success"
-    default_state = task_instance.state.value if hasattr(task_instance, 
"state") else ""
-    default_state = "finished" if default_state == "success" else default_state
+    default_state = (
+        getattr(task_instance.state, "value", "running") if 
hasattr(task_instance, "state") else ""
+    )
+    default_state = "finished" if default_state in ("running", "success") else 
default_state
 
     log.debug("Generating OpenLineage facets")
     common_run_facets = {"parent": _get_parent_run_facet(task_instance)}
@@ -318,10 +345,10 @@ def emit_openlineage_events_for_databricks_queries(
         event_batch = _create_ol_event_pair(
             job_namespace=namespace(),
             
job_name=f"{task_instance.dag_id}.{task_instance.task_id}.query.{counter}",
-            start_time=query_metadata.get("start_time", default_event_time),  
# type: ignore[arg-type]
-            end_time=query_metadata.get("end_time", default_event_time),  # 
type: ignore[arg-type]
+            start_time=query_metadata.get("start_time") or default_event_time, 
 # type: ignore[arg-type]
+            end_time=query_metadata.get("end_time") or default_event_time,  # 
type: ignore[arg-type]
             # Only finished status means it completed without failures
-            is_successful=query_metadata.get("status", default_state).lower() 
== "finished",
+            is_successful=(query_metadata.get("status") or 
default_state).lower() == "finished",
             run_facets={**query_specific_run_facets, **common_run_facets, 
**additional_run_facets},
             job_facets={**query_specific_job_facets, **common_job_facets, 
**additional_job_facets},
         )
diff --git 
a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py 
b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
index 7449489c77d..20e2c081f74 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
@@ -457,7 +457,8 @@ def 
test_get_openlineage_database_specific_lineage_with_no_query_id():
     assert result is None
 
 
-def test_get_openlineage_database_specific_lineage_with_single_query_id():
[email protected]("airflow.providers.databricks.utils.openlineage.emit_openlineage_events_for_databricks_queries")
+def 
test_get_openlineage_database_specific_lineage_with_single_query_id(mock_emit):
     from airflow.providers.common.compat.openlineage.facet import 
ExternalQueryRunFacet
     from airflow.providers.openlineage.extractors import OperatorLineage
 
@@ -466,7 +467,18 @@ def 
test_get_openlineage_database_specific_lineage_with_single_query_id():
     hook.get_connection = mock.MagicMock()
     hook.get_openlineage_database_info = lambda x: 
mock.MagicMock(authority="auth", scheme="scheme")
 
-    result = hook.get_openlineage_database_specific_lineage(None)
+    ti = mock.MagicMock()
+
+    result = hook.get_openlineage_database_specific_lineage(ti)
+    mock_emit.assert_called_once_with(
+        **{
+            "hook": hook,
+            "query_ids": ["query1"],
+            "query_source_namespace": "scheme://auth",
+            "task_instance": ti,
+            "query_for_extra_metadata": True,
+        }
+    )
     assert result == OperatorLineage(
         run_facets={"externalQuery": 
ExternalQueryRunFacet(externalQueryId="query1", source="scheme://auth")}
     )
@@ -488,6 +500,7 @@ def 
test_get_openlineage_database_specific_lineage_with_multiple_query_ids(mock_
             "query_ids": ["query1", "query2"],
             "query_source_namespace": "scheme://auth",
             "task_instance": ti,
+            "query_for_extra_metadata": True,
         }
     )
     assert result is None
diff --git 
a/providers/databricks/tests/unit/databricks/utils/test_openlineage.py 
b/providers/databricks/tests/unit/databricks/utils/test_openlineage.py
index 699e018785e..6d427e0ba77 100644
--- a/providers/databricks/tests/unit/databricks/utils/test_openlineage.py
+++ b/providers/databricks/tests/unit/databricks/utils/test_openlineage.py
@@ -30,11 +30,14 @@ from airflow.providers.common.compat.openlineage.facet 
import (
     ExternalQueryRunFacet,
     SQLJobFacet,
 )
+from airflow.providers.databricks.hooks.databricks import DatabricksHook
+from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
 from airflow.providers.databricks.utils.openlineage import (
     _create_ol_event_pair,
     _get_ol_run_id,
     _get_parent_run_facet,
     _get_queries_details_from_databricks,
+    _process_data_from_api,
     _run_api_call,
     emit_openlineage_events_for_databricks_queries,
 )
@@ -96,7 +99,7 @@ def test_get_parent_run_facet():
 
 def test_run_api_call_success():
     mock_hook = mock.MagicMock()
-    mock_hook._token = "mock_token"
+    mock_hook._get_token.return_value = "mock_token"
     mock_hook.host = "mock_host"
 
     mock_response = mock.MagicMock()
@@ -109,14 +112,27 @@ def test_run_api_call_success():
     assert result == [{"query_id": "123", "status": "success"}]
 
 
-def test_run_api_call_error():
+def test_run_api_call_request_error():
     mock_hook = mock.MagicMock()
-    mock_hook._token = "mock_token"
+    mock_hook._get_token.return_value = "mock_token"
     mock_hook.host = "mock_host"
 
     mock_response = mock.MagicMock()
-    mock_response.status_code = 500
-    mock_response.text = "Internal Server Error"
+    mock_response.status_code = 200
+
+    with mock.patch("requests.get", side_effect=RuntimeError("request error")):
+        result = _run_api_call(mock_hook, ["123"])
+
+    assert result == []
+
+
+def test_run_api_call_token_error():
+    mock_hook = mock.MagicMock()
+    mock_hook._get_token.side_effect = RuntimeError("Token error")
+    mock_hook.host = "mock_host"
+
+    mock_response = mock.MagicMock()
+    mock_response.status_code = 200
 
     with mock.patch("requests.get", return_value=mock_response):
         result = _run_api_call(mock_hook, ["123"])
@@ -124,6 +140,55 @@ def test_run_api_call_error():
     assert result == []
 
 
+def test_process_data_from_api():
+    data = [
+        {
+            "query_id": "ABC",
+            "status": "FINISHED",
+            "query_start_time_ms": 1595357086200,
+            "query_end_time_ms": 1595357087200,
+            "query_text": "SELECT * FROM table1;",
+            "error_message": "Error occurred",
+        },
+        {
+            "query_id": "DEF",
+            "query_start_time_ms": 1595357086200,
+            "query_end_time_ms": 1595357087200,
+        },
+    ]
+    expected_details = [
+        {
+            "query_id": "ABC",
+            "status": "FINISHED",
+            "query_start_time_ms": datetime.datetime(
+                2020, 7, 21, 18, 44, 46, 200000, tzinfo=datetime.timezone.utc
+            ),
+            "query_end_time_ms": datetime.datetime(
+                2020, 7, 21, 18, 44, 47, 200000, tzinfo=datetime.timezone.utc
+            ),
+            "query_text": "SELECT * FROM table1;",
+            "error_message": "Error occurred",
+        },
+        {
+            "query_id": "DEF",
+            "query_start_time_ms": datetime.datetime(
+                2020, 7, 21, 18, 44, 46, 200000, tzinfo=datetime.timezone.utc
+            ),
+            "query_end_time_ms": datetime.datetime(
+                2020, 7, 21, 18, 44, 47, 200000, tzinfo=datetime.timezone.utc
+            ),
+        },
+    ]
+    result = _process_data_from_api(data=data)
+    assert len(result) == 2
+    assert result == expected_details
+
+
+def test_process_data_from_api_error():
+    with pytest.raises(KeyError):
+        _process_data_from_api(data=[{"query_start_time_ms": 1595357086200}])
+
+
 def test_get_queries_details_from_databricks_empty_query_ids():
     details = _get_queries_details_from_databricks(None, [])
     assert details == {}
@@ -131,7 +196,7 @@ def 
test_get_queries_details_from_databricks_empty_query_ids():
 
 @mock.patch("airflow.providers.databricks.utils.openlineage._run_api_call")
 def test_get_queries_details_from_databricks(mock_api_call):
-    hook = mock.MagicMock()
+    hook = DatabricksSqlHook()
     query_ids = ["ABC"]
     fake_result = [
         {
@@ -160,7 +225,7 @@ def test_get_queries_details_from_databricks(mock_api_call):
 
 @mock.patch("airflow.providers.databricks.utils.openlineage._run_api_call")
 def test_get_queries_details_from_databricks_no_data_found(mock_api_call):
-    hook = mock.MagicMock()
+    hook = DatabricksSqlHook()
     query_ids = ["ABC", "DEF"]
     mock_api_call.return_value = []
 
@@ -274,6 +339,7 @@ def 
test_emit_openlineage_events_for_databricks_queries(mock_now, mock_generate_
             query_source_namespace="databricks_ns",
             task_instance=mock_ti,
             hook=mock.MagicMock(),
+            query_for_extra_metadata=True,
             additional_run_facets=additional_run_facets,
             additional_job_facets=additional_job_facets,
         )
@@ -448,7 +514,7 @@ def 
test_emit_openlineage_events_for_databricks_queries(mock_now, mock_generate_
 @mock.patch("importlib.metadata.version", return_value="2.3.0")
 @mock.patch("openlineage.client.uuid.generate_new_uuid")
 @mock.patch("airflow.utils.timezone.utcnow")
-def test_emit_openlineage_events_for_databricks_queries_without_metadata_found(
+def test_emit_openlineage_events_for_databricks_queries_without_metadata(
     mock_now, mock_generate_uuid, mock_version
 ):
     fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
@@ -489,7 +555,8 @@ def 
test_emit_openlineage_events_for_databricks_queries_without_metadata_found(
             query_ids=query_ids,
             query_source_namespace="databricks_ns",
             task_instance=mock_ti,
-            hook=None,  # None so metadata retrieval is not triggered
+            hook=mock.MagicMock(),
+            # query_for_extra_metadata=False,  # False by default
             additional_run_facets=additional_run_facets,
             additional_job_facets=additional_job_facets,
         )
@@ -564,9 +631,37 @@ def 
test_emit_openlineage_events_for_databricks_queries_without_metadata_found(
 
 
 @mock.patch("importlib.metadata.version", return_value="2.3.0")
-def 
test_emit_openlineage_events_for_databricks_queries_without_query_ids(mock_version):
-    query_ids = []
[email protected]("openlineage.client.uuid.generate_new_uuid")
[email protected]("airflow.utils.timezone.utcnow")
+def 
test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids(
+    mock_now, mock_generate_uuid, mock_version
+):
+    fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
+    mock_generate_uuid.return_value = fake_uuid
+
+    default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0)
+    mock_now.return_value = default_event_time
+
+    query_ids = ["query1"]
+    hook = mock.MagicMock()
+    hook.query_ids = query_ids
     original_query_ids = copy.deepcopy(query_ids)
+    logical_date = timezone.datetime(2025, 1, 1)
+    mock_ti = mock.MagicMock(
+        dag_id="dag_id",
+        task_id="task_id",
+        map_index=1,
+        try_number=1,
+        logical_date=logical_date,
+        state=TaskInstanceState.RUNNING,  # This will be query default state 
if no metadata found
+        dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0),
+    )
+    mock_ti.get_template_context.return_value = {
+        "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0)
+    }
+
+    additional_run_facets = {"custom_run": "value_run"}
+    additional_job_facets = {"custom_job": "value_job"}
 
     fake_adapter = mock.MagicMock()
     fake_adapter.emit = mock.MagicMock()
@@ -578,16 +673,527 @@ def 
test_emit_openlineage_events_for_databricks_queries_without_query_ids(mock_v
         return_value=fake_listener,
     ):
         emit_openlineage_events_for_databricks_queries(
-            query_ids=query_ids,
             query_source_namespace="databricks_ns",
-            task_instance=None,
+            task_instance=mock_ti,
+            hook=hook,
+            # query_for_extra_metadata=False,  # False by default
+            additional_run_facets=additional_run_facets,
+            additional_job_facets=additional_job_facets,
+        )
+
+        assert query_ids == original_query_ids  # Verify that the input 
query_ids list is unchanged.
+        assert fake_adapter.emit.call_count == 2  # Expect two events per 
query.
+
+        expected_common_job_facets = {
+            "jobType": job_type_job.JobTypeJobFacet(
+                jobType="QUERY",
+                processingType="BATCH",
+                integration="DATABRICKS",
+            ),
+            "custom_job": "value_job",
+        }
+        expected_common_run_facets = {
+            "parent": parent_run.ParentRunFacet(
+                
run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"),
+                job=parent_run.Job(namespace=namespace(), 
name="dag_id.task_id"),
+                root=parent_run.Root(
+                    
run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"),
+                    job=parent_run.RootJob(namespace=namespace(), 
name="dag_id"),
+                ),
+            ),
+            "custom_run": "value_run",
+        }
+
+        expected_calls = [
+            mock.call(  # Query1: START event (no metadata)
+                RunEvent(
+                    eventTime=default_event_time.isoformat(),
+                    eventType=RunState.START,
+                    run=Run(
+                        runId=fake_uuid,
+                        facets={
+                            "externalQuery": ExternalQueryRunFacet(
+                                externalQueryId="query1", 
source="databricks_ns"
+                            ),
+                            **expected_common_run_facets,
+                        },
+                    ),
+                    job=Job(
+                        namespace=namespace(),
+                        name="dag_id.task_id.query.1",
+                        facets=expected_common_job_facets,
+                    ),
+                )
+            ),
+            mock.call(  # Query1: COMPLETE event (no metadata)
+                RunEvent(
+                    eventTime=default_event_time.isoformat(),
+                    eventType=RunState.COMPLETE,
+                    run=Run(
+                        runId=fake_uuid,
+                        facets={
+                            "externalQuery": ExternalQueryRunFacet(
+                                externalQueryId="query1", 
source="databricks_ns"
+                            ),
+                            **expected_common_run_facets,
+                        },
+                    ),
+                    job=Job(
+                        namespace=namespace(),
+                        name="dag_id.task_id.query.1",
+                        facets=expected_common_job_facets,
+                    ),
+                )
+            ),
+        ]
+
+        assert fake_adapter.emit.call_args_list == expected_calls
+
+
[email protected](
+    "airflow.providers.openlineage.sqlparser.SQLParser.create_namespace", 
return_value="databricks_ns"
+)
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("openlineage.client.uuid.generate_new_uuid")
[email protected]("airflow.utils.timezone.utcnow")
+def 
test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids_and_namespace(
+    mock_now, mock_generate_uuid, mock_version, mock_parser
+):
+    fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
+    mock_generate_uuid.return_value = fake_uuid
+
+    default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0)
+    mock_now.return_value = default_event_time
+
+    query_ids = ["query1"]
+    hook = mock.MagicMock()
+    hook.query_ids = query_ids
+    original_query_ids = copy.deepcopy(query_ids)
+    logical_date = timezone.datetime(2025, 1, 1)
+    mock_ti = mock.MagicMock(
+        dag_id="dag_id",
+        task_id="task_id",
+        map_index=1,
+        try_number=1,
+        logical_date=logical_date,
+        state=TaskInstanceState.RUNNING,  # This will be query default state 
if no metadata found
+        dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0),
+    )
+    mock_ti.get_template_context.return_value = {
+        "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0)
+    }
+
+    additional_run_facets = {"custom_run": "value_run"}
+    additional_job_facets = {"custom_job": "value_job"}
+
+    fake_adapter = mock.MagicMock()
+    fake_adapter.emit = mock.MagicMock()
+    fake_listener = mock.MagicMock()
+    fake_listener.adapter = fake_adapter
+
+    with mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.get_openlineage_listener",
+        return_value=fake_listener,
+    ):
+        emit_openlineage_events_for_databricks_queries(
+            task_instance=mock_ti,
+            hook=hook,
+            # query_for_extra_metadata=False,  # False by default
+            additional_run_facets=additional_run_facets,
+            additional_job_facets=additional_job_facets,
         )
 
+        assert query_ids == original_query_ids  # Verify that the input 
query_ids list is unchanged.
+        assert fake_adapter.emit.call_count == 2  # Expect two events per 
query.
+
+        expected_common_job_facets = {
+            "jobType": job_type_job.JobTypeJobFacet(
+                jobType="QUERY",
+                processingType="BATCH",
+                integration="DATABRICKS",
+            ),
+            "custom_job": "value_job",
+        }
+        expected_common_run_facets = {
+            "parent": parent_run.ParentRunFacet(
+                
run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"),
+                job=parent_run.Job(namespace=namespace(), 
name="dag_id.task_id"),
+                root=parent_run.Root(
+                    
run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"),
+                    job=parent_run.RootJob(namespace=namespace(), 
name="dag_id"),
+                ),
+            ),
+            "custom_run": "value_run",
+        }
+
+        expected_calls = [
+            mock.call(  # Query1: START event (no metadata)
+                RunEvent(
+                    eventTime=default_event_time.isoformat(),
+                    eventType=RunState.START,
+                    run=Run(
+                        runId=fake_uuid,
+                        facets={
+                            "externalQuery": ExternalQueryRunFacet(
+                                externalQueryId="query1", 
source="databricks_ns"
+                            ),
+                            **expected_common_run_facets,
+                        },
+                    ),
+                    job=Job(
+                        namespace=namespace(),
+                        name="dag_id.task_id.query.1",
+                        facets=expected_common_job_facets,
+                    ),
+                )
+            ),
+            mock.call(  # Query1: COMPLETE event (no metadata)
+                RunEvent(
+                    eventTime=default_event_time.isoformat(),
+                    eventType=RunState.COMPLETE,
+                    run=Run(
+                        runId=fake_uuid,
+                        facets={
+                            "externalQuery": ExternalQueryRunFacet(
+                                externalQueryId="query1", 
source="databricks_ns"
+                            ),
+                            **expected_common_run_facets,
+                        },
+                    ),
+                    job=Job(
+                        namespace=namespace(),
+                        name="dag_id.task_id.query.1",
+                        facets=expected_common_job_facets,
+                    ),
+                )
+            ),
+        ]
+
+        assert fake_adapter.emit.call_args_list == expected_calls
+
+
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("openlineage.client.uuid.generate_new_uuid")
[email protected]("airflow.utils.timezone.utcnow")
+def 
test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids_and_namespace_raw_ns(
+    mock_now, mock_generate_uuid, mock_version
+):
+    fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
+    mock_generate_uuid.return_value = fake_uuid
+
+    default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0)
+    mock_now.return_value = default_event_time
+
+    query_ids = ["query1"]
+    hook = DatabricksHook()
+    hook.query_ids = query_ids
+    hook.host = "some_host"
+    original_query_ids = copy.deepcopy(query_ids)
+    logical_date = timezone.datetime(2025, 1, 1)
+    mock_ti = mock.MagicMock(
+        dag_id="dag_id",
+        task_id="task_id",
+        map_index=1,
+        try_number=1,
+        logical_date=logical_date,
+        state=TaskInstanceState.RUNNING,  # This will be query default state 
if no metadata found
+        dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0),
+    )
+    mock_ti.get_template_context.return_value = {
+        "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0)
+    }
+
+    additional_run_facets = {"custom_run": "value_run"}
+    additional_job_facets = {"custom_job": "value_job"}
+
+    fake_adapter = mock.MagicMock()
+    fake_adapter.emit = mock.MagicMock()
+    fake_listener = mock.MagicMock()
+    fake_listener.adapter = fake_adapter
+
+    with mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.get_openlineage_listener",
+        return_value=fake_listener,
+    ):
+        emit_openlineage_events_for_databricks_queries(
+            task_instance=mock_ti,
+            hook=hook,
+            # query_for_extra_metadata=False,  # False by default
+            additional_run_facets=additional_run_facets,
+            additional_job_facets=additional_job_facets,
+        )
+
+        assert query_ids == original_query_ids  # Verify that the input 
query_ids list is unchanged.
+        assert fake_adapter.emit.call_count == 2  # Expect two events per 
query.
+
+        expected_common_job_facets = {
+            "jobType": job_type_job.JobTypeJobFacet(
+                jobType="QUERY",
+                processingType="BATCH",
+                integration="DATABRICKS",
+            ),
+            "custom_job": "value_job",
+        }
+        expected_common_run_facets = {
+            "parent": parent_run.ParentRunFacet(
+                
run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"),
+                job=parent_run.Job(namespace=namespace(), 
name="dag_id.task_id"),
+                root=parent_run.Root(
+                    
run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"),
+                    job=parent_run.RootJob(namespace=namespace(), 
name="dag_id"),
+                ),
+            ),
+            "custom_run": "value_run",
+        }
+
+        expected_calls = [
+            mock.call(  # Query1: START event (no metadata)
+                RunEvent(
+                    eventTime=default_event_time.isoformat(),
+                    eventType=RunState.START,
+                    run=Run(
+                        runId=fake_uuid,
+                        facets={
+                            "externalQuery": ExternalQueryRunFacet(
+                                externalQueryId="query1", 
source="databricks://some_host"
+                            ),
+                            **expected_common_run_facets,
+                        },
+                    ),
+                    job=Job(
+                        namespace=namespace(),
+                        name="dag_id.task_id.query.1",
+                        facets=expected_common_job_facets,
+                    ),
+                )
+            ),
+            mock.call(  # Query1: COMPLETE event (no metadata)
+                RunEvent(
+                    eventTime=default_event_time.isoformat(),
+                    eventType=RunState.COMPLETE,
+                    run=Run(
+                        runId=fake_uuid,
+                        facets={
+                            "externalQuery": ExternalQueryRunFacet(
+                                externalQueryId="query1", 
source="databricks://some_host"
+                            ),
+                            **expected_common_run_facets,
+                        },
+                    ),
+                    job=Job(
+                        namespace=namespace(),
+                        name="dag_id.task_id.query.1",
+                        facets=expected_common_job_facets,
+                    ),
+                )
+            ),
+        ]
+
+        assert fake_adapter.emit.call_args_list == expected_calls
+
+
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("openlineage.client.uuid.generate_new_uuid")
[email protected]("airflow.utils.timezone.utcnow")
+def 
test_emit_openlineage_events_for_databricks_queries_ith_query_ids_and_hook_query_ids(
+    mock_now, mock_generate_uuid, mock_version
+):
+    fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
+    mock_generate_uuid.return_value = fake_uuid
+
+    default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0)
+    mock_now.return_value = default_event_time
+
+    hook = DatabricksSqlHook()
+    hook.query_ids = ["query2", "query3"]
+    query_ids = ["query1"]
+    original_query_ids = copy.deepcopy(query_ids)
+    logical_date = timezone.datetime(2025, 1, 1)
+    mock_ti = mock.MagicMock(
+        dag_id="dag_id",
+        task_id="task_id",
+        map_index=1,
+        try_number=1,
+        logical_date=logical_date,
+        state=TaskInstanceState.SUCCESS,  # This will be query default state 
if no metadata found
+        dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0),
+    )
+    mock_ti.get_template_context.return_value = {
+        "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0)
+    }
+
+    additional_run_facets = {"custom_run": "value_run"}
+    additional_job_facets = {"custom_job": "value_job"}
+
+    fake_adapter = mock.MagicMock()
+    fake_adapter.emit = mock.MagicMock()
+    fake_listener = mock.MagicMock()
+    fake_listener.adapter = fake_adapter
+
+    with mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.get_openlineage_listener",
+        return_value=fake_listener,
+    ):
+        emit_openlineage_events_for_databricks_queries(
+            query_ids=query_ids,
+            query_source_namespace="databricks_ns",
+            task_instance=mock_ti,
+            hook=hook,
+            # query_for_extra_metadata=False,  # False by default
+            additional_run_facets=additional_run_facets,
+            additional_job_facets=additional_job_facets,
+        )
+
+        assert query_ids == original_query_ids  # Verify that the input 
query_ids list is unchanged.
+        assert fake_adapter.emit.call_count == 2  # Expect two events per 
query.
+
+        expected_common_job_facets = {
+            "jobType": job_type_job.JobTypeJobFacet(
+                jobType="QUERY",
+                processingType="BATCH",
+                integration="DATABRICKS",
+            ),
+            "custom_job": "value_job",
+        }
+        expected_common_run_facets = {
+            "parent": parent_run.ParentRunFacet(
+                
run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"),
+                job=parent_run.Job(namespace=namespace(), 
name="dag_id.task_id"),
+                root=parent_run.Root(
+                    
run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"),
+                    job=parent_run.RootJob(namespace=namespace(), 
name="dag_id"),
+                ),
+            ),
+            "custom_run": "value_run",
+        }
+
+        expected_calls = [
+            mock.call(  # Query1: START event (no metadata)
+                RunEvent(
+                    eventTime=default_event_time.isoformat(),
+                    eventType=RunState.START,
+                    run=Run(
+                        runId=fake_uuid,
+                        facets={
+                            "externalQuery": ExternalQueryRunFacet(
+                                externalQueryId="query1", 
source="databricks_ns"
+                            ),
+                            **expected_common_run_facets,
+                        },
+                    ),
+                    job=Job(
+                        namespace=namespace(),
+                        name="dag_id.task_id.query.1",
+                        facets=expected_common_job_facets,
+                    ),
+                )
+            ),
+            mock.call(  # Query1: COMPLETE event (no metadata)
+                RunEvent(
+                    eventTime=default_event_time.isoformat(),
+                    eventType=RunState.COMPLETE,
+                    run=Run(
+                        runId=fake_uuid,
+                        facets={
+                            "externalQuery": ExternalQueryRunFacet(
+                                externalQueryId="query1", 
source="databricks_ns"
+                            ),
+                            **expected_common_run_facets,
+                        },
+                    ),
+                    job=Job(
+                        namespace=namespace(),
+                        name="dag_id.task_id.query.1",
+                        facets=expected_common_job_facets,
+                    ),
+                )
+            ),
+        ]
+
+        assert fake_adapter.emit.call_args_list == expected_calls
+
+
[email protected]("importlib.metadata.version", return_value="2.3.0")
+def 
test_emit_openlineage_events_for_databricks_queries_missing_query_ids_and_hook(mock_version):
+    query_ids = []
+    original_query_ids = copy.deepcopy(query_ids)
+
+    fake_adapter = mock.MagicMock()
+    fake_adapter.emit = mock.MagicMock()
+    fake_listener = mock.MagicMock()
+    fake_listener.adapter = fake_adapter
+
+    with mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.get_openlineage_listener",
+        return_value=fake_listener,
+    ):
+        with pytest.raises(ValueError, match="If 'hook' is not provided, 
'query_ids' must be set."):
+            emit_openlineage_events_for_databricks_queries(
+                query_ids=query_ids,
+                query_source_namespace="databricks_ns",
+                task_instance=None,
+            )
+
+        assert query_ids == original_query_ids  # Verify that the input 
query_ids list is unchanged.
+        fake_adapter.emit.assert_not_called()  # No events should be emitted
+
+
[email protected]("importlib.metadata.version", return_value="2.3.0")
+def 
test_emit_openlineage_events_for_databricks_queries_missing_query_namespace_and_hook(mock_version):
+    query_ids = ["1", "2"]
+    original_query_ids = copy.deepcopy(query_ids)
+
+    fake_adapter = mock.MagicMock()
+    fake_adapter.emit = mock.MagicMock()
+    fake_listener = mock.MagicMock()
+    fake_listener.adapter = fake_adapter
+
+    with mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.get_openlineage_listener",
+        return_value=fake_listener,
+    ):
+        with pytest.raises(
+            ValueError, match="If 'hook' is not provided, 
'query_source_namespace' must be set."
+        ):
+            emit_openlineage_events_for_databricks_queries(
+                query_ids=query_ids,
+                task_instance=None,
+            )
+
+        assert query_ids == original_query_ids  # Verify that the input 
query_ids list is unchanged.
+        fake_adapter.emit.assert_not_called()  # No events should be emitted
+
+
[email protected]("importlib.metadata.version", return_value="2.3.0")
+def 
test_emit_openlineage_events_for_databricks_queries_missing_hook_and_query_for_extra_metadata_true(
+    mock_version,
+):
+    query_ids = ["1", "2"]
+    original_query_ids = copy.deepcopy(query_ids)
+
+    fake_adapter = mock.MagicMock()
+    fake_adapter.emit = mock.MagicMock()
+    fake_listener = mock.MagicMock()
+    fake_listener.adapter = fake_adapter
+
+    with mock.patch(
+        
"airflow.providers.openlineage.plugins.listener.get_openlineage_listener",
+        return_value=fake_listener,
+    ):
+        with pytest.raises(
+            ValueError, match="If 'hook' is not provided, 
'query_for_extra_metadata' must be False."
+        ):
+            emit_openlineage_events_for_databricks_queries(
+                query_ids=query_ids,
+                query_source_namespace="databricks_ns",
+                task_instance=None,
+                query_for_extra_metadata=True,
+            )
+
         assert query_ids == original_query_ids  # Verify that the input 
query_ids list is unchanged.
         fake_adapter.emit.assert_not_called()  # No events should be emitted
 
 
-# emit_openlineage_events_for_databricks_queries requires OL provider 2.3.0
 @mock.patch("importlib.metadata.version", return_value="1.99.0")
 def test_emit_openlineage_events_with_old_openlineage_provider(mock_version):
     query_ids = ["q1", "q2"]


Reply via email to