This is an automated email from the ASF dual-hosted git repository.
pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 574102fd29 [FEAT] adds repair run functionality for databricks (#36601)
574102fd29 is described below
commit 574102fd291930ed45262a40fb7033a122152541
Author: gaurav7261 <[email protected]>
AuthorDate: Thu Jan 11 22:24:47 2024 +0530
[FEAT] adds repair run functionality for databricks (#36601)
* [FEAT] adds repair run functionality for databricks
* [FIX] addded latest repair run and test cases
* [FIX] comma typo
* [FIX] check for DatabricksRunNowOperator instance before doing repair run
* [FIX] fixed static checks
* [FIX] fixed static checks
* Update airflow/providers/databricks/hooks/databricks.py
Co-authored-by: Andrey Anshin <[email protected]>
* [FIX] type annotations
* [FIX] change from log.warn to log.warning
* Update airflow/providers/databricks/operators/databricks.py
Co-authored-by: Andrey Anshin <[email protected]>
* [FIX] CI Static check
---------
Co-authored-by: GauravM
<[email protected]>
Co-authored-by: GauravM
<[email protected]>
Co-authored-by: GauravM <[email protected]>
Co-authored-by: Andrey Anshin <[email protected]>
---
airflow/providers/databricks/hooks/databricks.py | 15 ++++-
.../providers/databricks/operators/databricks.py | 17 +++++
.../providers/databricks/hooks/test_databricks.py | 73 ++++++++++++++++++++++
3 files changed, 103 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/databricks/hooks/databricks.py
b/airflow/providers/databricks/hooks/databricks.py
index b39e3d622c..bc3bd90209 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -519,13 +519,24 @@ class DatabricksHook(BaseDatabricksHook):
json = {"run_id": run_id}
self._do_api_call(DELETE_RUN_ENDPOINT, json)
- def repair_run(self, json: dict) -> None:
+ def repair_run(self, json: dict) -> int:
"""
Re-run one or more tasks.
:param json: repair a job run.
"""
- self._do_api_call(REPAIR_RUN_ENDPOINT, json)
+ response = self._do_api_call(REPAIR_RUN_ENDPOINT, json)
+ return response["repair_id"]
+
+ def get_latest_repair_id(self, run_id: int) -> int | None:
+ """Get latest repair id if any exist for run_id else None."""
+ json = {"run_id": run_id, "include_history": True}
+ response = self._do_api_call(GET_RUN_ENDPOINT, json)
+ repair_history = response["repair_history"]
+ if len(repair_history) == 1:
+ return None
+ else:
+ return repair_history[-1]["id"]
def get_cluster_state(self, cluster_id: str) -> ClusterState:
"""
diff --git a/airflow/providers/databricks/operators/databricks.py
b/airflow/providers/databricks/operators/databricks.py
index edea8b4e59..5d8b62643f 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -88,6 +88,19 @@ def _handle_databricks_operator_execution(operator, hook,
log, context) -> None:
f"{operator.task_id} failed with terminal state:
{run_state} "
f"and with the error {run_state.state_message}"
)
+ if isinstance(operator, DatabricksRunNowOperator) and
operator.repair_run:
+ operator.repair_run = False
+ log.warning(
+ "%s but since repair run is set, repairing the run
with all failed tasks",
+ error_message,
+ )
+
+ latest_repair_id =
hook.get_latest_repair_id(operator.run_id)
+ repair_json = {"run_id": operator.run_id,
"rerun_all_failed_tasks": True}
+ if latest_repair_id is not None:
+ repair_json["latest_repair_id"] = latest_repair_id
+ operator.json["latest_repair_id"] =
hook.repair_run(operator, repair_json)
+ _handle_databricks_operator_execution(operator, hook,
log, context)
raise AirflowException(error_message)
else:
@@ -623,6 +636,7 @@ class DatabricksRunNowOperator(BaseOperator):
- ``jar_params``
- ``spark_submit_params``
- ``idempotency_token``
+ - ``repair_run``
:param job_id: the job_id of the existing Databricks job.
This field will be templated.
@@ -711,6 +725,7 @@ class DatabricksRunNowOperator(BaseOperator):
:param do_xcom_push: Whether we should push run_id and run_page_url to
xcom.
:param wait_for_termination: if we should wait for termination of the job
run. ``True`` by default.
:param deferrable: Run operator in the deferrable mode.
+ :param repair_run: Repair the databricks run in case of failure, doesn't
work in deferrable mode
"""
# Used in airflow.models.BaseOperator
@@ -741,6 +756,7 @@ class DatabricksRunNowOperator(BaseOperator):
do_xcom_push: bool = True,
wait_for_termination: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ repair_run: bool = False,
**kwargs,
) -> None:
"""Create a new ``DatabricksRunNowOperator``."""
@@ -753,6 +769,7 @@ class DatabricksRunNowOperator(BaseOperator):
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
+ self.repair_run = repair_run
if job_id is not None:
self.json["job_id"] = job_id
diff --git a/tests/providers/databricks/hooks/test_databricks.py
b/tests/providers/databricks/hooks/test_databricks.py
index 1baaab1fea..c9004e7175 100644
--- a/tests/providers/databricks/hooks/test_databricks.py
+++ b/tests/providers/databricks/hooks/test_databricks.py
@@ -683,6 +683,79 @@ class TestDatabricksHook:
timeout=self.hook.timeout_seconds,
)
+ @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
+ def test_negative_get_latest_repair_id(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = {
+ "job_id": JOB_ID,
+ "run_id": RUN_ID,
+ "state": {"life_cycle_state": "RUNNING", "result_state":
"RUNNING"},
+ "repair_history": [
+ {
+ "type": "ORIGINAL",
+ "start_time": 1704528798059,
+ "end_time": 1704529026679,
+ "state": {
+ "life_cycle_state": "RUNNING",
+ "result_state": "RUNNING",
+ "state_message": "dummy",
+ "user_cancelled_or_timedout": "false",
+ },
+ "task_run_ids": [396529700633015, 1111270934390307],
+ }
+ ],
+ }
+ latest_repair_id = self.hook.get_latest_repair_id(RUN_ID)
+
+ assert latest_repair_id is None
+
+ @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
+ def test_positive_get_latest_repair_id(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = {
+ "job_id": JOB_ID,
+ "run_id": RUN_ID,
+ "state": {"life_cycle_state": "RUNNING", "result_state":
"RUNNING"},
+ "repair_history": [
+ {
+ "type": "ORIGINAL",
+ "start_time": 1704528798059,
+ "end_time": 1704529026679,
+ "state": {
+ "life_cycle_state": "TERMINATED",
+ "result_state": "CANCELED",
+ "state_message": "dummy_original",
+ "user_cancelled_or_timedout": "false",
+ },
+ "task_run_ids": [396529700633015, 1111270934390307],
+ },
+ {
+ "type": "REPAIR",
+ "start_time": 1704530276423,
+ "end_time": 1704530363736,
+ "state": {
+ "life_cycle_state": "TERMINATED",
+ "result_state": "CANCELED",
+ "state_message": "dummy_repair_1",
+ "user_cancelled_or_timedout": "true",
+ },
+ "id": 108607572123234,
+ "task_run_ids": [396529700633015, 1111270934390307],
+ },
+ {
+ "type": "REPAIR",
+ "start_time": 1704531464690,
+ "end_time": 1704531481590,
+ "state": {"life_cycle_state": "RUNNING", "result_state":
"RUNNING"},
+ "id": 52532060060836,
+ "task_run_ids": [396529700633015, 1111270934390307],
+ },
+ ],
+ }
+ latest_repair_id = self.hook.get_latest_repair_id(RUN_ID)
+
+ assert latest_repair_id == 52532060060836
+
@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_get_cluster_state(self, mock_requests):
"""