This is an automated email from the ASF dual-hosted git repository.
pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 347373986c check status before DatabricksSubmitRunOperator &
DatabricksSubmitRunOperator executes in deferrable mode (#36862)
347373986c is described below
commit 347373986c378a3c7fd4cf85336d0c419a51991e
Author: vatsrahul1001 <[email protected]>
AuthorDate: Tue Jan 23 12:47:30 2024 +0530
check status before DatabricksSubmitRunOperator &
DatabricksSubmitRunOperator executes in deferrable mode (#36862)
* check-status-before-DatabricksSubmitRunOperator-execute-in-deferrable-mode
* Apply suggestions from code review
Co-authored-by: Wei Lee <[email protected]>
* fixing static checks
---------
Co-authored-by: Wei Lee <[email protected]>
---
.../providers/databricks/operators/databricks.py | 30 ++--
.../databricks/operators/test_databricks.py | 161 ++++++++++++++++++++-
2 files changed, 177 insertions(+), 14 deletions(-)
diff --git a/airflow/providers/databricks/operators/databricks.py
b/airflow/providers/databricks/operators/databricks.py
index 5d8b62643f..9aa07582f4 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -132,18 +132,24 @@ def
_handle_deferrable_databricks_operator_execution(operator, hook, log, contex
log.info("View run status, Spark UI, and logs at %s", run_page_url)
if operator.wait_for_termination:
- operator.defer(
- trigger=DatabricksExecutionTrigger(
- run_id=operator.run_id,
- databricks_conn_id=operator.databricks_conn_id,
- polling_period_seconds=operator.polling_period_seconds,
- retry_limit=operator.databricks_retry_limit,
- retry_delay=operator.databricks_retry_delay,
- retry_args=operator.databricks_retry_args,
- run_page_url=run_page_url,
- ),
- method_name=DEFER_METHOD_NAME,
- )
+ run_info = hook.get_run(operator.run_id)
+ run_state = RunState(**run_info["state"])
+ if not run_state.is_terminal:
+ operator.defer(
+ trigger=DatabricksExecutionTrigger(
+ run_id=operator.run_id,
+ databricks_conn_id=operator.databricks_conn_id,
+ polling_period_seconds=operator.polling_period_seconds,
+ retry_limit=operator.databricks_retry_limit,
+ retry_delay=operator.databricks_retry_delay,
+ retry_args=operator.databricks_retry_args,
+ run_page_url=run_page_url,
+ ),
+ method_name=DEFER_METHOD_NAME,
+ )
+ else:
+ if run_state.is_successful:
+ log.info("%s completed successfully.", operator.task_id)
def _handle_deferrable_databricks_operator_completion(event: dict, log:
Logger) -> None:
diff --git a/tests/providers/databricks/operators/test_databricks.py
b/tests/providers/databricks/operators/test_databricks.py
index c196f51ee3..d0b8199701 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -910,7 +910,7 @@ class TestDatabricksSubmitRunDeferrableOperator:
op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
- db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+ db_mock.get_run = make_run_with_state_mock("RUNNING", "RUNNING")
with pytest.raises(TaskDeferred) as exc:
op.execute(None)
@@ -982,6 +982,66 @@ class TestDatabricksSubmitRunDeferrableOperator:
with pytest.raises(AirflowException):
op.execute_complete(context=None, event=event)
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator.defer")
+ def
test_databricks_submit_run_deferrable_operator_failed_before_defer(self,
mock_defer, db_mock_class):
+ """Asserts that a task is not deferred when its failed"""
+ run = {
+ "new_cluster": NEW_CLUSTER,
+ "notebook_task": NOTEBOOK_TASK,
+ }
+ op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
+ db_mock = db_mock_class.return_value
+ db_mock.submit_run.return_value = 1
+ db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
+ op.execute(None)
+
+ expected = utils.normalise_json_content(
+ {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
+ )
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
+ caller="DatabricksSubmitRunDeferrableOperator",
+ )
+
+ db_mock.submit_run.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ assert op.run_id == RUN_ID
+ assert not mock_defer.called
+
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator.defer")
+ def
test_databricks_submit_run_deferrable_operator_success_before_defer(self,
mock_defer, db_mock_class):
+ """Asserts that a task is not deferred when it succeeds"""
+ run = {
+ "new_cluster": NEW_CLUSTER,
+ "notebook_task": NOTEBOOK_TASK,
+ }
+ op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
+ db_mock = db_mock_class.return_value
+ db_mock.submit_run.return_value = 1
+ db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+ op.execute(None)
+
+ expected = utils.normalise_json_content(
+ {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
+ )
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
+ caller="DatabricksSubmitRunDeferrableOperator",
+ )
+
+ db_mock.submit_run.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ assert op.run_id == RUN_ID
+ assert not mock_defer.called
+
class TestDatabricksRunNowOperator:
def test_init_with_named_parameters(self):
@@ -1345,7 +1405,7 @@ class TestDatabricksRunNowDeferrableOperator:
op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID,
job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
- db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+ db_mock.get_run = make_run_with_state_mock("RUNNING", "RUNNING")
with pytest.raises(TaskDeferred) as exc:
op.execute(None)
@@ -1416,3 +1476,100 @@ class TestDatabricksRunNowDeferrableOperator:
op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID)
with pytest.raises(AirflowException):
op.execute_complete(context=None, event=event)
+
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator.defer")
+ def test_operator_failed_before_defer(self, mock_defer, db_mock_class):
+ """Asserts that a task is not deferred when its failed"""
+ run = {
+ "new_cluster": NEW_CLUSTER,
+ "notebook_task": NOTEBOOK_TASK,
+ }
+ op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
+ db_mock = db_mock_class.return_value
+ db_mock.submit_run.return_value = 1
+ db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
+ op.execute(None)
+
+ expected = utils.normalise_json_content(
+ {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
+ )
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
+ caller="DatabricksSubmitRunDeferrableOperator",
+ )
+
+ db_mock.submit_run.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ assert op.run_id == RUN_ID
+ assert not mock_defer.called
+
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator.defer")
+ def test_databricks_run_now_deferrable_operator_failed_before_defer(self,
mock_defer, db_mock_class):
+ """Asserts that a task is not deferred when its failed"""
+ run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task":
NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
+ op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID,
job_id=JOB_ID, json=run)
+ db_mock = db_mock_class.return_value
+ db_mock.run_now.return_value = 1
+ db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
+
+ op.execute(None)
+ expected = utils.normalise_json_content(
+ {
+ "notebook_params": NOTEBOOK_PARAMS,
+ "notebook_task": NOTEBOOK_TASK,
+ "jar_params": JAR_PARAMS,
+ "job_id": JOB_ID,
+ }
+ )
+
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
+ caller="DatabricksRunNowOperator",
+ )
+
+ db_mock.run_now.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ assert op.run_id == RUN_ID
+ assert not mock_defer.called
+
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator.defer")
+ def test_databricks_run_now_deferrable_operator_success_before_defer(self,
mock_defer, db_mock_class):
+ """Asserts that a task is not deferred when its succeeds"""
+ run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task":
NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
+ op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID,
job_id=JOB_ID, json=run)
+ db_mock = db_mock_class.return_value
+ db_mock.run_now.return_value = 1
+ db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+
+ op.execute(None)
+
+ expected = utils.normalise_json_content(
+ {
+ "notebook_params": NOTEBOOK_PARAMS,
+ "notebook_task": NOTEBOOK_TASK,
+ "jar_params": JAR_PARAMS,
+ "job_id": JOB_ID,
+ }
+ )
+
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
+ caller="DatabricksRunNowOperator",
+ )
+
+ db_mock.run_now.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ assert op.run_id == RUN_ID
+ assert not mock_defer.called