This is an automated email from the ASF dual-hosted git repository.

pankajkoti pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1ee69712730 Accept task_key as an argument in 
`DatabricksNotebookOperator` (#44960)
1ee69712730 is described below

commit 1ee69712730b4919c7796a32babc75d34380274e
Author: Idris Adebisi <[email protected]>
AuthorDate: Thu Dec 26 07:46:57 2024 +0000

    Accept task_key as an argument in `DatabricksNotebookOperator` (#44960)
    
    This PR introduces the ability for users to explicitly specify 
databricks_task_key as a parameter for the DatabricksNotebookOperator. If 
databricks_task_key is not provided, a default value is generated using the 
hash of the dag_id and task_id.
    
    Key Changes:
    
    Users can now define databricks_task_key explicitly.
    When not provided, the key defaults to a deterministic hash based on dag_id 
and task_id.
    Fixes: #41816
    Fixes: #44250
    related: #43106
---
 .../providers/databricks/operators/databricks.py   | 43 +++++++++++++---------
 .../databricks/plugins/databricks_workflow.py      | 22 ++++-------
 .../tests/databricks/operators/test_databricks.py  | 28 +++++++++++++-
 .../databricks/plugins/test_databricks_workflow.py | 22 ++---------
 4 files changed, 63 insertions(+), 52 deletions(-)

diff --git a/providers/src/airflow/providers/databricks/operators/databricks.py 
b/providers/src/airflow/providers/databricks/operators/databricks.py
index 1b8d45fa479..b8fde94c594 100644
--- a/providers/src/airflow/providers/databricks/operators/databricks.py
+++ b/providers/src/airflow/providers/databricks/operators/databricks.py
@@ -19,6 +19,7 @@
 
 from __future__ import annotations
 
+import hashlib
 import time
 from abc import ABC, abstractmethod
 from collections.abc import Sequence
@@ -966,6 +967,8 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
 
     :param caller: The name of the caller operator to be used in the logs.
     :param databricks_conn_id: The name of the Airflow connection to use.
+    :param databricks_task_key: An optional task_key used to refer to the task 
by Databricks API. By
+        default this will be set to the hash of ``dag_id + task_id``.
     :param databricks_retry_args: An optional dictionary with arguments passed 
to ``tenacity.Retrying`` class.
     :param databricks_retry_delay: Number of seconds to wait between retries.
     :param databricks_retry_limit: Amount of times to retry if the Databricks 
backend is unreachable.
@@ -986,6 +989,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
         self,
         caller: str = "DatabricksTaskBaseOperator",
         databricks_conn_id: str = "databricks_default",
+        databricks_task_key: str = "",
         databricks_retry_args: dict[Any, Any] | None = None,
         databricks_retry_delay: int = 1,
         databricks_retry_limit: int = 3,
@@ -1000,6 +1004,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
     ):
         self.caller = caller
         self.databricks_conn_id = databricks_conn_id
+        self._databricks_task_key = databricks_task_key
         self.databricks_retry_args = databricks_retry_args
         self.databricks_retry_delay = databricks_retry_delay
         self.databricks_retry_limit = databricks_retry_limit
@@ -1037,17 +1042,21 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
             caller=caller,
         )
 
-    def _get_databricks_task_id(self, task_id: str) -> str:
-        """Get the databricks task ID using dag_id and task_id. Removes 
illegal characters."""
-        task_id = f"{self.dag_id}__{task_id.replace('.', '__')}"
-        if len(task_id) > 100:
-            self.log.warning(
-                "The generated task_key '%s' exceeds 100 characters and will 
be truncated by the Databricks API. "
-                "This will cause failure when trying to monitor the task. 
task_key is generated by ",
-                "concatenating dag_id and task_id.",
-                task_id,
+    @cached_property
+    def databricks_task_key(self) -> str:
+        return self._generate_databricks_task_key()
+
+    def _generate_databricks_task_key(self, task_id: str | None = None) -> str:
+        """Create a databricks task key using the hash of dag_id and 
task_id."""
+        if not self._databricks_task_key or len(self._databricks_task_key) > 
100:
+            self.log.info(
+                "databricks_task_key has not be provided or the provided one 
exceeds 100 characters and will be truncated by the Databricks API. This will 
cause failure when trying to monitor the task. A task_key will be generated 
using the hash value of dag_id+task_id"
             )
-        return task_id
+            task_id = task_id or self.task_id
+            task_key = f"{self.dag_id}__{task_id}".encode()
+            self._databricks_task_key = hashlib.md5(task_key).hexdigest()
+            self.log.info("Generated databricks task_key: %s", 
self._databricks_task_key)
+        return self._databricks_task_key
 
     @property
     def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | 
None:
@@ -1077,7 +1086,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
     def _get_run_json(self) -> dict[str, Any]:
         """Get run json to be used for task submissions."""
         run_json = {
-            "run_name": self._get_databricks_task_id(self.task_id),
+            "run_name": self.databricks_task_key,
             **self._get_task_base_json(),
         }
         if self.new_cluster and self.existing_cluster_id:
@@ -1127,9 +1136,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
         # building the {task_key: task} map below.
         sorted_task_runs = sorted(tasks, key=lambda x: x["start_time"])
 
-        return {task["task_key"]: task for task in sorted_task_runs}[
-            self._get_databricks_task_id(self.task_id)
-        ]
+        return {task["task_key"]: task for task in 
sorted_task_runs}[self.databricks_task_key]
 
     def _convert_to_databricks_workflow_task(
         self, relevant_upstreams: list[BaseOperator], context: Context | None 
= None
@@ -1137,9 +1144,9 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
         """Convert the operator to a Databricks workflow task that can be a 
task in a workflow."""
         base_task_json = self._get_task_base_json()
         result = {
-            "task_key": self._get_databricks_task_id(self.task_id),
+            "task_key": self.databricks_task_key,
             "depends_on": [
-                {"task_key": self._get_databricks_task_id(task_id)}
+                {"task_key": self._generate_databricks_task_key(task_id)}
                 for task_id in self.upstream_task_ids
                 if task_id in relevant_upstreams
             ],
@@ -1172,7 +1179,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
         run_state = RunState(**run["state"])
         self.log.info(
             "Current state of the the databricks task %s is %s",
-            self._get_databricks_task_id(self.task_id),
+            self.databricks_task_key,
             run_state.life_cycle_state,
         )
         if self.deferrable and not run_state.is_terminal:
@@ -1194,7 +1201,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
             run_state = RunState(**run["state"])
             self.log.info(
                 "Current state of the databricks task %s is %s",
-                self._get_databricks_task_id(self.task_id),
+                self.databricks_task_key,
                 run_state.life_cycle_state,
             )
         self._handle_terminal_run_state(run_state)
diff --git 
a/providers/src/airflow/providers/databricks/plugins/databricks_workflow.py 
b/providers/src/airflow/providers/databricks/plugins/databricks_workflow.py
index d0d8a179dd3..81cf09b0f8d 100644
--- a/providers/src/airflow/providers/databricks/plugins/databricks_workflow.py
+++ b/providers/src/airflow/providers/databricks/plugins/databricks_workflow.py
@@ -44,6 +44,8 @@ from airflow.www.views import AirflowBaseView
 if TYPE_CHECKING:
     from sqlalchemy.orm.session import Session
 
+    from airflow.providers.databricks.operators.databricks import 
DatabricksTaskBaseOperator
+
 
 REPAIR_WAIT_ATTEMPTS = os.getenv("DATABRICKS_REPAIR_WAIT_ATTEMPTS", 20)
 REPAIR_WAIT_DELAY = os.getenv("DATABRICKS_REPAIR_WAIT_DELAY", 0.5)
@@ -57,18 +59,8 @@ def get_auth_decorator():
     return auth.has_access_dag("POST", DagAccessEntity.RUN)
 
 
-def _get_databricks_task_id(task: BaseOperator) -> str:
-    """
-    Get the databricks task ID using dag_id and task_id. removes illegal 
characters.
-
-    :param task: The task to get the databricks task ID for.
-    :return: The databricks task ID.
-    """
-    return f"{task.dag_id}__{task.task_id.replace('.', '__')}"
-
-
 def get_databricks_task_ids(
-    group_id: str, task_map: dict[str, BaseOperator], log: logging.Logger
+    group_id: str, task_map: dict[str, DatabricksTaskBaseOperator], log: 
logging.Logger
 ) -> list[str]:
     """
     Return a list of all Databricks task IDs for a dictionary of Airflow tasks.
@@ -83,7 +75,7 @@ def get_databricks_task_ids(
     for task_id, task in task_map.items():
         if task_id == f"{group_id}.launch":
             continue
-        databricks_task_id = _get_databricks_task_id(task)
+        databricks_task_id = task.databricks_task_key
         log.debug("databricks task id for task %s is %s", task_id, 
databricks_task_id)
         task_ids.append(databricks_task_id)
     return task_ids
@@ -112,7 +104,7 @@ def _clear_task_instances(
     dag = airflow_app.dag_bag.get_dag(dag_id)
     log.debug("task_ids %s to clear", str(task_ids))
     dr: DagRun = _get_dagrun(dag, run_id, session=session)
-    tis_to_clear = [ti for ti in dr.get_task_instances() if 
_get_databricks_task_id(ti) in task_ids]
+    tis_to_clear = [ti for ti in dr.get_task_instances() if 
ti.databricks_task_key in task_ids]
     clear_task_instances(tis_to_clear, session)
 
 
@@ -327,7 +319,7 @@ class WorkflowJobRepairAllFailedLink(BaseOperatorLink, 
LoggingMixin):
 
         tasks_to_run = {ti: t for ti, t in task_group_sub_tasks if ti in 
failed_and_skipped_tasks}
 
-        return ",".join(get_databricks_task_ids(task_group.group_id, 
tasks_to_run, log))
+        return ",".join(get_databricks_task_ids(task_group.group_id, 
tasks_to_run, log))  # type: ignore[arg-type]
 
     @staticmethod
     def _get_failed_and_skipped_tasks(dr: DagRun) -> list[str]:
@@ -390,7 +382,7 @@ class WorkflowJobRepairSingleTaskLink(BaseOperatorLink, 
LoggingMixin):
             "databricks_conn_id": metadata.conn_id,
             "databricks_run_id": metadata.run_id,
             "run_id": ti_key.run_id,
-            "tasks_to_repair": _get_databricks_task_id(task),
+            "tasks_to_repair": task.databricks_task_key,
         }
         return url_for("RepairDatabricksTasks.repair", **query_params)
 
diff --git a/providers/tests/databricks/operators/test_databricks.py 
b/providers/tests/databricks/operators/test_databricks.py
index da3c697360f..51e7a765998 100644
--- a/providers/tests/databricks/operators/test_databricks.py
+++ b/providers/tests/databricks/operators/test_databricks.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+import hashlib
 from datetime import datetime, timedelta
 from typing import Any
 from unittest import mock
@@ -2216,8 +2217,9 @@ class TestDatabricksNotebookOperator:
 
         task_json = 
operator._convert_to_databricks_workflow_task(relevant_upstreams)
 
+        task_key = hashlib.md5(b"example_dag__test_task").hexdigest()
         expected_json = {
-            "task_key": "example_dag__test_task",
+            "task_key": task_key,
             "depends_on": [],
             "timeout_seconds": 0,
             "email_notifications": {},
@@ -2317,3 +2319,27 @@ class TestDatabricksTaskOperator:
 
         assert operator.task_config == task_config
         assert task_base_json == task_config
+
+    def test_generate_databricks_task_key(self):
+        task_config = {}
+        operator = DatabricksTaskOperator(
+            task_id="test_task",
+            databricks_conn_id="test_conn_id",
+            task_config=task_config,
+        )
+
+        task_key = f"{operator.dag_id}__{operator.task_id}".encode()
+        expected_task_key = hashlib.md5(task_key).hexdigest()
+        assert expected_task_key == operator.databricks_task_key
+
+    def test_user_databricks_task_key(self):
+        task_config = {}
+        operator = DatabricksTaskOperator(
+            task_id="test_task",
+            databricks_conn_id="test_conn_id",
+            databricks_task_key="test_task_key",
+            task_config=task_config,
+        )
+        expected_task_key = "test_task_key"
+
+        assert expected_task_key == operator.databricks_task_key
diff --git a/providers/tests/databricks/plugins/test_databricks_workflow.py 
b/providers/tests/databricks/plugins/test_databricks_workflow.py
index 35d3496c66e..a22febb5b7f 100644
--- a/providers/tests/databricks/plugins/test_databricks_workflow.py
+++ b/providers/tests/databricks/plugins/test_databricks_workflow.py
@@ -32,7 +32,6 @@ from airflow.providers.databricks.plugins.databricks_workflow 
import (
     WorkflowJobRepairSingleTaskLink,
     WorkflowJobRunLink,
     _get_dagrun,
-    _get_databricks_task_id,
     _get_launch_task_key,
     _repair_task,
     get_databricks_task_ids,
@@ -50,30 +49,17 @@ TASK_INSTANCE_KEY = TaskInstanceKey(dag_id=DAG_ID, 
task_id=TASK_ID, run_id=RUN_I
 DATABRICKS_CONN_ID = "databricks_default"
 DATABRICKS_RUN_ID = 12345
 GROUP_ID = "test_group"
+LOG = MagicMock()
 TASK_MAP = {
-    "task1": MagicMock(dag_id=DAG_ID, task_id="task1"),
-    "task2": MagicMock(dag_id=DAG_ID, task_id="task2"),
+    "task1": MagicMock(dag_id=DAG_ID, task_id="task1", 
databricks_task_key="task_key1"),
+    "task2": MagicMock(dag_id=DAG_ID, task_id="task2", 
databricks_task_key="task_key2"),
 }
-LOG = MagicMock()
-
-
[email protected](
-    "task, expected_id",
-    [
-        (MagicMock(dag_id="dag1", task_id="task.1"), "dag1__task__1"),
-        (MagicMock(dag_id="dag2", task_id="task_1"), "dag2__task_1"),
-    ],
-)
-def test_get_databricks_task_id(task, expected_id):
-    result = _get_databricks_task_id(task)
-
-    assert result == expected_id
 
 
 def test_get_databricks_task_ids():
     result = get_databricks_task_ids(GROUP_ID, TASK_MAP, LOG)
 
-    expected_ids = ["test_dag__task1", "test_dag__task2"]
+    expected_ids = ["task_key1", "task_key2"]
     assert result == expected_ids
 
 

Reply via email to