This is an automated email from the ASF dual-hosted git repository.

pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 347373986c check status before DatabricksSubmitRunOperator & 
DatabricksSubmitRunOperator executes in deferrable mode (#36862)
347373986c is described below

commit 347373986c378a3c7fd4cf85336d0c419a51991e
Author: vatsrahul1001 <[email protected]>
AuthorDate: Tue Jan 23 12:47:30 2024 +0530

    check status before DatabricksSubmitRunOperator & 
DatabricksSubmitRunOperator executes in deferrable mode (#36862)
    
    * check-status-before-DatabricksSubmitRunOperator-execute-in-deferrable-mode
    
    * Apply suggestions from code review
    
    Co-authored-by: Wei Lee <[email protected]>
    
    * fixing static checks
    
    ---------
    
    Co-authored-by: Wei Lee <[email protected]>
---
 .../providers/databricks/operators/databricks.py   |  30 ++--
 .../databricks/operators/test_databricks.py        | 161 ++++++++++++++++++++-
 2 files changed, 177 insertions(+), 14 deletions(-)

diff --git a/airflow/providers/databricks/operators/databricks.py 
b/airflow/providers/databricks/operators/databricks.py
index 5d8b62643f..9aa07582f4 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -132,18 +132,24 @@ def 
_handle_deferrable_databricks_operator_execution(operator, hook, log, contex
     log.info("View run status, Spark UI, and logs at %s", run_page_url)
 
     if operator.wait_for_termination:
-        operator.defer(
-            trigger=DatabricksExecutionTrigger(
-                run_id=operator.run_id,
-                databricks_conn_id=operator.databricks_conn_id,
-                polling_period_seconds=operator.polling_period_seconds,
-                retry_limit=operator.databricks_retry_limit,
-                retry_delay=operator.databricks_retry_delay,
-                retry_args=operator.databricks_retry_args,
-                run_page_url=run_page_url,
-            ),
-            method_name=DEFER_METHOD_NAME,
-        )
+        run_info = hook.get_run(operator.run_id)
+        run_state = RunState(**run_info["state"])
+        if not run_state.is_terminal:
+            operator.defer(
+                trigger=DatabricksExecutionTrigger(
+                    run_id=operator.run_id,
+                    databricks_conn_id=operator.databricks_conn_id,
+                    polling_period_seconds=operator.polling_period_seconds,
+                    retry_limit=operator.databricks_retry_limit,
+                    retry_delay=operator.databricks_retry_delay,
+                    retry_args=operator.databricks_retry_args,
+                    run_page_url=run_page_url,
+                ),
+                method_name=DEFER_METHOD_NAME,
+            )
+        else:
+            if run_state.is_successful:
+                log.info("%s completed successfully.", operator.task_id)
 
 
 def _handle_deferrable_databricks_operator_completion(event: dict, log: 
Logger) -> None:
diff --git a/tests/providers/databricks/operators/test_databricks.py 
b/tests/providers/databricks/operators/test_databricks.py
index c196f51ee3..d0b8199701 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -910,7 +910,7 @@ class TestDatabricksSubmitRunDeferrableOperator:
         op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
         db_mock = db_mock_class.return_value
         db_mock.submit_run.return_value = 1
-        db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+        db_mock.get_run = make_run_with_state_mock("RUNNING", "RUNNING")
 
         with pytest.raises(TaskDeferred) as exc:
             op.execute(None)
@@ -982,6 +982,66 @@ class TestDatabricksSubmitRunDeferrableOperator:
         with pytest.raises(AirflowException):
             op.execute_complete(context=None, event=event)
 
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator.defer")
+    def 
test_databricks_submit_run_deferrable_operator_failed_before_defer(self, 
mock_defer, db_mock_class):
+        """Asserts that a task is not deferred when its failed"""
+        run = {
+            "new_cluster": NEW_CLUSTER,
+            "notebook_task": NOTEBOOK_TASK,
+        }
+        op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.submit_run.return_value = 1
+        db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
+        op.execute(None)
+
+        expected = utils.normalise_json_content(
+            {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": TASK_ID}
+        )
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+            caller="DatabricksSubmitRunDeferrableOperator",
+        )
+
+        db_mock.submit_run.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        assert op.run_id == RUN_ID
+        assert not mock_defer.called
+
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator.defer")
+    def 
test_databricks_submit_run_deferrable_operator_success_before_defer(self, 
mock_defer, db_mock_class):
+        """Asserts that a task is not deferred when it succeeds"""
+        run = {
+            "new_cluster": NEW_CLUSTER,
+            "notebook_task": NOTEBOOK_TASK,
+        }
+        op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.submit_run.return_value = 1
+        db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+        op.execute(None)
+
+        expected = utils.normalise_json_content(
+            {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": TASK_ID}
+        )
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+            caller="DatabricksSubmitRunDeferrableOperator",
+        )
+
+        db_mock.submit_run.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        assert op.run_id == RUN_ID
+        assert not mock_defer.called
+
 
 class TestDatabricksRunNowOperator:
     def test_init_with_named_parameters(self):
@@ -1345,7 +1405,7 @@ class TestDatabricksRunNowDeferrableOperator:
         op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, 
job_id=JOB_ID, json=run)
         db_mock = db_mock_class.return_value
         db_mock.run_now.return_value = 1
-        db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+        db_mock.get_run = make_run_with_state_mock("RUNNING", "RUNNING")
 
         with pytest.raises(TaskDeferred) as exc:
             op.execute(None)
@@ -1416,3 +1476,100 @@ class TestDatabricksRunNowDeferrableOperator:
         op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID)
         with pytest.raises(AirflowException):
             op.execute_complete(context=None, event=event)
+
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator.defer")
+    def test_operator_failed_before_defer(self, mock_defer, db_mock_class):
+        """Asserts that a task is not deferred when its failed"""
+        run = {
+            "new_cluster": NEW_CLUSTER,
+            "notebook_task": NOTEBOOK_TASK,
+        }
+        op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.submit_run.return_value = 1
+        db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
+        op.execute(None)
+
+        expected = utils.normalise_json_content(
+            {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": TASK_ID}
+        )
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+            caller="DatabricksSubmitRunDeferrableOperator",
+        )
+
+        db_mock.submit_run.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        assert op.run_id == RUN_ID
+        assert not mock_defer.called
+
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator.defer")
+    def test_databricks_run_now_deferrable_operator_failed_before_defer(self, 
mock_defer, db_mock_class):
+        """Asserts that a task is not deferred when its failed"""
+        run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": 
NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
+        op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, 
job_id=JOB_ID, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.run_now.return_value = 1
+        db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
+
+        op.execute(None)
+        expected = utils.normalise_json_content(
+            {
+                "notebook_params": NOTEBOOK_PARAMS,
+                "notebook_task": NOTEBOOK_TASK,
+                "jar_params": JAR_PARAMS,
+                "job_id": JOB_ID,
+            }
+        )
+
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+            caller="DatabricksRunNowOperator",
+        )
+
+        db_mock.run_now.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        assert op.run_id == RUN_ID
+        assert not mock_defer.called
+
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator.defer")
+    def test_databricks_run_now_deferrable_operator_success_before_defer(self, 
mock_defer, db_mock_class):
+        """Asserts that a task is not deferred when its succeeds"""
+        run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": 
NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
+        op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, 
job_id=JOB_ID, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.run_now.return_value = 1
+        db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+
+        op.execute(None)
+
+        expected = utils.normalise_json_content(
+            {
+                "notebook_params": NOTEBOOK_PARAMS,
+                "notebook_task": NOTEBOOK_TASK,
+                "jar_params": JAR_PARAMS,
+                "job_id": JOB_ID,
+            }
+        )
+
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+            caller="DatabricksRunNowOperator",
+        )
+
+        db_mock.run_now.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        assert op.run_id == RUN_ID
+        assert not mock_defer.called

Reply via email to