This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9a14e66a64c feat: Add explicit support for DatabricksHook to Ol helper
(#52253)
9a14e66a64c is described below
commit 9a14e66a64cdec8c366bc0722b63239bb4a84c39
Author: Kacper Muda <[email protected]>
AuthorDate: Mon Jun 30 11:59:35 2025 +0200
feat: Add explicit support for DatabricksHook to Ol helper (#52253)
---
.../providers/databricks/hooks/databricks_sql.py | 24 +-
.../providers/databricks/utils/openlineage.py | 143 +++--
.../unit/databricks/hooks/test_databricks_sql.py | 17 +-
.../unit/databricks/utils/test_openlineage.py | 634 ++++++++++++++++++++-
4 files changed, 732 insertions(+), 86 deletions(-)
diff --git
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
index 76aea1dde54..0ace06abc99 100644
---
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
+++
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
@@ -344,10 +344,9 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
def get_openlineage_database_specific_lineage(self, task_instance) ->
OperatorLineage | None:
"""
- Generate OpenLineage metadata for a Databricks task instance based on
executed query IDs.
+ Emit separate OpenLineage events for each Databricks query, based on
executed query IDs.
- If a single query ID is present, attach an `ExternalQueryRunFacet` to
the lineage metadata.
- If multiple query IDs are present, emits separate OpenLineage events
for each query instead.
+ If a single query ID is present, also add an `ExternalQueryRunFacet`
to the returned lineage metadata.
Note that `get_openlineage_database_specific_lineage` is usually
called after task's execution,
so if multiple query IDs are present, both START and COMPLETE event
for each query will be emitted
@@ -368,13 +367,22 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
from airflow.providers.openlineage.sqlparser import SQLParser
if not self.query_ids:
- self.log.debug("openlineage: no databricks query ids found.")
+ self.log.info("OpenLineage could not find databricks query ids.")
return None
self.log.debug("openlineage: getting connection to get database info")
connection = self.get_connection(self.get_conn_id())
namespace =
SQLParser.create_namespace(self.get_openlineage_database_info(connection))
+ self.log.info("Separate OpenLineage events will be emitted for each
Databricks query_id.")
+ emit_openlineage_events_for_databricks_queries(
+ task_instance=task_instance,
+ hook=self,
+ query_ids=self.query_ids,
+ query_for_extra_metadata=True,
+ query_source_namespace=namespace,
+ )
+
if len(self.query_ids) == 1:
self.log.debug("Attaching ExternalQueryRunFacet with single
query_id to OpenLineage event.")
return OperatorLineage(
@@ -385,12 +393,4 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
}
)
- self.log.info("Multiple query_ids found. Separate OpenLineage event
will be emitted for each query.")
- emit_openlineage_events_for_databricks_queries(
- query_ids=self.query_ids,
- query_source_namespace=namespace,
- task_instance=task_instance,
- hook=self,
- )
-
return None
diff --git
a/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py
b/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py
index 5d2fe399703..57c964d9842 100644
--- a/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py
+++ b/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
from openlineage.client.event_v2 import RunEvent
from openlineage.client.facet_v2 import JobFacet
+ from airflow.providers.databricks.hooks.databricks import DatabricksHook
from airflow.providers.databricks.hooks.databricks_sql import
DatabricksSqlHook
log = logging.getLogger(__name__)
@@ -121,20 +122,18 @@ def _get_parent_run_facet(task_instance):
)
-def _run_api_call(hook: DatabricksSqlHook, query_ids: list[str]) -> list[dict]:
+def _run_api_call(hook: DatabricksSqlHook | DatabricksHook, query_ids:
list[str]) -> list[dict]:
"""Retrieve execution details for specific queries from Databricks's query
history API."""
- if not hook._token:
- # This has logic for token initialization
- hook.get_conn()
-
- # https://docs.databricks.com/api/azure/workspace/queryhistory/list
try:
+ token = hook._get_token(raise_error=True)
+ # https://docs.databricks.com/api/azure/workspace/queryhistory/list
response = requests.get(
url=f"https://{hook.host}/api/2.0/sql/history/queries",
- headers={"Authorization": f"Bearer {hook._token}"},
+ headers={"Authorization": f"Bearer {token}"},
data=json.dumps({"filter_by": {"statement_ids": query_ids}}),
timeout=2,
)
+ response.raise_for_status()
except Exception as e:
log.warning(
"OpenLineage could not retrieve Databricks queries details. Error
received: `%s`.",
@@ -142,48 +141,42 @@ def _run_api_call(hook: DatabricksSqlHook, query_ids:
list[str]) -> list[dict]:
)
return []
- if response.status_code != 200:
- log.warning(
- "OpenLineage could not retrieve Databricks queries details. API
error received: `%s`: `%s`",
- response.status_code,
- response.text,
- )
- return []
-
return response.json()["res"]
+def _process_data_from_api(data: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ """Convert timestamp fields to UTC datetime objects."""
+ for row in data:
+ for key in ("query_start_time_ms", "query_end_time_ms"):
+ row[key] = datetime.datetime.fromtimestamp(row[key] / 1000,
tz=datetime.timezone.utc)
+
+ return data
+
+
def _get_queries_details_from_databricks(
- hook: DatabricksSqlHook, query_ids: list[str]
+ hook: DatabricksSqlHook | DatabricksHook, query_ids: list[str]
) -> dict[str, dict[str, Any]]:
if not query_ids:
return {}
- queries_info_from_api = _run_api_call(hook=hook, query_ids=query_ids)
-
query_details = {}
- for query_info in queries_info_from_api:
- if not query_info.get("query_id"):
- log.debug("Databricks query ID not found in API response.")
- continue
-
- q_start_time = None
- q_end_time = None
- if query_info.get("query_start_time_ms") and
query_info.get("query_end_time_ms"):
- q_start_time = datetime.datetime.fromtimestamp(
- query_info["query_start_time_ms"] / 1000,
tz=datetime.timezone.utc
- )
- q_end_time = datetime.datetime.fromtimestamp(
- query_info["query_end_time_ms"] / 1000,
tz=datetime.timezone.utc
- )
-
- query_details[query_info["query_id"]] = {
- "status": query_info.get("status"),
- "start_time": q_start_time,
- "end_time": q_end_time,
- "query_text": query_info.get("query_text"),
- "error_message": query_info.get("error_message"),
+ try:
+ queries_info_from_api = _run_api_call(hook=hook, query_ids=query_ids)
+ queries_info_from_api = _process_data_from_api(queries_info_from_api)
+
+ query_details = {
+ query_info["query_id"]: {
+ "status": query_info.get("status"),
+ "start_time": query_info.get("query_start_time_ms"),
+ "end_time": query_info.get("query_end_time_ms"),
+ "query_text": query_info.get("query_text"),
+ "error_message": query_info.get("error_message"),
+ }
+ for query_info in queries_info_from_api
+ if query_info["query_id"]
}
+ except Exception as e:
+ log.warning("OpenLineage could not retrieve extra metadata from
Databricks. Error encountered: %s", e)
return query_details
@@ -221,17 +214,18 @@ def _create_ol_event_pair(
@require_openlineage_version(provider_min_version="2.3.0")
def emit_openlineage_events_for_databricks_queries(
- query_ids: list[str],
- query_source_namespace: str,
task_instance,
- hook: DatabricksSqlHook | None = None,
+ hook: DatabricksSqlHook | DatabricksHook | None = None,
+ query_ids: list[str] | None = None,
+ query_source_namespace: str | None = None,
+ query_for_extra_metadata: bool = False,
additional_run_facets: dict | None = None,
additional_job_facets: dict | None = None,
) -> None:
"""
Emit OpenLineage events for executed Databricks queries.
- Metadata retrieval from Databricks is attempted only if a
`DatabricksSqlHook` is provided.
+ Metadata retrieval from Databricks is attempted only if
`get_extra_metadata` is True and hook is provided.
If metadata is available, execution details such as start time, end time,
execution status,
error messages, and SQL text are included in the events. If no metadata is
found, the function
defaults to using the Airflow task instance's state and the current
timestamp.
@@ -241,10 +235,16 @@ def emit_openlineage_events_for_databricks_queries(
will correspond to actual query execution times.
Args:
- query_ids: A list of Databricks query IDs to emit events for.
- query_source_namespace: The namespace to be included in
ExternalQueryRunFacet.
task_instance: The Airflow task instance that run these queries.
- hook: A hook instance used to retrieve query metadata if available.
+ hook: A supported Databricks hook instance used to retrieve query
metadata if available.
+ If omitted, `query_ids` and `query_source_namespace` must be provided
explicitly and
+ `query_for_extra_metadata` must be `False`.
+ query_ids: A list of Databricks query IDs to emit events for, can only
be None if `hook` is provided
+ and `hook.query_ids` are present (DatabricksHook does not store
query_ids).
+ query_source_namespace: The namespace to be included in
ExternalQueryRunFacet,
+ can be `None` only if hook is provided.
+ query_for_extra_metadata: Whether to query Databricks for additional
metadata about queries.
+ Must be `False` if `hook` is not provided.
additional_run_facets: Additional run facets to include in OpenLineage
events.
additional_job_facets: Additional job facets to include in OpenLineage
events.
"""
@@ -259,25 +259,52 @@ def emit_openlineage_events_for_databricks_queries(
from airflow.providers.openlineage.conf import namespace
from airflow.providers.openlineage.plugins.listener import
get_openlineage_listener
- if not query_ids:
- log.debug("No Databricks query IDs provided; skipping OpenLineage
event emission.")
- return
-
- query_ids = [q for q in query_ids] # Make a copy to make sure it does not
change
+ log.info("OpenLineage will emit events for Databricks queries.")
if hook:
+ if not query_ids:
+ log.debug("No Databricks query IDs provided; Checking
`hook.query_ids` property.")
+ query_ids = getattr(hook, "query_ids", [])
+ if not query_ids:
+ raise ValueError("No Databricks query IDs provided and
`hook.query_ids` are not present.")
+
+ if not query_source_namespace:
+ log.debug("No Databricks query namespace provided; Creating one
from scratch.")
+
+ if hasattr(hook, "get_openlineage_database_info") and
hasattr(hook, "get_conn_id"):
+ from airflow.providers.openlineage.sqlparser import SQLParser
+
+ query_source_namespace = SQLParser.create_namespace(
+
hook.get_openlineage_database_info(hook.get_connection(hook.get_conn_id()))
+ )
+ else:
+ query_source_namespace = f"databricks://{hook.host}" if
hook.host else "databricks"
+ else:
+ if not query_ids:
+ raise ValueError("If 'hook' is not provided, 'query_ids' must be
set.")
+ if not query_source_namespace:
+ raise ValueError("If 'hook' is not provided,
'query_source_namespace' must be set.")
+ if query_for_extra_metadata:
+ raise ValueError("If 'hook' is not provided,
'query_for_extra_metadata' must be False.")
+
+ query_ids = [q for q in query_ids] # Make a copy to make sure we do not
change hook's attribute
+
+ if query_for_extra_metadata and hook:
log.debug("Retrieving metadata for %s queries from Databricks.",
len(query_ids))
databricks_metadata = _get_queries_details_from_databricks(hook,
query_ids)
else:
- log.debug("DatabricksSqlHook not provided. No extra metadata fill be
fetched from Databricks.")
+ log.debug("`query_for_extra_metadata` is False. No extra metadata fill
be fetched from Databricks.")
databricks_metadata = {}
# If real metadata is unavailable, we send events with eventTime=now
default_event_time = timezone.utcnow()
- # If no query metadata is provided, we use task_instance's state when
checking for success
+ # ti.state has no `value` attr (AF2) when task it's still running, in AF3
we get 'running', in that case
+ # assuming it's user call and query succeeded, so we replace it with
success.
# Adjust state for DBX logic, where "finished" means "success"
- default_state = task_instance.state.value if hasattr(task_instance,
"state") else ""
- default_state = "finished" if default_state == "success" else default_state
+ default_state = (
+ getattr(task_instance.state, "value", "running") if
hasattr(task_instance, "state") else ""
+ )
+ default_state = "finished" if default_state in ("running", "success") else
default_state
log.debug("Generating OpenLineage facets")
common_run_facets = {"parent": _get_parent_run_facet(task_instance)}
@@ -318,10 +345,10 @@ def emit_openlineage_events_for_databricks_queries(
event_batch = _create_ol_event_pair(
job_namespace=namespace(),
job_name=f"{task_instance.dag_id}.{task_instance.task_id}.query.{counter}",
- start_time=query_metadata.get("start_time", default_event_time),
# type: ignore[arg-type]
- end_time=query_metadata.get("end_time", default_event_time), #
type: ignore[arg-type]
+ start_time=query_metadata.get("start_time") or default_event_time,
# type: ignore[arg-type]
+ end_time=query_metadata.get("end_time") or default_event_time, #
type: ignore[arg-type]
# Only finished status means it completed without failures
- is_successful=query_metadata.get("status", default_state).lower()
== "finished",
+ is_successful=(query_metadata.get("status") or
default_state).lower() == "finished",
run_facets={**query_specific_run_facets, **common_run_facets,
**additional_run_facets},
job_facets={**query_specific_job_facets, **common_job_facets,
**additional_job_facets},
)
diff --git
a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
index 7449489c77d..20e2c081f74 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
@@ -457,7 +457,8 @@ def
test_get_openlineage_database_specific_lineage_with_no_query_id():
assert result is None
-def test_get_openlineage_database_specific_lineage_with_single_query_id():
[email protected]("airflow.providers.databricks.utils.openlineage.emit_openlineage_events_for_databricks_queries")
+def
test_get_openlineage_database_specific_lineage_with_single_query_id(mock_emit):
from airflow.providers.common.compat.openlineage.facet import
ExternalQueryRunFacet
from airflow.providers.openlineage.extractors import OperatorLineage
@@ -466,7 +467,18 @@ def
test_get_openlineage_database_specific_lineage_with_single_query_id():
hook.get_connection = mock.MagicMock()
hook.get_openlineage_database_info = lambda x:
mock.MagicMock(authority="auth", scheme="scheme")
- result = hook.get_openlineage_database_specific_lineage(None)
+ ti = mock.MagicMock()
+
+ result = hook.get_openlineage_database_specific_lineage(ti)
+ mock_emit.assert_called_once_with(
+ **{
+ "hook": hook,
+ "query_ids": ["query1"],
+ "query_source_namespace": "scheme://auth",
+ "task_instance": ti,
+ "query_for_extra_metadata": True,
+ }
+ )
assert result == OperatorLineage(
run_facets={"externalQuery":
ExternalQueryRunFacet(externalQueryId="query1", source="scheme://auth")}
)
@@ -488,6 +500,7 @@ def
test_get_openlineage_database_specific_lineage_with_multiple_query_ids(mock_
"query_ids": ["query1", "query2"],
"query_source_namespace": "scheme://auth",
"task_instance": ti,
+ "query_for_extra_metadata": True,
}
)
assert result is None
diff --git
a/providers/databricks/tests/unit/databricks/utils/test_openlineage.py
b/providers/databricks/tests/unit/databricks/utils/test_openlineage.py
index 699e018785e..6d427e0ba77 100644
--- a/providers/databricks/tests/unit/databricks/utils/test_openlineage.py
+++ b/providers/databricks/tests/unit/databricks/utils/test_openlineage.py
@@ -30,11 +30,14 @@ from airflow.providers.common.compat.openlineage.facet
import (
ExternalQueryRunFacet,
SQLJobFacet,
)
+from airflow.providers.databricks.hooks.databricks import DatabricksHook
+from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
from airflow.providers.databricks.utils.openlineage import (
_create_ol_event_pair,
_get_ol_run_id,
_get_parent_run_facet,
_get_queries_details_from_databricks,
+ _process_data_from_api,
_run_api_call,
emit_openlineage_events_for_databricks_queries,
)
@@ -96,7 +99,7 @@ def test_get_parent_run_facet():
def test_run_api_call_success():
mock_hook = mock.MagicMock()
- mock_hook._token = "mock_token"
+ mock_hook._get_token.return_value = "mock_token"
mock_hook.host = "mock_host"
mock_response = mock.MagicMock()
@@ -109,14 +112,27 @@ def test_run_api_call_success():
assert result == [{"query_id": "123", "status": "success"}]
-def test_run_api_call_error():
+def test_run_api_call_request_error():
mock_hook = mock.MagicMock()
- mock_hook._token = "mock_token"
+ mock_hook._get_token.return_value = "mock_token"
mock_hook.host = "mock_host"
mock_response = mock.MagicMock()
- mock_response.status_code = 500
- mock_response.text = "Internal Server Error"
+ mock_response.status_code = 200
+
+ with mock.patch("requests.get", side_effect=RuntimeError("request error")):
+ result = _run_api_call(mock_hook, ["123"])
+
+ assert result == []
+
+
+def test_run_api_call_token_error():
+ mock_hook = mock.MagicMock()
+ mock_hook._get_token.side_effect = RuntimeError("Token error")
+ mock_hook.host = "mock_host"
+
+ mock_response = mock.MagicMock()
+ mock_response.status_code = 200
with mock.patch("requests.get", return_value=mock_response):
result = _run_api_call(mock_hook, ["123"])
@@ -124,6 +140,55 @@ def test_run_api_call_error():
assert result == []
+def test_process_data_from_api():
+ data = [
+ {
+ "query_id": "ABC",
+ "status": "FINISHED",
+ "query_start_time_ms": 1595357086200,
+ "query_end_time_ms": 1595357087200,
+ "query_text": "SELECT * FROM table1;",
+ "error_message": "Error occurred",
+ },
+ {
+ "query_id": "DEF",
+ "query_start_time_ms": 1595357086200,
+ "query_end_time_ms": 1595357087200,
+ },
+ ]
+ expected_details = [
+ {
+ "query_id": "ABC",
+ "status": "FINISHED",
+ "query_start_time_ms": datetime.datetime(
+ 2020, 7, 21, 18, 44, 46, 200000, tzinfo=datetime.timezone.utc
+ ),
+ "query_end_time_ms": datetime.datetime(
+ 2020, 7, 21, 18, 44, 47, 200000, tzinfo=datetime.timezone.utc
+ ),
+ "query_text": "SELECT * FROM table1;",
+ "error_message": "Error occurred",
+ },
+ {
+ "query_id": "DEF",
+ "query_start_time_ms": datetime.datetime(
+ 2020, 7, 21, 18, 44, 46, 200000, tzinfo=datetime.timezone.utc
+ ),
+ "query_end_time_ms": datetime.datetime(
+ 2020, 7, 21, 18, 44, 47, 200000, tzinfo=datetime.timezone.utc
+ ),
+ },
+ ]
+ result = _process_data_from_api(data=data)
+ assert len(result) == 2
+ assert result == expected_details
+
+
+def test_process_data_from_api_error():
+ with pytest.raises(KeyError):
+ _process_data_from_api(data=[{"query_start_time_ms": 1595357086200}])
+
+
def test_get_queries_details_from_databricks_empty_query_ids():
details = _get_queries_details_from_databricks(None, [])
assert details == {}
@@ -131,7 +196,7 @@ def
test_get_queries_details_from_databricks_empty_query_ids():
@mock.patch("airflow.providers.databricks.utils.openlineage._run_api_call")
def test_get_queries_details_from_databricks(mock_api_call):
- hook = mock.MagicMock()
+ hook = DatabricksSqlHook()
query_ids = ["ABC"]
fake_result = [
{
@@ -160,7 +225,7 @@ def test_get_queries_details_from_databricks(mock_api_call):
@mock.patch("airflow.providers.databricks.utils.openlineage._run_api_call")
def test_get_queries_details_from_databricks_no_data_found(mock_api_call):
- hook = mock.MagicMock()
+ hook = DatabricksSqlHook()
query_ids = ["ABC", "DEF"]
mock_api_call.return_value = []
@@ -274,6 +339,7 @@ def
test_emit_openlineage_events_for_databricks_queries(mock_now, mock_generate_
query_source_namespace="databricks_ns",
task_instance=mock_ti,
hook=mock.MagicMock(),
+ query_for_extra_metadata=True,
additional_run_facets=additional_run_facets,
additional_job_facets=additional_job_facets,
)
@@ -448,7 +514,7 @@ def
test_emit_openlineage_events_for_databricks_queries(mock_now, mock_generate_
@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
@mock.patch("airflow.utils.timezone.utcnow")
-def test_emit_openlineage_events_for_databricks_queries_without_metadata_found(
+def test_emit_openlineage_events_for_databricks_queries_without_metadata(
mock_now, mock_generate_uuid, mock_version
):
fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
@@ -489,7 +555,8 @@ def
test_emit_openlineage_events_for_databricks_queries_without_metadata_found(
query_ids=query_ids,
query_source_namespace="databricks_ns",
task_instance=mock_ti,
- hook=None, # None so metadata retrieval is not triggered
+ hook=mock.MagicMock(),
+ # query_for_extra_metadata=False, # False by default
additional_run_facets=additional_run_facets,
additional_job_facets=additional_job_facets,
)
@@ -564,9 +631,37 @@ def
test_emit_openlineage_events_for_databricks_queries_without_metadata_found(
@mock.patch("importlib.metadata.version", return_value="2.3.0")
-def
test_emit_openlineage_events_for_databricks_queries_without_query_ids(mock_version):
- query_ids = []
[email protected]("openlineage.client.uuid.generate_new_uuid")
[email protected]("airflow.utils.timezone.utcnow")
+def
test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids(
+ mock_now, mock_generate_uuid, mock_version
+):
+ fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
+ mock_generate_uuid.return_value = fake_uuid
+
+ default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0)
+ mock_now.return_value = default_event_time
+
+ query_ids = ["query1"]
+ hook = mock.MagicMock()
+ hook.query_ids = query_ids
original_query_ids = copy.deepcopy(query_ids)
+ logical_date = timezone.datetime(2025, 1, 1)
+ mock_ti = mock.MagicMock(
+ dag_id="dag_id",
+ task_id="task_id",
+ map_index=1,
+ try_number=1,
+ logical_date=logical_date,
+ state=TaskInstanceState.RUNNING, # This will be query default state
if no metadata found
+ dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0),
+ )
+ mock_ti.get_template_context.return_value = {
+ "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0)
+ }
+
+ additional_run_facets = {"custom_run": "value_run"}
+ additional_job_facets = {"custom_job": "value_job"}
fake_adapter = mock.MagicMock()
fake_adapter.emit = mock.MagicMock()
@@ -578,16 +673,527 @@ def
test_emit_openlineage_events_for_databricks_queries_without_query_ids(mock_v
return_value=fake_listener,
):
emit_openlineage_events_for_databricks_queries(
- query_ids=query_ids,
query_source_namespace="databricks_ns",
- task_instance=None,
+ task_instance=mock_ti,
+ hook=hook,
+ # query_for_extra_metadata=False, # False by default
+ additional_run_facets=additional_run_facets,
+ additional_job_facets=additional_job_facets,
+ )
+
+ assert query_ids == original_query_ids # Verify that the input
query_ids list is unchanged.
+ assert fake_adapter.emit.call_count == 2 # Expect two events per
query.
+
+ expected_common_job_facets = {
+ "jobType": job_type_job.JobTypeJobFacet(
+ jobType="QUERY",
+ processingType="BATCH",
+ integration="DATABRICKS",
+ ),
+ "custom_job": "value_job",
+ }
+ expected_common_run_facets = {
+ "parent": parent_run.ParentRunFacet(
+
run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"),
+ job=parent_run.Job(namespace=namespace(),
name="dag_id.task_id"),
+ root=parent_run.Root(
+
run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"),
+ job=parent_run.RootJob(namespace=namespace(),
name="dag_id"),
+ ),
+ ),
+ "custom_run": "value_run",
+ }
+
+ expected_calls = [
+ mock.call( # Query1: START event (no metadata)
+ RunEvent(
+ eventTime=default_event_time.isoformat(),
+ eventType=RunState.START,
+ run=Run(
+ runId=fake_uuid,
+ facets={
+ "externalQuery": ExternalQueryRunFacet(
+ externalQueryId="query1",
source="databricks_ns"
+ ),
+ **expected_common_run_facets,
+ },
+ ),
+ job=Job(
+ namespace=namespace(),
+ name="dag_id.task_id.query.1",
+ facets=expected_common_job_facets,
+ ),
+ )
+ ),
+ mock.call( # Query1: COMPLETE event (no metadata)
+ RunEvent(
+ eventTime=default_event_time.isoformat(),
+ eventType=RunState.COMPLETE,
+ run=Run(
+ runId=fake_uuid,
+ facets={
+ "externalQuery": ExternalQueryRunFacet(
+ externalQueryId="query1",
source="databricks_ns"
+ ),
+ **expected_common_run_facets,
+ },
+ ),
+ job=Job(
+ namespace=namespace(),
+ name="dag_id.task_id.query.1",
+ facets=expected_common_job_facets,
+ ),
+ )
+ ),
+ ]
+
+ assert fake_adapter.emit.call_args_list == expected_calls
+
+
[email protected](
+ "airflow.providers.openlineage.sqlparser.SQLParser.create_namespace",
return_value="databricks_ns"
+)
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("openlineage.client.uuid.generate_new_uuid")
[email protected]("airflow.utils.timezone.utcnow")
+def
test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids_and_namespace(
+ mock_now, mock_generate_uuid, mock_version, mock_parser
+):
+ fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
+ mock_generate_uuid.return_value = fake_uuid
+
+ default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0)
+ mock_now.return_value = default_event_time
+
+ query_ids = ["query1"]
+ hook = mock.MagicMock()
+ hook.query_ids = query_ids
+ original_query_ids = copy.deepcopy(query_ids)
+ logical_date = timezone.datetime(2025, 1, 1)
+ mock_ti = mock.MagicMock(
+ dag_id="dag_id",
+ task_id="task_id",
+ map_index=1,
+ try_number=1,
+ logical_date=logical_date,
+ state=TaskInstanceState.RUNNING, # This will be query default state
if no metadata found
+ dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0),
+ )
+ mock_ti.get_template_context.return_value = {
+ "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0)
+ }
+
+ additional_run_facets = {"custom_run": "value_run"}
+ additional_job_facets = {"custom_job": "value_job"}
+
+ fake_adapter = mock.MagicMock()
+ fake_adapter.emit = mock.MagicMock()
+ fake_listener = mock.MagicMock()
+ fake_listener.adapter = fake_adapter
+
+ with mock.patch(
+
"airflow.providers.openlineage.plugins.listener.get_openlineage_listener",
+ return_value=fake_listener,
+ ):
+ emit_openlineage_events_for_databricks_queries(
+ task_instance=mock_ti,
+ hook=hook,
+ # query_for_extra_metadata=False, # False by default
+ additional_run_facets=additional_run_facets,
+ additional_job_facets=additional_job_facets,
)
+ assert query_ids == original_query_ids # Verify that the input
query_ids list is unchanged.
+ assert fake_adapter.emit.call_count == 2 # Expect two events per
query.
+
+ expected_common_job_facets = {
+ "jobType": job_type_job.JobTypeJobFacet(
+ jobType="QUERY",
+ processingType="BATCH",
+ integration="DATABRICKS",
+ ),
+ "custom_job": "value_job",
+ }
+ expected_common_run_facets = {
+ "parent": parent_run.ParentRunFacet(
+
run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"),
+ job=parent_run.Job(namespace=namespace(),
name="dag_id.task_id"),
+ root=parent_run.Root(
+
run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"),
+ job=parent_run.RootJob(namespace=namespace(),
name="dag_id"),
+ ),
+ ),
+ "custom_run": "value_run",
+ }
+
+ expected_calls = [
+ mock.call( # Query1: START event (no metadata)
+ RunEvent(
+ eventTime=default_event_time.isoformat(),
+ eventType=RunState.START,
+ run=Run(
+ runId=fake_uuid,
+ facets={
+ "externalQuery": ExternalQueryRunFacet(
+ externalQueryId="query1",
source="databricks_ns"
+ ),
+ **expected_common_run_facets,
+ },
+ ),
+ job=Job(
+ namespace=namespace(),
+ name="dag_id.task_id.query.1",
+ facets=expected_common_job_facets,
+ ),
+ )
+ ),
+ mock.call( # Query1: COMPLETE event (no metadata)
+ RunEvent(
+ eventTime=default_event_time.isoformat(),
+ eventType=RunState.COMPLETE,
+ run=Run(
+ runId=fake_uuid,
+ facets={
+ "externalQuery": ExternalQueryRunFacet(
+ externalQueryId="query1",
source="databricks_ns"
+ ),
+ **expected_common_run_facets,
+ },
+ ),
+ job=Job(
+ namespace=namespace(),
+ name="dag_id.task_id.query.1",
+ facets=expected_common_job_facets,
+ ),
+ )
+ ),
+ ]
+
+ assert fake_adapter.emit.call_args_list == expected_calls
+
+
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("openlineage.client.uuid.generate_new_uuid")
[email protected]("airflow.utils.timezone.utcnow")
+def
test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids_and_namespace_raw_ns(
+ mock_now, mock_generate_uuid, mock_version
+):
+ fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
+ mock_generate_uuid.return_value = fake_uuid
+
+ default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0)
+ mock_now.return_value = default_event_time
+
+ query_ids = ["query1"]
+ hook = DatabricksHook()
+ hook.query_ids = query_ids
+ hook.host = "some_host"
+ original_query_ids = copy.deepcopy(query_ids)
+ logical_date = timezone.datetime(2025, 1, 1)
+ mock_ti = mock.MagicMock(
+ dag_id="dag_id",
+ task_id="task_id",
+ map_index=1,
+ try_number=1,
+ logical_date=logical_date,
+ state=TaskInstanceState.RUNNING, # This will be query default state
if no metadata found
+ dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0),
+ )
+ mock_ti.get_template_context.return_value = {
+ "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0)
+ }
+
+ additional_run_facets = {"custom_run": "value_run"}
+ additional_job_facets = {"custom_job": "value_job"}
+
+ fake_adapter = mock.MagicMock()
+ fake_adapter.emit = mock.MagicMock()
+ fake_listener = mock.MagicMock()
+ fake_listener.adapter = fake_adapter
+
+ with mock.patch(
+
"airflow.providers.openlineage.plugins.listener.get_openlineage_listener",
+ return_value=fake_listener,
+ ):
+ emit_openlineage_events_for_databricks_queries(
+ task_instance=mock_ti,
+ hook=hook,
+ # query_for_extra_metadata=False, # False by default
+ additional_run_facets=additional_run_facets,
+ additional_job_facets=additional_job_facets,
+ )
+
+ assert query_ids == original_query_ids # Verify that the input
query_ids list is unchanged.
+ assert fake_adapter.emit.call_count == 2 # Expect two events per
query.
+
+ expected_common_job_facets = {
+ "jobType": job_type_job.JobTypeJobFacet(
+ jobType="QUERY",
+ processingType="BATCH",
+ integration="DATABRICKS",
+ ),
+ "custom_job": "value_job",
+ }
+ expected_common_run_facets = {
+ "parent": parent_run.ParentRunFacet(
+
run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"),
+ job=parent_run.Job(namespace=namespace(),
name="dag_id.task_id"),
+ root=parent_run.Root(
+
run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"),
+ job=parent_run.RootJob(namespace=namespace(),
name="dag_id"),
+ ),
+ ),
+ "custom_run": "value_run",
+ }
+
+ expected_calls = [
+ mock.call( # Query1: START event (no metadata)
+ RunEvent(
+ eventTime=default_event_time.isoformat(),
+ eventType=RunState.START,
+ run=Run(
+ runId=fake_uuid,
+ facets={
+ "externalQuery": ExternalQueryRunFacet(
+ externalQueryId="query1",
source="databricks://some_host"
+ ),
+ **expected_common_run_facets,
+ },
+ ),
+ job=Job(
+ namespace=namespace(),
+ name="dag_id.task_id.query.1",
+ facets=expected_common_job_facets,
+ ),
+ )
+ ),
+ mock.call( # Query1: COMPLETE event (no metadata)
+ RunEvent(
+ eventTime=default_event_time.isoformat(),
+ eventType=RunState.COMPLETE,
+ run=Run(
+ runId=fake_uuid,
+ facets={
+ "externalQuery": ExternalQueryRunFacet(
+ externalQueryId="query1",
source="databricks://some_host"
+ ),
+ **expected_common_run_facets,
+ },
+ ),
+ job=Job(
+ namespace=namespace(),
+ name="dag_id.task_id.query.1",
+ facets=expected_common_job_facets,
+ ),
+ )
+ ),
+ ]
+
+ assert fake_adapter.emit.call_args_list == expected_calls
+
+
[email protected]("importlib.metadata.version", return_value="2.3.0")
[email protected]("openlineage.client.uuid.generate_new_uuid")
[email protected]("airflow.utils.timezone.utcnow")
+def
test_emit_openlineage_events_for_databricks_queries_ith_query_ids_and_hook_query_ids(
+ mock_now, mock_generate_uuid, mock_version
+):
+ fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
+ mock_generate_uuid.return_value = fake_uuid
+
+ default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0)
+ mock_now.return_value = default_event_time
+
+ hook = DatabricksSqlHook()
+ hook.query_ids = ["query2", "query3"]
+ query_ids = ["query1"]
+ original_query_ids = copy.deepcopy(query_ids)
+ logical_date = timezone.datetime(2025, 1, 1)
+ mock_ti = mock.MagicMock(
+ dag_id="dag_id",
+ task_id="task_id",
+ map_index=1,
+ try_number=1,
+ logical_date=logical_date,
+ state=TaskInstanceState.SUCCESS, # This will be query default state
if no metadata found
+ dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0),
+ )
+ mock_ti.get_template_context.return_value = {
+ "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0)
+ }
+
+ additional_run_facets = {"custom_run": "value_run"}
+ additional_job_facets = {"custom_job": "value_job"}
+
+ fake_adapter = mock.MagicMock()
+ fake_adapter.emit = mock.MagicMock()
+ fake_listener = mock.MagicMock()
+ fake_listener.adapter = fake_adapter
+
+ with mock.patch(
+
"airflow.providers.openlineage.plugins.listener.get_openlineage_listener",
+ return_value=fake_listener,
+ ):
+ emit_openlineage_events_for_databricks_queries(
+ query_ids=query_ids,
+ query_source_namespace="databricks_ns",
+ task_instance=mock_ti,
+ hook=hook,
+ # query_for_extra_metadata=False, # False by default
+ additional_run_facets=additional_run_facets,
+ additional_job_facets=additional_job_facets,
+ )
+
+ assert query_ids == original_query_ids # Verify that the input
query_ids list is unchanged.
+ assert fake_adapter.emit.call_count == 2 # Expect two events per
query.
+
+ expected_common_job_facets = {
+ "jobType": job_type_job.JobTypeJobFacet(
+ jobType="QUERY",
+ processingType="BATCH",
+ integration="DATABRICKS",
+ ),
+ "custom_job": "value_job",
+ }
+ expected_common_run_facets = {
+ "parent": parent_run.ParentRunFacet(
+
run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"),
+ job=parent_run.Job(namespace=namespace(),
name="dag_id.task_id"),
+ root=parent_run.Root(
+
run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"),
+ job=parent_run.RootJob(namespace=namespace(),
name="dag_id"),
+ ),
+ ),
+ "custom_run": "value_run",
+ }
+
+ expected_calls = [
+ mock.call( # Query1: START event (no metadata)
+ RunEvent(
+ eventTime=default_event_time.isoformat(),
+ eventType=RunState.START,
+ run=Run(
+ runId=fake_uuid,
+ facets={
+ "externalQuery": ExternalQueryRunFacet(
+ externalQueryId="query1",
source="databricks_ns"
+ ),
+ **expected_common_run_facets,
+ },
+ ),
+ job=Job(
+ namespace=namespace(),
+ name="dag_id.task_id.query.1",
+ facets=expected_common_job_facets,
+ ),
+ )
+ ),
+ mock.call( # Query1: COMPLETE event (no metadata)
+ RunEvent(
+ eventTime=default_event_time.isoformat(),
+ eventType=RunState.COMPLETE,
+ run=Run(
+ runId=fake_uuid,
+ facets={
+ "externalQuery": ExternalQueryRunFacet(
+ externalQueryId="query1",
source="databricks_ns"
+ ),
+ **expected_common_run_facets,
+ },
+ ),
+ job=Job(
+ namespace=namespace(),
+ name="dag_id.task_id.query.1",
+ facets=expected_common_job_facets,
+ ),
+ )
+ ),
+ ]
+
+ assert fake_adapter.emit.call_args_list == expected_calls
+
+
[email protected]("importlib.metadata.version", return_value="2.3.0")
+def
test_emit_openlineage_events_for_databricks_queries_missing_query_ids_and_hook(mock_version):
+ query_ids = []
+ original_query_ids = copy.deepcopy(query_ids)
+
+ fake_adapter = mock.MagicMock()
+ fake_adapter.emit = mock.MagicMock()
+ fake_listener = mock.MagicMock()
+ fake_listener.adapter = fake_adapter
+
+ with mock.patch(
+
"airflow.providers.openlineage.plugins.listener.get_openlineage_listener",
+ return_value=fake_listener,
+ ):
+ with pytest.raises(ValueError, match="If 'hook' is not provided,
'query_ids' must be set."):
+ emit_openlineage_events_for_databricks_queries(
+ query_ids=query_ids,
+ query_source_namespace="databricks_ns",
+ task_instance=None,
+ )
+
+ assert query_ids == original_query_ids # Verify that the input
query_ids list is unchanged.
+ fake_adapter.emit.assert_not_called() # No events should be emitted
+
+
[email protected]("importlib.metadata.version", return_value="2.3.0")
+def
test_emit_openlineage_events_for_databricks_queries_missing_query_namespace_and_hook(mock_version):
+ query_ids = ["1", "2"]
+ original_query_ids = copy.deepcopy(query_ids)
+
+ fake_adapter = mock.MagicMock()
+ fake_adapter.emit = mock.MagicMock()
+ fake_listener = mock.MagicMock()
+ fake_listener.adapter = fake_adapter
+
+ with mock.patch(
+
"airflow.providers.openlineage.plugins.listener.get_openlineage_listener",
+ return_value=fake_listener,
+ ):
+ with pytest.raises(
+ ValueError, match="If 'hook' is not provided,
'query_source_namespace' must be set."
+ ):
+ emit_openlineage_events_for_databricks_queries(
+ query_ids=query_ids,
+ task_instance=None,
+ )
+
+ assert query_ids == original_query_ids # Verify that the input
query_ids list is unchanged.
+ fake_adapter.emit.assert_not_called() # No events should be emitted
+
+
[email protected]("importlib.metadata.version", return_value="2.3.0")
+def
test_emit_openlineage_events_for_databricks_queries_missing_hook_and_query_for_extra_metadata_true(
+ mock_version,
+):
+ query_ids = ["1", "2"]
+ original_query_ids = copy.deepcopy(query_ids)
+
+ fake_adapter = mock.MagicMock()
+ fake_adapter.emit = mock.MagicMock()
+ fake_listener = mock.MagicMock()
+ fake_listener.adapter = fake_adapter
+
+ with mock.patch(
+
"airflow.providers.openlineage.plugins.listener.get_openlineage_listener",
+ return_value=fake_listener,
+ ):
+ with pytest.raises(
+ ValueError, match="If 'hook' is not provided,
'query_for_extra_metadata' must be False."
+ ):
+ emit_openlineage_events_for_databricks_queries(
+ query_ids=query_ids,
+ query_source_namespace="databricks_ns",
+ task_instance=None,
+ query_for_extra_metadata=True,
+ )
+
assert query_ids == original_query_ids # Verify that the input
query_ids list is unchanged.
fake_adapter.emit.assert_not_called() # No events should be emitted
-# emit_openlineage_events_for_databricks_queries requires OL provider 2.3.0
@mock.patch("importlib.metadata.version", return_value="1.99.0")
def test_emit_openlineage_events_with_old_openlineage_provider(mock_version):
query_ids = ["q1", "q2"]