This is an automated email from the ASF dual-hosted git repository.
pankajkoti pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1ee69712730 Accept task_key as an argument in
`DatabricksNotebookOperator` (#44960)
1ee69712730 is described below
commit 1ee69712730b4919c7796a32babc75d34380274e
Author: Idris Adebisi <[email protected]>
AuthorDate: Thu Dec 26 07:46:57 2024 +0000
Accept task_key as an argument in `DatabricksNotebookOperator` (#44960)
This PR introduces the ability for users to explicitly specify
databricks_task_key as a parameter for the DatabricksNotebookOperator. If
databricks_task_key is not provided, a default value is generated using the
hash of the dag_id and task_id.
Key Changes:
Users can now define databricks_task_key explicitly.
When not provided, the key defaults to a deterministic hash based on dag_id
and task_id.
Fixes: #41816
Fixes: #44250
related: #43106
---
.../providers/databricks/operators/databricks.py | 43 +++++++++++++---------
.../databricks/plugins/databricks_workflow.py | 22 ++++-------
.../tests/databricks/operators/test_databricks.py | 28 +++++++++++++-
.../databricks/plugins/test_databricks_workflow.py | 22 ++---------
4 files changed, 63 insertions(+), 52 deletions(-)
diff --git a/providers/src/airflow/providers/databricks/operators/databricks.py
b/providers/src/airflow/providers/databricks/operators/databricks.py
index 1b8d45fa479..b8fde94c594 100644
--- a/providers/src/airflow/providers/databricks/operators/databricks.py
+++ b/providers/src/airflow/providers/databricks/operators/databricks.py
@@ -19,6 +19,7 @@
from __future__ import annotations
+import hashlib
import time
from abc import ABC, abstractmethod
from collections.abc import Sequence
@@ -966,6 +967,8 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
:param caller: The name of the caller operator to be used in the logs.
:param databricks_conn_id: The name of the Airflow connection to use.
+ :param databricks_task_key: An optional task_key used to refer to the task
by Databricks API. By
+ default this will be set to the hash of ``dag_id + task_id``.
:param databricks_retry_args: An optional dictionary with arguments passed
to ``tenacity.Retrying`` class.
:param databricks_retry_delay: Number of seconds to wait between retries.
:param databricks_retry_limit: Amount of times to retry if the Databricks
backend is unreachable.
@@ -986,6 +989,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
self,
caller: str = "DatabricksTaskBaseOperator",
databricks_conn_id: str = "databricks_default",
+ databricks_task_key: str = "",
databricks_retry_args: dict[Any, Any] | None = None,
databricks_retry_delay: int = 1,
databricks_retry_limit: int = 3,
@@ -1000,6 +1004,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
):
self.caller = caller
self.databricks_conn_id = databricks_conn_id
+ self._databricks_task_key = databricks_task_key
self.databricks_retry_args = databricks_retry_args
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_limit = databricks_retry_limit
@@ -1037,17 +1042,21 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
caller=caller,
)
- def _get_databricks_task_id(self, task_id: str) -> str:
- """Get the databricks task ID using dag_id and task_id. Removes
illegal characters."""
- task_id = f"{self.dag_id}__{task_id.replace('.', '__')}"
- if len(task_id) > 100:
- self.log.warning(
- "The generated task_key '%s' exceeds 100 characters and will
be truncated by the Databricks API. "
- "This will cause failure when trying to monitor the task.
task_key is generated by ",
- "concatenating dag_id and task_id.",
- task_id,
+ @cached_property
+ def databricks_task_key(self) -> str:
+ return self._generate_databricks_task_key()
+
+ def _generate_databricks_task_key(self, task_id: str | None = None) -> str:
+ """Create a databricks task key using the hash of dag_id and
task_id."""
+ if not self._databricks_task_key or len(self._databricks_task_key) >
100:
+ self.log.info(
+ "databricks_task_key has not be provided or the provided one
exceeds 100 characters and will be truncated by the Databricks API. This will
cause failure when trying to monitor the task. A task_key will be generated
using the hash value of dag_id+task_id"
)
- return task_id
+ task_id = task_id or self.task_id
+ task_key = f"{self.dag_id}__{task_id}".encode()
+ self._databricks_task_key = hashlib.md5(task_key).hexdigest()
+ self.log.info("Generated databricks task_key: %s",
self._databricks_task_key)
+ return self._databricks_task_key
@property
def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup |
None:
@@ -1077,7 +1086,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
def _get_run_json(self) -> dict[str, Any]:
"""Get run json to be used for task submissions."""
run_json = {
- "run_name": self._get_databricks_task_id(self.task_id),
+ "run_name": self.databricks_task_key,
**self._get_task_base_json(),
}
if self.new_cluster and self.existing_cluster_id:
@@ -1127,9 +1136,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
# building the {task_key: task} map below.
sorted_task_runs = sorted(tasks, key=lambda x: x["start_time"])
- return {task["task_key"]: task for task in sorted_task_runs}[
- self._get_databricks_task_id(self.task_id)
- ]
+ return {task["task_key"]: task for task in
sorted_task_runs}[self.databricks_task_key]
def _convert_to_databricks_workflow_task(
self, relevant_upstreams: list[BaseOperator], context: Context | None
= None
@@ -1137,9 +1144,9 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
"""Convert the operator to a Databricks workflow task that can be a
task in a workflow."""
base_task_json = self._get_task_base_json()
result = {
- "task_key": self._get_databricks_task_id(self.task_id),
+ "task_key": self.databricks_task_key,
"depends_on": [
- {"task_key": self._get_databricks_task_id(task_id)}
+ {"task_key": self._generate_databricks_task_key(task_id)}
for task_id in self.upstream_task_ids
if task_id in relevant_upstreams
],
@@ -1172,7 +1179,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
run_state = RunState(**run["state"])
self.log.info(
"Current state of the the databricks task %s is %s",
- self._get_databricks_task_id(self.task_id),
+ self.databricks_task_key,
run_state.life_cycle_state,
)
if self.deferrable and not run_state.is_terminal:
@@ -1194,7 +1201,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
run_state = RunState(**run["state"])
self.log.info(
"Current state of the databricks task %s is %s",
- self._get_databricks_task_id(self.task_id),
+ self.databricks_task_key,
run_state.life_cycle_state,
)
self._handle_terminal_run_state(run_state)
diff --git
a/providers/src/airflow/providers/databricks/plugins/databricks_workflow.py
b/providers/src/airflow/providers/databricks/plugins/databricks_workflow.py
index d0d8a179dd3..81cf09b0f8d 100644
--- a/providers/src/airflow/providers/databricks/plugins/databricks_workflow.py
+++ b/providers/src/airflow/providers/databricks/plugins/databricks_workflow.py
@@ -44,6 +44,8 @@ from airflow.www.views import AirflowBaseView
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
+ from airflow.providers.databricks.operators.databricks import
DatabricksTaskBaseOperator
+
REPAIR_WAIT_ATTEMPTS = os.getenv("DATABRICKS_REPAIR_WAIT_ATTEMPTS", 20)
REPAIR_WAIT_DELAY = os.getenv("DATABRICKS_REPAIR_WAIT_DELAY", 0.5)
@@ -57,18 +59,8 @@ def get_auth_decorator():
return auth.has_access_dag("POST", DagAccessEntity.RUN)
-def _get_databricks_task_id(task: BaseOperator) -> str:
- """
- Get the databricks task ID using dag_id and task_id. removes illegal
characters.
-
- :param task: The task to get the databricks task ID for.
- :return: The databricks task ID.
- """
- return f"{task.dag_id}__{task.task_id.replace('.', '__')}"
-
-
def get_databricks_task_ids(
- group_id: str, task_map: dict[str, BaseOperator], log: logging.Logger
+ group_id: str, task_map: dict[str, DatabricksTaskBaseOperator], log:
logging.Logger
) -> list[str]:
"""
Return a list of all Databricks task IDs for a dictionary of Airflow tasks.
@@ -83,7 +75,7 @@ def get_databricks_task_ids(
for task_id, task in task_map.items():
if task_id == f"{group_id}.launch":
continue
- databricks_task_id = _get_databricks_task_id(task)
+ databricks_task_id = task.databricks_task_key
log.debug("databricks task id for task %s is %s", task_id,
databricks_task_id)
task_ids.append(databricks_task_id)
return task_ids
@@ -112,7 +104,7 @@ def _clear_task_instances(
dag = airflow_app.dag_bag.get_dag(dag_id)
log.debug("task_ids %s to clear", str(task_ids))
dr: DagRun = _get_dagrun(dag, run_id, session=session)
- tis_to_clear = [ti for ti in dr.get_task_instances() if
_get_databricks_task_id(ti) in task_ids]
+ tis_to_clear = [ti for ti in dr.get_task_instances() if
ti.databricks_task_key in task_ids]
clear_task_instances(tis_to_clear, session)
@@ -327,7 +319,7 @@ class WorkflowJobRepairAllFailedLink(BaseOperatorLink,
LoggingMixin):
tasks_to_run = {ti: t for ti, t in task_group_sub_tasks if ti in
failed_and_skipped_tasks}
- return ",".join(get_databricks_task_ids(task_group.group_id,
tasks_to_run, log))
+ return ",".join(get_databricks_task_ids(task_group.group_id,
tasks_to_run, log)) # type: ignore[arg-type]
@staticmethod
def _get_failed_and_skipped_tasks(dr: DagRun) -> list[str]:
@@ -390,7 +382,7 @@ class WorkflowJobRepairSingleTaskLink(BaseOperatorLink,
LoggingMixin):
"databricks_conn_id": metadata.conn_id,
"databricks_run_id": metadata.run_id,
"run_id": ti_key.run_id,
- "tasks_to_repair": _get_databricks_task_id(task),
+ "tasks_to_repair": task.databricks_task_key,
}
return url_for("RepairDatabricksTasks.repair", **query_params)
diff --git a/providers/tests/databricks/operators/test_databricks.py
b/providers/tests/databricks/operators/test_databricks.py
index da3c697360f..51e7a765998 100644
--- a/providers/tests/databricks/operators/test_databricks.py
+++ b/providers/tests/databricks/operators/test_databricks.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+import hashlib
from datetime import datetime, timedelta
from typing import Any
from unittest import mock
@@ -2216,8 +2217,9 @@ class TestDatabricksNotebookOperator:
task_json =
operator._convert_to_databricks_workflow_task(relevant_upstreams)
+ task_key = hashlib.md5(b"example_dag__test_task").hexdigest()
expected_json = {
- "task_key": "example_dag__test_task",
+ "task_key": task_key,
"depends_on": [],
"timeout_seconds": 0,
"email_notifications": {},
@@ -2317,3 +2319,27 @@ class TestDatabricksTaskOperator:
assert operator.task_config == task_config
assert task_base_json == task_config
+
+ def test_generate_databricks_task_key(self):
+ task_config = {}
+ operator = DatabricksTaskOperator(
+ task_id="test_task",
+ databricks_conn_id="test_conn_id",
+ task_config=task_config,
+ )
+
+ task_key = f"{operator.dag_id}__{operator.task_id}".encode()
+ expected_task_key = hashlib.md5(task_key).hexdigest()
+ assert expected_task_key == operator.databricks_task_key
+
+ def test_user_databricks_task_key(self):
+ task_config = {}
+ operator = DatabricksTaskOperator(
+ task_id="test_task",
+ databricks_conn_id="test_conn_id",
+ databricks_task_key="test_task_key",
+ task_config=task_config,
+ )
+ expected_task_key = "test_task_key"
+
+ assert expected_task_key == operator.databricks_task_key
diff --git a/providers/tests/databricks/plugins/test_databricks_workflow.py
b/providers/tests/databricks/plugins/test_databricks_workflow.py
index 35d3496c66e..a22febb5b7f 100644
--- a/providers/tests/databricks/plugins/test_databricks_workflow.py
+++ b/providers/tests/databricks/plugins/test_databricks_workflow.py
@@ -32,7 +32,6 @@ from airflow.providers.databricks.plugins.databricks_workflow
import (
WorkflowJobRepairSingleTaskLink,
WorkflowJobRunLink,
_get_dagrun,
- _get_databricks_task_id,
_get_launch_task_key,
_repair_task,
get_databricks_task_ids,
@@ -50,30 +49,17 @@ TASK_INSTANCE_KEY = TaskInstanceKey(dag_id=DAG_ID,
task_id=TASK_ID, run_id=RUN_I
DATABRICKS_CONN_ID = "databricks_default"
DATABRICKS_RUN_ID = 12345
GROUP_ID = "test_group"
+LOG = MagicMock()
TASK_MAP = {
- "task1": MagicMock(dag_id=DAG_ID, task_id="task1"),
- "task2": MagicMock(dag_id=DAG_ID, task_id="task2"),
+ "task1": MagicMock(dag_id=DAG_ID, task_id="task1",
databricks_task_key="task_key1"),
+ "task2": MagicMock(dag_id=DAG_ID, task_id="task2",
databricks_task_key="task_key2"),
}
-LOG = MagicMock()
-
-
[email protected](
- "task, expected_id",
- [
- (MagicMock(dag_id="dag1", task_id="task.1"), "dag1__task__1"),
- (MagicMock(dag_id="dag2", task_id="task_1"), "dag2__task_1"),
- ],
-)
-def test_get_databricks_task_id(task, expected_id):
- result = _get_databricks_task_id(task)
-
- assert result == expected_id
def test_get_databricks_task_ids():
result = get_databricks_task_ids(GROUP_ID, TASK_MAP, LOG)
- expected_ids = ["test_dag__task1", "test_dag__task2"]
+ expected_ids = ["task_key1", "task_key2"]
assert result == expected_ids