This is an automated email from the ASF dual-hosted git repository.
pankajkoti pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1e4663f34c add deferrable support to `DatabricksNotebookOperator`
(#39295)
1e4663f34c is described below
commit 1e4663f34c2fb42b87cf75e4776650620eb2baa4
Author: Kalyan <[email protected]>
AuthorDate: Tue May 14 19:48:17 2024 +0530
add deferrable support to `DatabricksNotebookOperator` (#39295)
related: #39178
This PR intends to make DatabricksNotebookOperator deferrable
---
.../providers/databricks/hooks/databricks_base.py | 1 +
.../providers/databricks/operators/databricks.py | 40 +++++++++++++++---
.../providers/databricks/triggers/databricks.py | 2 +
.../databricks/operators/test_databricks.py | 48 +++++++++++++++++++++-
4 files changed, 83 insertions(+), 8 deletions(-)
diff --git a/airflow/providers/databricks/hooks/databricks_base.py
b/airflow/providers/databricks/hooks/databricks_base.py
index 32316d49bb..2dee924f61 100644
--- a/airflow/providers/databricks/hooks/databricks_base.py
+++ b/airflow/providers/databricks/hooks/databricks_base.py
@@ -80,6 +80,7 @@ class BaseDatabricksHook(BaseHook):
:param retry_delay: The number of seconds to wait between retries (it
might be a floating point number).
:param retry_args: An optional dictionary with arguments passed to
``tenacity.Retrying`` class.
+ :param caller: The name of the operator that is calling the hook.
"""
conn_name_attr: str = "databricks_conn_id"
diff --git a/airflow/providers/databricks/operators/databricks.py
b/airflow/providers/databricks/operators/databricks.py
index 0d819e1b70..7ae802db10 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -167,7 +167,7 @@ def
_handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
error_message = f"Job run failed with terminal state: {run_state} and with
the errors {errors}"
- if event["repair_run"]:
+ if event.get("repair_run"):
log.warning(
"%s but since repair run is set, repairing the run with all failed
tasks",
error_message,
@@ -923,9 +923,11 @@ class DatabricksNotebookOperator(BaseOperator):
:param databricks_retry_args: An optional dictionary with arguments passed
to ``tenacity.Retrying`` class.
:param wait_for_termination: if we should wait for termination of the job
run. ``True`` by default.
:param databricks_conn_id: The name of the Airflow connection to use.
+ :param deferrable: Run operator in the deferrable mode.
"""
template_fields = ("notebook_params",)
+ CALLER = "DatabricksNotebookOperator"
def __init__(
self,
@@ -942,6 +944,7 @@ class DatabricksNotebookOperator(BaseOperator):
databricks_retry_args: dict[Any, Any] | None = None,
wait_for_termination: bool = True,
databricks_conn_id: str = "databricks_default",
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs: Any,
):
self.notebook_path = notebook_path
@@ -958,11 +961,12 @@ class DatabricksNotebookOperator(BaseOperator):
self.wait_for_termination = wait_for_termination
self.databricks_conn_id = databricks_conn_id
self.databricks_run_id: int | None = None
+ self.deferrable = deferrable
super().__init__(**kwargs)
@cached_property
def _hook(self) -> DatabricksHook:
- return self._get_hook(caller="DatabricksNotebookOperator")
+ return self._get_hook(caller=self.CALLER)
def _get_hook(self, caller: str) -> DatabricksHook:
return DatabricksHook(
@@ -970,7 +974,7 @@ class DatabricksNotebookOperator(BaseOperator):
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
- caller=caller,
+ caller=self.CALLER,
)
def _get_task_timeout_seconds(self) -> int:
@@ -1041,6 +1045,19 @@ class DatabricksNotebookOperator(BaseOperator):
run = self._hook.get_run(self.databricks_run_id)
run_state = RunState(**run["state"])
self.log.info("Current state of the job: %s",
run_state.life_cycle_state)
+ if self.deferrable and not run_state.is_terminal:
+ return self.defer(
+ trigger=DatabricksExecutionTrigger(
+ run_id=self.databricks_run_id,
+ databricks_conn_id=self.databricks_conn_id,
+ polling_period_seconds=self.polling_period_seconds,
+ retry_limit=self.databricks_retry_limit,
+ retry_delay=self.databricks_retry_delay,
+ retry_args=self.databricks_retry_args,
+ caller=self.CALLER,
+ ),
+ method_name=DEFER_METHOD_NAME,
+ )
while not run_state.is_terminal:
time.sleep(self.polling_period_seconds)
run = self._hook.get_run(self.databricks_run_id)
@@ -1056,9 +1073,7 @@ class DatabricksNotebookOperator(BaseOperator):
)
if not run_state.is_successful:
raise AirflowException(
- "Task failed. Final state %s. Reason: %s",
- run_state.result_state,
- run_state.state_message,
+ f"Task failed. Final state {run_state.result_state}. Reason:
{run_state.state_message}"
)
self.log.info("Task succeeded. Final state %s.",
run_state.result_state)
@@ -1066,3 +1081,16 @@ class DatabricksNotebookOperator(BaseOperator):
self.launch_notebook_job()
if self.wait_for_termination:
self.monitor_databricks_job()
+
+ def execute_complete(self, context: dict | None, event: dict) -> None:
+ run_state = RunState.from_json(event["run_state"])
+ if run_state.life_cycle_state != "TERMINATED":
+ raise AirflowException(
+ f"Databricks job failed with state
{run_state.life_cycle_state}. "
+ f"Message: {run_state.state_message}"
+ )
+ if not run_state.is_successful:
+ raise AirflowException(
+ f"Task failed. Final state {run_state.result_state}. Reason:
{run_state.state_message}"
+ )
+ self.log.info("Task succeeded. Final state %s.",
run_state.result_state)
diff --git a/airflow/providers/databricks/triggers/databricks.py
b/airflow/providers/databricks/triggers/databricks.py
index d20202fdca..55845fc6f7 100644
--- a/airflow/providers/databricks/triggers/databricks.py
+++ b/airflow/providers/databricks/triggers/databricks.py
@@ -48,6 +48,7 @@ class DatabricksExecutionTrigger(BaseTrigger):
retry_args: dict[Any, Any] | None = None,
run_page_url: str | None = None,
repair_run: bool = False,
+ caller: str = "DatabricksExecutionTrigger",
) -> None:
super().__init__()
self.run_id = run_id
@@ -63,6 +64,7 @@ class DatabricksExecutionTrigger(BaseTrigger):
retry_limit=self.retry_limit,
retry_delay=self.retry_delay,
retry_args=retry_args,
+ caller=caller,
)
def serialize(self) -> tuple[str, dict[str, Any]]:
diff --git a/tests/providers/databricks/operators/test_databricks.py
b/tests/providers/databricks/operators/test_databricks.py
index 64b9ba985c..d6e7eb3892 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -1865,6 +1865,50 @@ class TestDatabricksNotebookOperator:
operator.launch_notebook_job.assert_called_once()
operator.monitor_databricks_job.assert_not_called()
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+ def test_execute_with_deferrable(self, mock_databricks_hook):
+ mock_databricks_hook.return_value.get_run.return_value = {"state":
{"life_cycle_state": "PENDING"}}
+ operator = DatabricksNotebookOperator(
+ task_id="test_task",
+ notebook_path="test_path",
+ source="test_source",
+ databricks_conn_id="test_conn_id",
+ wait_for_termination=True,
+ deferrable=True,
+ )
+ operator.databricks_run_id = 12345
+
+ with pytest.raises(TaskDeferred) as exec_info:
+ operator.monitor_databricks_job()
+ assert isinstance(
+ exec_info.value.trigger, DatabricksExecutionTrigger
+ ), "Trigger is not a DatabricksExecutionTrigger"
+ assert exec_info.value.method_name == "execute_complete"
+
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+ def test_execute_with_deferrable_early_termination(self,
mock_databricks_hook):
+ mock_databricks_hook.return_value.get_run.return_value = {
+ "state": {
+ "life_cycle_state": "TERMINATED",
+ "result_state": "FAILED",
+ "state_message": "FAILURE",
+ }
+ }
+ operator = DatabricksNotebookOperator(
+ task_id="test_task",
+ notebook_path="test_path",
+ source="test_source",
+ databricks_conn_id="test_conn_id",
+ wait_for_termination=True,
+ deferrable=True,
+ )
+ operator.databricks_run_id = 12345
+
+ with pytest.raises(AirflowException) as exec_info:
+ operator.monitor_databricks_job()
+ exception_message = "Task failed. Final state FAILED. Reason: FAILURE"
+ assert exception_message == str(exec_info.value)
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_monitor_databricks_job_successful_raises_no_exception(self,
mock_databricks_hook):
mock_databricks_hook.return_value.get_run.return_value = {
@@ -1896,10 +1940,10 @@ class TestDatabricksNotebookOperator:
operator.databricks_run_id = 12345
- exception_message = "'Task failed. Final state %s. Reason: %s',
'FAILED', 'FAILURE'"
with pytest.raises(AirflowException) as exc_info:
operator.monitor_databricks_job()
- assert exception_message in str(exc_info.value)
+ exception_message = "Task failed. Final state FAILED. Reason: FAILURE"
+ assert exception_message == str(exc_info.value)
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_launch_notebook_job(self, mock_databricks_hook):