potiuk commented on code in PR #40153: URL: https://github.com/apache/airflow/pull/40153#discussion_r1663852033
########## airflow/providers/databricks/plugins/databricks_workflow.py: ########## @@ -0,0 +1,453 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +import os +from operator import itemgetter +from typing import TYPE_CHECKING, Any, cast + +from flask import current_app, flash, redirect, request +from flask_appbuilder.api import expose + +from airflow.configuration import conf +from airflow.exceptions import AirflowException, TaskInstanceNotFound +from airflow.models import BaseOperator, BaseOperatorLink +from airflow.models.dag import DAG, clear_task_instances +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance, TaskInstanceKey +from airflow.models.xcom import XCom +from airflow.plugins_manager import AirflowPlugin +from airflow.providers.databricks.hooks.databricks import DatabricksHook +from airflow.utils.airflow_flask_app import AirflowApp +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.state import TaskInstanceState +from airflow.utils.task_group import TaskGroup +from airflow.www.views import AirflowBaseView + +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + + +REPAIR_WAIT_ATTEMPTS = os.getenv("DATABRICKS_REPAIR_WAIT_ATTEMPTS", 20) +REPAIR_WAIT_DELAY = os.getenv("DATABRICKS_REPAIR_WAIT_DELAY", 0.5) + +airflow_app = cast(AirflowApp, current_app) + + +def _get_databricks_task_id(task: BaseOperator) -> str: + """ + Get the databricks task ID using dag_id and task_id. removes illegal characters. + + :param task: The task to get the databricks task ID for. + :return: The databricks task ID. + """ + return f"{task.dag_id}__{task.task_id.replace('.', '__')}" + + +def get_databricks_task_ids( + group_id: str, task_map: dict[str, BaseOperator], log: logging.Logger +) -> list[str]: + """ + Return a list of all Databricks task IDs for a dictionary of Airflow tasks. + + :param group_id: The task group ID. + :param task_map: A dictionary mapping task IDs to BaseOperator instances. + :param log: The logger to use for logging. + :return: A list of Databricks task IDs for the given task group. + """ + task_ids = [] + log.debug("Getting databricks task ids for group %s", group_id) + for task_id, task in task_map.items(): + if task_id == f"{group_id}.launch": + continue + databricks_task_id = _get_databricks_task_id(task) + log.debug("databricks task id for task %s is %s", task_id, databricks_task_id) + task_ids.append(databricks_task_id) + return task_ids + + +@provide_session +def _get_dagrun(dag: DAG, run_id: str, session: Session | None = None) -> DagRun: + """ + Retrieve the DagRun object associated with the specified DAG and run_id. + + :param dag: The DAG object associated with the DagRun to retrieve. + :param run_id: The run_id associated with the DagRun to retrieve. + :param session: The SQLAlchemy session to use for the query. If None, uses the default session. + :return: The DagRun object associated with the specified DAG and run_id. + """ + if not session: + raise AirflowException("Session not provided.") + + return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).first() + + +@provide_session +def _clear_task_instances( + dag_id: str, run_id: str, task_ids: list[str], log: logging.Logger, session: Session | None = None +) -> None: + dag = airflow_app.dag_bag.get_dag(dag_id) + log.debug("task_ids %s to clear", str(task_ids)) + dr: DagRun = _get_dagrun(dag, run_id, session=session) + tis_to_clear = [ti for ti in dr.get_task_instances() if _get_databricks_task_id(ti) in task_ids] + clear_task_instances(tis_to_clear, session) + + +def _repair_task( + databricks_conn_id: str, + databricks_run_id: int, + tasks_to_repair: list[str], + logger: logging.Logger, +) -> int: + """ + Repair a Databricks task using the Databricks API. + + This function allows the Airflow retry function to create a repair job for Databricks. + It uses the Databricks API to get the latest repair ID before sending the repair query. + + :param databricks_conn_id: The Databricks connection ID. + :param databricks_run_id: The Databricks run ID. + :param tasks_to_repair: A list of Databricks task IDs to repair. + :param logger: The logger to use for logging. + :return: None + """ + hook = DatabricksHook(databricks_conn_id=databricks_conn_id) + + repair_history_id = hook.get_latest_repair_id(databricks_run_id) + logger.debug("Latest repair ID is %s", repair_history_id) + logger.debug( + "Sending repair query for tasks %s on run %s", + tasks_to_repair, + databricks_run_id, + ) + + repair_json = { + "run_id": databricks_run_id, + "latest_repair_id": repair_history_id, + "rerun_tasks": tasks_to_repair, + } + + return hook.repair_run(repair_json) + + +def get_launch_task_id(task_group: TaskGroup) -> str: + """ + Retrieve the launch task ID from the current task group or a parent task group, recursively. + + :param task_group: Task Group to be inspected + :return: launch Task ID + """ + try: + launch_task_id = task_group.get_child_by_label("launch").task_id # type: ignore[attr-defined] + except KeyError as e: + if not task_group.parent_group: + raise AirflowException("No launch task can be found in the task group.") from e + launch_task_id = get_launch_task_id(task_group.parent_group) + + return launch_task_id + + +def _get_launch_task_key(current_task_key: TaskInstanceKey, task_id: str) -> TaskInstanceKey: + """ + Return the task key for the launch task. + + This allows us to gather databricks Metadata even if the current task has failed (since tasks only + create xcom values if they succeed). + + :param current_task_key: The task key for the current task. + :param task_id: The task ID for the current task. + :return: The task key for the launch task. + """ + if task_id: + return TaskInstanceKey( + dag_id=current_task_key.dag_id, + task_id=task_id, + run_id=current_task_key.run_id, + try_number=current_task_key.try_number, + ) + + return current_task_key + + +@provide_session +def get_task_instance(operator: BaseOperator, dttm, session: Session = NEW_SESSION) -> TaskInstance: + dag_id = operator.dag.dag_id + dag_run = DagRun.find(dag_id, execution_date=dttm)[0] + ti = ( + session.query(TaskInstance) + .filter( + TaskInstance.dag_id == dag_id, + TaskInstance.run_id == dag_run.run_id, + TaskInstance.task_id == operator.task_id, + ) + .one_or_none() + ) + if not ti: + raise TaskInstanceNotFound("Task instance not found") + return ti + + +def get_xcom_result( + ti_key: TaskInstanceKey, + key: str, + ti: TaskInstance | None, +) -> Any: + result = XCom.get_value( + ti_key=ti_key, + key=key, + ) + from airflow.providers.databricks.operators.databricks_workflow import WorkflowRunMetadata + + return WorkflowRunMetadata(**result) + + +class WorkflowJobRunLink(BaseOperatorLink, LoggingMixin): + """Constructs a link to monitor a Databricks Job Run.""" Review Comment: Thanks a lot for the video @pankajkoti . It's been super helpful. Yes. This is working exactly as I'd imagined it, so it confirms my assesment - that this is precisely something we should build into airflow core. However, I do see that as an upcoming Airlfow 3.1+ feature, not something that we will be able to im[plement now, nor even as part of 3.0 effort. Adding all the interaction possibilities - launching the job behind the scenes, observing and monitoring individual tasks run in the external workflow, ability to clear and "re-run" (i.e. repair) individual tasks are all the features I'd imagined "workflow-in-workfow" case should have as a core feature and well designed API. But - we have other things as priority to implement for Airllow 3, and also such feature will only be available in Airflow 3.x (when it is implement and if we agree that such functionality is something we should implement in Airlfow). Implementing it as plugin in the way you described is more of a band-aid and it's not easily replicable to other workflows without similar complexity. But it has an undeniable advantage that we can make it works now in Airflow 2, put it in the hands of our users today and learn from it when we attempt to implement the Airflow 3 "generic" solution. Which we might not be doing for a year or so, so waiting for it does not allow us to learn from the ways our users would like to use it. So - overall - I'd be for doing it this way with few caveats: * we commit to make it Airflow 3 -plugin compliant (i.e. future version of the databricks workflow provider will have the airflow 3 - compatible plugin variant). I uderstand this is "given" * we plan to have a workstream (Airflow 3.1+) where we turn such approach (also taking patterns from Cosmos integration) where we turn it into a generic solution We do not have to do the "generic" attempts now - just plan it as a natural evolution of this approach and whenever we get to Airflow 3.1+ to implement it (or generally where we fill we would like to do simlar approach for another workflow) - we should THEN sit down and decompose it into what would it mean for Core Airflow to support this kind of "external workflow" integration in Airflow. That would by my ask at this moment. I think the way it is implemented now - where we use plugins and make such workflow available for Airflow 2 users is also a nice "teaser" - it would give us a feedback from the users on how useful and needed it is, so that we could have more clarity on what exactly is needed, what works, what kind of problems we need to solve when we attempt to make it Airflow 3 "feature". -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
