This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2d103e115c get all failed tasks errors in when exception raised in 
DatabricksCreateJobsOperator (#39354)
2d103e115c is described below

commit 2d103e115c9951ce2bccb6b7ffa4fbd7ff269ef3
Author: gaurav7261 <[email protected]>
AuthorDate: Fri May 3 14:14:45 2024 +0530

    get all failed tasks errors in when exception raised in 
DatabricksCreateJobsOperator (#39354)
---
 .../providers/databricks/operators/databricks.py   | 21 +++----
 .../databricks/operators/test_databricks.py        | 73 +++++++++++++++++++++-
 2 files changed, 81 insertions(+), 13 deletions(-)

diff --git a/airflow/providers/databricks/operators/databricks.py 
b/airflow/providers/databricks/operators/databricks.py
index c38b0683c3..0d819e1b70 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -67,23 +67,22 @@ def _handle_databricks_operator_execution(operator, hook, 
log, context) -> None:
                     log.info("%s completed successfully.", operator.task_id)
                     log.info("View run status, Spark UI, and logs at %s", 
run_page_url)
                     return
-
                 if run_state.result_state == "FAILED":
-                    task_run_id = None
+                    failed_tasks = []
                     for task in run_info.get("tasks", []):
                         if task.get("state", {}).get("result_state", "") == 
"FAILED":
                             task_run_id = task["run_id"]
-                    if task_run_id is not None:
-                        run_output = hook.get_run_output(task_run_id)
-                        if "error" in run_output:
-                            notebook_error = run_output["error"]
-                        else:
-                            notebook_error = run_state.state_message
-                    else:
-                        notebook_error = run_state.state_message
+                            task_key = task["task_key"]
+                            run_output = hook.get_run_output(task_run_id)
+                            if "error" in run_output:
+                                error = run_output["error"]
+                            else:
+                                error = run_state.state_message
+                            failed_tasks.append({"task_key": task_key, 
"run_id": task_run_id, "error": error})
+
                     error_message = (
                         f"{operator.task_id} failed with terminal state: 
{run_state} "
-                        f"and with the error {notebook_error}"
+                        f"and with the errors {failed_tasks}"
                     )
                 else:
                     error_message = (
diff --git a/tests/providers/databricks/operators/test_databricks.py 
b/tests/providers/databricks/operators/test_databricks.py
index e6cb240dfc..64b9ba985c 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -1310,6 +1310,7 @@ class TestDatabricksRunNowOperator:
                 "tasks": [
                     {
                         "run_id": 2,
+                        "task_key": "first_task",
                         "state": {
                             "life_cycle_state": "TERMINATED",
                             "result_state": "FAILED",
@@ -1321,10 +1322,76 @@ class TestDatabricksRunNowOperator:
         )
         db_mock.get_run_output = mock_dict({"error": "Exception: Something 
went wrong..."})
 
-        with pytest.raises(AirflowException) as exc_info:
+        with pytest.raises(AirflowException, match="Exception: Something went 
wrong"):
             op.execute(None)
 
-        assert exc_info.value.args[0].endswith(" Exception: Something went 
wrong...")
+        expected = utils.normalise_json_content(
+            {
+                "notebook_params": NOTEBOOK_PARAMS,
+                "notebook_task": NOTEBOOK_TASK,
+                "jar_params": JAR_PARAMS,
+                "job_id": JOB_ID,
+            }
+        )
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+            caller="DatabricksRunNowOperator",
+        )
+        db_mock.run_now.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        db_mock.get_run.assert_called_once_with(RUN_ID)
+        assert RUN_ID == op.run_id
+
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    def test_exec_multiple_failures_with_message(self, db_mock_class):
+        """
+        Test the execute function in case where the run failed.
+        """
+        run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": 
NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
+        op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.run_now.return_value = RUN_ID
+        db_mock.get_run = mock_dict(
+            {
+                "job_id": JOB_ID,
+                "run_id": 1,
+                "state": {
+                    "life_cycle_state": "TERMINATED",
+                    "result_state": "FAILED",
+                    "state_message": "failed",
+                },
+                "tasks": [
+                    {
+                        "run_id": 2,
+                        "task_key": "first_task",
+                        "state": {
+                            "life_cycle_state": "TERMINATED",
+                            "result_state": "FAILED",
+                            "state_message": "failed",
+                        },
+                    },
+                    {
+                        "run_id": 3,
+                        "task_key": "second_task",
+                        "state": {
+                            "life_cycle_state": "TERMINATED",
+                            "result_state": "FAILED",
+                            "state_message": "failed",
+                        },
+                    },
+                ],
+            }
+        )
+        db_mock.get_run_output = mock_dict({"error": "Exception: Something 
went wrong..."})
+
+        with pytest.raises(
+            AirflowException,
+            match="(?=.*Exception: Something went wrong.*)(?=.*Exception: 
Something went wrong.*)",
+        ):
+            op.execute(None)
 
         expected = utils.normalise_json_content(
             {
@@ -1341,6 +1408,8 @@ class TestDatabricksRunNowOperator:
             retry_args=None,
             caller="DatabricksRunNowOperator",
         )
+        db_mock.get_run_output.assert_called()
+        assert db_mock.get_run_output.call_count == 2
         db_mock.run_now.assert_called_once_with(expected)
         db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
         db_mock.get_run.assert_called_once_with(RUN_ID)

Reply via email to