mobuchowski commented on code in PR #50392:
URL: https://github.com/apache/airflow/pull/50392#discussion_r2084107904
##########
providers/databricks/pyproject.toml:
##########
@@ -83,14 +85,19 @@ dependencies = [
"standard" = [
"apache-airflow-providers-standard"
]
+"openlineage" = [
+ "apache-airflow-providers-openlineage"
Review Comment:
Probably need version bound here
##########
providers/databricks/pyproject.toml:
##########
@@ -83,14 +85,19 @@ dependencies = [
"standard" = [
"apache-airflow-providers-standard"
]
+"openlineage" = [
+ "apache-airflow-providers-openlineage"
+]
[dependency-groups]
dev = [
"apache-airflow",
"apache-airflow-task-sdk",
"apache-airflow-devel-common",
+ "apache-airflow-providers-common-compat",
"apache-airflow-providers-common-sql",
"apache-airflow-providers-fab",
+ "apache-airflow-providers-openlineage",
Review Comment:
And here
##########
providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -106,7 +107,8 @@ def __init__(
self.catalog = catalog
self.schema = schema
- def get_db_hook(self) -> DatabricksSqlHook:
+ @cached_property
Review Comment:
👍
##########
providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py:
##########
@@ -309,3 +318,83 @@ def bulk_dump(self, table, tmp_file):
def bulk_load(self, table, tmp_file):
raise NotImplementedError()
+
+ def get_openlineage_database_info(self, connection) -> DatabaseInfo:
+ from airflow.providers.openlineage.sqlparser import DatabaseInfo
+
+ return DatabaseInfo(
+ scheme=self.get_openlineage_database_dialect(connection),
+ authority=self._get_openlineage_authority(connection),
+ database=self.catalog,
+ information_schema_columns=[
+ "table_schema",
+ "table_name",
+ "column_name",
+ "ordinal_position",
+ "data_type",
+ "table_catalog",
+ ],
+ is_information_schema_cross_db=True,
+ )
+
+ def get_openlineage_database_dialect(self, _) -> str:
+ return "databricks"
+
+ def get_openlineage_default_schema(self) -> str | None:
+ return self.schema or "default"
+
+ def _get_openlineage_authority(self, _) -> str | None:
+ return self.host
+
+ def get_openlineage_database_specific_lineage(self, task_instance) ->
OperatorLineage | None:
+ """
+ Generate OpenLineage metadata for a Databricks task instance based on
executed query IDs.
+
+ If a single query ID is present, attach an `ExternalQueryRunFacet` to
the lineage metadata.
+ If multiple query IDs are present, emits separate OpenLineage events
for each query instead.
+
+ Note that `get_openlineage_database_specific_lineage` is usually
called after task's execution,
+ so if multiple query IDs are present, both START and COMPLETE event
for each query will be emitted
+ after task's execution. If we are able to query Databricks for query
execution metadata,
+ query event times will correspond to actual query's start and finish
times.
+
+ Args:
+ task_instance: The Airflow TaskInstance object for which lineage
is being collected.
+
+ Returns:
+ An `OperatorLineage` object if a single query ID is found;
otherwise `None`.
+ """
+ from airflow.providers.common.compat.openlineage.facet import
ExternalQueryRunFacet
+ from airflow.providers.databricks.utils.openlineage import (
+ emit_openlineage_events_for_databricks_queries,
+ )
+ from airflow.providers.openlineage.extractors import OperatorLineage
+ from airflow.providers.openlineage.sqlparser import SQLParser
+
+ if not self.query_ids:
+ self.log.debug("openlineage: no databricks query ids found.")
+ return None
+
+ self.log.debug("openlineage: getting connection to get database info")
+ connection = self.get_connection(self.get_conn_id())
+ namespace =
SQLParser.create_namespace(self.get_openlineage_database_info(connection))
+
+ if len(self.query_ids) == 1:
+ self.log.debug("Attaching ExternalQueryRunFacet with single
query_id to OpenLineage event.")
+ return OperatorLineage(
+ run_facets={
+ "externalQuery": ExternalQueryRunFacet(
+ externalQueryId=self.query_ids[0], source=namespace
+ )
+ }
+ )
+
+ self.log.info("Multiple query_ids found. Separate OpenLineage event
will be emitted for each query.")
+ emit_openlineage_events_for_databricks_queries(
+ query_ids=self.query_ids,
+ query_source_namespace=namespace,
+ task_instance=task_instance,
+ hook=self,
+ )
Review Comment:
TBH now I think the logic would be much better if it were consistent,
whether we have a single or multiple SQL-level events, emit them separately.
Current way requires you to handle both paths anyway, so moving to that
model should not be a breaking logic? Maybe not for this PR though.
##########
providers/databricks/src/airflow/providers/databricks/utils/openlineage.py:
##########
@@ -0,0 +1,328 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import datetime
+import json
+import logging
+from typing import TYPE_CHECKING, Any
+
+import requests
+
+from airflow.providers.common.compat.openlineage.check import
require_openlineage_version
+from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.utils import timezone
+
+if TYPE_CHECKING:
+ from openlineage.client.event_v2 import RunEvent
+ from openlineage.client.facet_v2 import JobFacet
+
+ from airflow.providers.databricks.hooks.databricks_sql import
DatabricksSqlHook
+
+
+log = logging.getLogger(__name__)
+
+
+def _get_logical_date(task_instance):
+ # todo: remove when min airflow version >= 3.0
+ if AIRFLOW_V_3_0_PLUS:
+ dagrun = task_instance.get_template_context()["dag_run"]
+ return dagrun.logical_date or dagrun.run_after
+
+ if hasattr(task_instance, "logical_date"):
+ date = task_instance.logical_date
+ else:
+ date = task_instance.execution_date
+
+ return date
+
+
+def _get_dag_run_clear_number(task_instance):
+ # todo: remove when min airflow version >= 3.0
+ if AIRFLOW_V_3_0_PLUS:
+ dagrun = task_instance.get_template_context()["dag_run"]
+ return dagrun.clear_number
+ return task_instance.dag_run.clear_number
+
+
+# todo: move this run_id logic into OpenLineage's listener to avoid differences
+def _get_ol_run_id(task_instance) -> str:
+ """
+ Get OpenLineage run_id from TaskInstance.
+
+ It's crucial that the task_instance's run_id creation logic matches
OpenLineage's listener implementation.
+ Only then can we ensure that the generated run_id aligns with the Airflow
task,
+ enabling a proper connection between events.
+ """
+ from airflow.providers.openlineage.plugins.adapter import
OpenLineageAdapter
+
+ # Generate same OL run id as is generated for current task instance
+ return OpenLineageAdapter.build_task_instance_run_id(
+ dag_id=task_instance.dag_id,
+ task_id=task_instance.task_id,
+ logical_date=_get_logical_date(task_instance),
+ try_number=task_instance.try_number,
+ map_index=task_instance.map_index,
+ )
+
+
+# todo: move this run_id logic into OpenLineage's listener to avoid differences
+def _get_ol_dag_run_id(task_instance) -> str:
+ from airflow.providers.openlineage.plugins.adapter import
OpenLineageAdapter
+
+ return OpenLineageAdapter.build_dag_run_id(
+ dag_id=task_instance.dag_id,
+ logical_date=_get_logical_date(task_instance),
+ clear_number=_get_dag_run_clear_number(task_instance),
+ )
+
+
+def _get_parent_run_facet(task_instance):
+ """
+ Retrieve the ParentRunFacet associated with a specific Airflow task
instance.
+
+ This facet helps link OpenLineage events of child jobs - such as queries
executed within
+ external systems (e.g., Databricks) by the Airflow task - to the original
Airflow task execution.
+ Establishing this connection enables better lineage tracking and
observability.
+ """
+ from openlineage.client.facet_v2 import parent_run
+
+ from airflow.providers.openlineage.conf import namespace
+
+ parent_run_id = _get_ol_run_id(task_instance)
+ root_parent_run_id = _get_ol_dag_run_id(task_instance)
+
+ return parent_run.ParentRunFacet(
+ run=parent_run.Run(runId=parent_run_id),
+ job=parent_run.Job(
+ namespace=namespace(),
+ name=f"{task_instance.dag_id}.{task_instance.task_id}",
+ ),
+ root=parent_run.Root(
+ run=parent_run.RootRun(runId=root_parent_run_id),
+ job=parent_run.RootJob(
+ name=task_instance.dag_id,
+ namespace=namespace(),
+ ),
+ ),
+ )
+
+
+def _run_api_call(hook: DatabricksSqlHook, query_ids: list[str]) -> list[dict]:
+ """Retrieve execution details for specific queries from Databricks's query
history API."""
+ if not hook._token:
+ # This has logic for token initialization
+ hook.get_conn()
+
+ # https://docs.databricks.com/api/azure/workspace/queryhistory/list
+ response = requests.get(
+ url=f"https://{hook.host}/api/2.0/sql/history/queries",
+ headers={"Authorization": f"Bearer {hook._token}"},
+ data=json.dumps({"filter_by": {"statement_ids": query_ids}}),
+ )
+ if response.status_code != 200:
+ log.warning(
+ "OpenLineage could not retrieve Databricks queries details. API
error received: `%s`: `%s`",
+ response.status_code,
+ response.text,
+ )
+ return []
+
+ return response.json()["res"]
+
+
+def _get_queries_details_from_databricks(
+ hook: DatabricksSqlHook, query_ids: list[str]
+) -> dict[str, dict[str, Any]]:
+ if not query_ids:
+ return {}
+
+ queries_info_from_api = _run_api_call(hook=hook, query_ids=query_ids)
+
+ query_details = {}
+ for query_info in queries_info_from_api:
+ if not query_info.get("query_id"):
+ log.debug("Databricks query ID not found in API response.")
+ continue
+
+ q_start_time = None
+ q_end_time = None
+ if query_info.get("query_start_time_ms") and
query_info.get("query_end_time_ms"):
+ q_start_time = datetime.datetime.fromtimestamp(
+ query_info["query_start_time_ms"] / 1000,
tz=datetime.timezone.utc
+ )
+ q_end_time = datetime.datetime.fromtimestamp(
+ query_info["query_end_time_ms"] / 1000,
tz=datetime.timezone.utc
+ )
+
+ query_details[query_info["query_id"]] = {
+ "status": query_info.get("status"),
+ "start_time": q_start_time,
+ "end_time": q_end_time,
+ "query_text": query_info.get("query_text"),
+ "error_message": query_info.get("error_message"),
+ }
+
+ return query_details
+
+
+def _create_ol_event_pair(
+ job_namespace: str,
+ job_name: str,
+ start_time: datetime.datetime,
+ end_time: datetime.datetime,
+ is_successful: bool,
+ run_facets: dict | None = None,
+ job_facets: dict | None = None,
+) -> tuple[RunEvent, RunEvent]:
+ """Create a pair of OpenLineage RunEvents representing the start and end
of a query execution."""
+ from openlineage.client.event_v2 import Job, Run, RunEvent, RunState
+ from openlineage.client.uuid import generate_new_uuid
+
+ run = Run(runId=str(generate_new_uuid()), facets=run_facets or {})
+ job = Job(namespace=job_namespace, name=job_name, facets=job_facets or {})
+
+ start = RunEvent(
+ eventType=RunState.START,
+ eventTime=start_time.isoformat(),
+ run=run,
+ job=job,
+ )
+ end = RunEvent(
+ eventType=RunState.COMPLETE if is_successful else RunState.FAIL,
+ eventTime=end_time.isoformat(),
+ run=run,
+ job=job,
+ )
+ return start, end
+
+
+@require_openlineage_version(provider_min_version="2.3.0")
+def emit_openlineage_events_for_databricks_queries(
+ query_ids: list[str],
+ query_source_namespace: str,
+ task_instance,
+ hook: DatabricksSqlHook | None = None,
+ additional_run_facets: dict | None = None,
+ additional_job_facets: dict | None = None,
+) -> None:
+ """
+ Emit OpenLineage events for executed Databricks queries.
+
+ Metadata retrieval from Databricks is attempted only if a
`DatabricksSqlHook` is provided.
+ If metadata is available, execution details such as start time, end time,
execution status,
+ error messages, and SQL text are included in the events. If no metadata is
found, the function
+ defaults to using the Airflow task instance's state and the current
timestamp.
+
+ Note that both START and COMPLETE event for each query will be emitted at
the same time.
+ If we are able to query Databricks for query execution metadata, event
times
+ will correspond to actual query execution times.
+
+ Args:
+ query_ids: A list of Databricks query IDs to emit events for.
+ query_source_namespace: The namespace to be included in
ExternalQueryRunFacet.
+ task_instance: The Airflow task instance that run these queries.
+ hook: A hook instance used to retrieve query metadata if available.
+ additional_run_facets: Additional run facets to include in OpenLineage
events.
+ additional_job_facets: Additional job facets to include in OpenLineage
events.
+ """
+ from openlineage.client.facet_v2 import job_type_job
+
+ from airflow.providers.common.compat.openlineage.facet import (
+ ErrorMessageRunFacet,
+ ExternalQueryRunFacet,
+ RunFacet,
+ SQLJobFacet,
+ )
+ from airflow.providers.openlineage.conf import namespace
+ from airflow.providers.openlineage.plugins.listener import
get_openlineage_listener
+
+ if not query_ids:
+ log.debug("No Databricks query IDs provided; skipping OpenLineage
event emission.")
+ return
+
+ query_ids = [q for q in query_ids] # Make a copy to make sure it does not
change
+
+ if hook:
+ log.debug("Retrieving metadata for %s queries from Databricks.",
len(query_ids))
+ databricks_metadata = _get_queries_details_from_databricks(hook,
query_ids)
+ else:
+ log.debug("DatabricksSqlHook not provided. No extra metadata fill be
fetched from Databricks.")
+ databricks_metadata = {}
+
+ # If real metadata is unavailable, we send events with eventTime=now
+ default_event_time = timezone.utcnow()
+ # If no query metadata is provided, we use task_instance's state when
checking for success
+ # Adjust state for DBX logic, where "finished" means "success"
+ default_state = task_instance.state.value if hasattr(task_instance,
"state") else ""
+ default_state = "finished" if default_state == "success" else default_state
+
+ log.debug("Generating OpenLineage facets")
+ common_run_facets = {"parent": _get_parent_run_facet(task_instance)}
+ common_job_facets: dict[str, JobFacet] = {
+ "jobType": job_type_job.JobTypeJobFacet(
+ jobType="QUERY",
+ integration="DATABRICKS",
Review Comment:
I still don't like that we pretend we're databricks (or snowflake, etc)
here, but it's okay for consistency now.
##########
providers/databricks/pyproject.toml:
##########
@@ -83,14 +85,19 @@ dependencies = [
"standard" = [
"apache-airflow-providers-standard"
]
+"openlineage" = [
+ "apache-airflow-providers-openlineage"
+]
[dependency-groups]
dev = [
"apache-airflow",
"apache-airflow-task-sdk",
"apache-airflow-devel-common",
+ "apache-airflow-providers-common-compat",
Review Comment:
And here
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]