kacpermuda commented on code in PR #50392: URL: https://github.com/apache/airflow/pull/50392#discussion_r2084367730
########## providers/databricks/src/airflow/providers/databricks/utils/openlineage.py: ########## @@ -0,0 +1,328 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +import json +import logging +from typing import TYPE_CHECKING, Any + +import requests + +from airflow.providers.common.compat.openlineage.check import require_openlineage_version +from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.utils import timezone + +if TYPE_CHECKING: + from openlineage.client.event_v2 import RunEvent + from openlineage.client.facet_v2 import JobFacet + + from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook + + +log = logging.getLogger(__name__) + + +def _get_logical_date(task_instance): + # todo: remove when min airflow version >= 3.0 + if AIRFLOW_V_3_0_PLUS: + dagrun = task_instance.get_template_context()["dag_run"] + return dagrun.logical_date or dagrun.run_after + + if hasattr(task_instance, "logical_date"): + date = task_instance.logical_date + else: + date = task_instance.execution_date + + return date + + +def _get_dag_run_clear_number(task_instance): + # todo: remove when min airflow version >= 3.0 + if AIRFLOW_V_3_0_PLUS: + dagrun = task_instance.get_template_context()["dag_run"] + return dagrun.clear_number + return task_instance.dag_run.clear_number + + +# todo: move this run_id logic into OpenLineage's listener to avoid differences +def _get_ol_run_id(task_instance) -> str: + """ + Get OpenLineage run_id from TaskInstance. + + It's crucial that the task_instance's run_id creation logic matches OpenLineage's listener implementation. + Only then can we ensure that the generated run_id aligns with the Airflow task, + enabling a proper connection between events. + """ + from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter + + # Generate same OL run id as is generated for current task instance + return OpenLineageAdapter.build_task_instance_run_id( + dag_id=task_instance.dag_id, + task_id=task_instance.task_id, + logical_date=_get_logical_date(task_instance), + try_number=task_instance.try_number, + map_index=task_instance.map_index, + ) + + +# todo: move this run_id logic into OpenLineage's listener to avoid differences +def _get_ol_dag_run_id(task_instance) -> str: + from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter + + return OpenLineageAdapter.build_dag_run_id( + dag_id=task_instance.dag_id, + logical_date=_get_logical_date(task_instance), + clear_number=_get_dag_run_clear_number(task_instance), + ) + + +def _get_parent_run_facet(task_instance): + """ + Retrieve the ParentRunFacet associated with a specific Airflow task instance. + + This facet helps link OpenLineage events of child jobs - such as queries executed within + external systems (e.g., Databricks) by the Airflow task - to the original Airflow task execution. + Establishing this connection enables better lineage tracking and observability. + """ + from openlineage.client.facet_v2 import parent_run + + from airflow.providers.openlineage.conf import namespace + + parent_run_id = _get_ol_run_id(task_instance) + root_parent_run_id = _get_ol_dag_run_id(task_instance) + + return parent_run.ParentRunFacet( + run=parent_run.Run(runId=parent_run_id), + job=parent_run.Job( + namespace=namespace(), + name=f"{task_instance.dag_id}.{task_instance.task_id}", + ), + root=parent_run.Root( + run=parent_run.RootRun(runId=root_parent_run_id), + job=parent_run.RootJob( + name=task_instance.dag_id, + namespace=namespace(), + ), + ), + ) + + +def _run_api_call(hook: DatabricksSqlHook, query_ids: list[str]) -> list[dict]: + """Retrieve execution details for specific queries from Databricks's query history API.""" + if not hook._token: + # This has logic for token initialization + hook.get_conn() + + # https://docs.databricks.com/api/azure/workspace/queryhistory/list + response = requests.get( + url=f"https://{hook.host}/api/2.0/sql/history/queries", + headers={"Authorization": f"Bearer {hook._token}"}, + data=json.dumps({"filter_by": {"statement_ids": query_ids}}), + ) + if response.status_code != 200: + log.warning( + "OpenLineage could not retrieve Databricks queries details. API error received: `%s`: `%s`", + response.status_code, + response.text, + ) + return [] + + return response.json()["res"] + + +def _get_queries_details_from_databricks( + hook: DatabricksSqlHook, query_ids: list[str] +) -> dict[str, dict[str, Any]]: + if not query_ids: + return {} + + queries_info_from_api = _run_api_call(hook=hook, query_ids=query_ids) + + query_details = {} + for query_info in queries_info_from_api: + if not query_info.get("query_id"): + log.debug("Databricks query ID not found in API response.") + continue + + q_start_time = None + q_end_time = None + if query_info.get("query_start_time_ms") and query_info.get("query_end_time_ms"): + q_start_time = datetime.datetime.fromtimestamp( + query_info["query_start_time_ms"] / 1000, tz=datetime.timezone.utc + ) + q_end_time = datetime.datetime.fromtimestamp( + query_info["query_end_time_ms"] / 1000, tz=datetime.timezone.utc + ) + + query_details[query_info["query_id"]] = { + "status": query_info.get("status"), + "start_time": q_start_time, + "end_time": q_end_time, + "query_text": query_info.get("query_text"), + "error_message": query_info.get("error_message"), + } + + return query_details + + +def _create_ol_event_pair( + job_namespace: str, + job_name: str, + start_time: datetime.datetime, + end_time: datetime.datetime, + is_successful: bool, + run_facets: dict | None = None, + job_facets: dict | None = None, +) -> tuple[RunEvent, RunEvent]: + """Create a pair of OpenLineage RunEvents representing the start and end of a query execution.""" + from openlineage.client.event_v2 import Job, Run, RunEvent, RunState + from openlineage.client.uuid import generate_new_uuid + + run = Run(runId=str(generate_new_uuid()), facets=run_facets or {}) + job = Job(namespace=job_namespace, name=job_name, facets=job_facets or {}) + + start = RunEvent( + eventType=RunState.START, + eventTime=start_time.isoformat(), + run=run, + job=job, + ) + end = RunEvent( + eventType=RunState.COMPLETE if is_successful else RunState.FAIL, + eventTime=end_time.isoformat(), + run=run, + job=job, + ) + return start, end + + +@require_openlineage_version(provider_min_version="2.3.0") +def emit_openlineage_events_for_databricks_queries( + query_ids: list[str], + query_source_namespace: str, + task_instance, + hook: DatabricksSqlHook | None = None, + additional_run_facets: dict | None = None, + additional_job_facets: dict | None = None, +) -> None: + """ + Emit OpenLineage events for executed Databricks queries. + + Metadata retrieval from Databricks is attempted only if a `DatabricksSqlHook` is provided. + If metadata is available, execution details such as start time, end time, execution status, + error messages, and SQL text are included in the events. If no metadata is found, the function + defaults to using the Airflow task instance's state and the current timestamp. + + Note that both START and COMPLETE event for each query will be emitted at the same time. + If we are able to query Databricks for query execution metadata, event times + will correspond to actual query execution times. + + Args: + query_ids: A list of Databricks query IDs to emit events for. + query_source_namespace: The namespace to be included in ExternalQueryRunFacet. + task_instance: The Airflow task instance that run these queries. + hook: A hook instance used to retrieve query metadata if available. + additional_run_facets: Additional run facets to include in OpenLineage events. + additional_job_facets: Additional job facets to include in OpenLineage events. + """ + from openlineage.client.facet_v2 import job_type_job + + from airflow.providers.common.compat.openlineage.facet import ( + ErrorMessageRunFacet, + ExternalQueryRunFacet, + RunFacet, + SQLJobFacet, + ) + from airflow.providers.openlineage.conf import namespace + from airflow.providers.openlineage.plugins.listener import get_openlineage_listener + + if not query_ids: + log.debug("No Databricks query IDs provided; skipping OpenLineage event emission.") + return + + query_ids = [q for q in query_ids] # Make a copy to make sure it does not change + + if hook: + log.debug("Retrieving metadata for %s queries from Databricks.", len(query_ids)) + databricks_metadata = _get_queries_details_from_databricks(hook, query_ids) + else: + log.debug("DatabricksSqlHook not provided. No extra metadata fill be fetched from Databricks.") + databricks_metadata = {} + + # If real metadata is unavailable, we send events with eventTime=now + default_event_time = timezone.utcnow() + # If no query metadata is provided, we use task_instance's state when checking for success + # Adjust state for DBX logic, where "finished" means "success" + default_state = task_instance.state.value if hasattr(task_instance, "state") else "" + default_state = "finished" if default_state == "success" else default_state + + log.debug("Generating OpenLineage facets") + common_run_facets = {"parent": _get_parent_run_facet(task_instance)} + common_job_facets: dict[str, JobFacet] = { + "jobType": job_type_job.JobTypeJobFacet( + jobType="QUERY", + integration="DATABRICKS", + processingType="BATCH", + ) + } + additional_run_facets = additional_run_facets or {} + additional_job_facets = additional_job_facets or {} + + events: list[RunEvent] = [] + for counter, query_id in enumerate(query_ids, 1): + query_metadata = databricks_metadata.get(query_id, {}) + log.debug( + "Metadata for query no. %s, (ID `%s`): `%s`", + counter, + query_id, + query_metadata if query_metadata else "not found", + ) + + query_specific_run_facets: dict[str, RunFacet] = { + "externalQuery": ExternalQueryRunFacet(externalQueryId=query_id, source=query_source_namespace) + } + if query_metadata.get("error_message"): + query_specific_run_facets["error"] = ErrorMessageRunFacet( + message=query_metadata["error_message"], + programmingLanguage="SQL", + ) + + query_specific_job_facets = {} + if query_metadata.get("query_text"): + query_specific_job_facets["sql"] = SQLJobFacet(query=query_metadata["query_text"]) + + log.debug("Creating OpenLineage event pair for query ID: %s", query_id) + event_batch = _create_ol_event_pair( + job_namespace=namespace(), + job_name=f"{task_instance.dag_id}.{task_instance.task_id}.query.{counter}", + start_time=query_metadata.get("start_time", default_event_time), # type: ignore[arg-type] + end_time=query_metadata.get("end_time", default_event_time), # type: ignore[arg-type] + # Only finished status means it completed without failures + is_successful=query_metadata.get("status", default_state).lower() == "finished", + run_facets={**query_specific_run_facets, **common_run_facets, **additional_run_facets}, + job_facets={**query_specific_job_facets, **common_job_facets, **additional_job_facets}, + ) + events.extend(event_batch) + + log.debug("Generated %s OpenLineage events; emitting now.", len(events)) + adapter = get_openlineage_listener().adapter + for event in events: + adapter.emit(event) Review Comment: Are you asking about the time it takes to emit these events? They are lightweight post requests, so not too long. Users can also limit the time OL has for execution with Airflow configs, or make it longer if they appear to be running a lot of queries at once. Just a reminder, this `overhead` is still applicable only for those users that enable OpenLineage. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
