This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 861d6bc5bb5 Add new hook method (#62822)
861d6bc5bb5 is described below

commit 861d6bc5bb564ab22e290413ff6bfb4d5008434f
Author: Roman Karichev <[email protected]>
AuthorDate: Thu Mar 12 01:55:05 2026 +0100

    Add new hook method (#62822)
---
 .../providers/databricks/hooks/databricks.py       | 25 ++++++++++++++++++++++
 .../providers/databricks/operators/databricks.py   |  2 +-
 .../tests/unit/databricks/hooks/test_databricks.py | 25 ++++++++++++++++++++++
 3 files changed, 51 insertions(+), 1 deletion(-)

diff --git 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py
index ff7e6150994..7bba7634552 100644
--- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py
+++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py
@@ -535,6 +535,31 @@ class DatabricksHook(BaseDatabricksHook):
         state = response["state"]
         return RunState(**state)
 
+    def get_run_tasks(self, run_id: int) -> list[dict[str, Any]]:
+        """
+        Retrieve list of tasks performed by the run.
+
+        :param run_id: id of the run
+        :return: A list of tasks
+        """
+        has_more = True
+        all_tasks = []
+        page_token = ""
+        json: dict[str, Any] = {"run_id": run_id}
+
+        while has_more:
+            if page_token:
+                json = {**json, "page_token": page_token}
+            response = self._do_api_call(GET_RUN_ENDPOINT, json)
+            tasks = response.get("tasks", [])
+            all_tasks += tasks
+            if "next_page_token" in response:
+                page_token = response["next_page_token"]
+            else:
+                has_more = False
+
+        return all_tasks
+
     def get_run(self, run_id: int) -> dict[str, Any]:
         """
         Retrieve run information.
diff --git 
a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py 
b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
index 2d70602bbd5..95943e8b6b6 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
@@ -1329,7 +1329,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
         """Retrieve the Databricks task corresponding to the current Airflow 
task."""
         if self.databricks_run_id is None:
             raise ValueError("Databricks job not yet launched. Please run 
launch_notebook_job first.")
-        tasks = self._hook.get_run(self.databricks_run_id)["tasks"]
+        tasks = self._hook.get_run_tasks(self.databricks_run_id)
 
         # Because the task_key remains the same across multiple runs, and the 
Databricks API does not return
         # tasks sorted by their attempts/start time, we sort the tasks by 
start time. This ensures that we
diff --git 
a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py 
b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py
index aab9dff3751..558e39a2129 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py
@@ -62,6 +62,7 @@ CLUSTER_ID = "cluster_id"
 RUN_ID = 1
 JOB_ID = 42
 JOB_NAME = "job-name"
+TASK_KEY = "task-key"
 PIPELINE_NAME = "some pipeline name"
 PIPELINE_ID = "its-a-pipeline-id"
 STATEMENT_ID = "statement_id"
@@ -87,6 +88,7 @@ GET_RUN_RESPONSE = {
     "job_id": JOB_ID,
     "run_page_url": RUN_PAGE_URL,
     "state": {"life_cycle_state": LIFE_CYCLE_STATE, "state_message": 
STATE_MESSAGE},
+    "tasks": [{"task_key": TASK_KEY}],
 }
 GET_RUN_OUTPUT_RESPONSE = {"metadata": {}, "error": ERROR_MESSAGE, 
"notebook_output": {}}
 CLUSTER_STATE = "TERMINATED"
@@ -692,6 +694,29 @@ class TestDatabricksHook:
         state_message = self.hook.get_run_state_message(RUN_ID)
         assert state_message == STATE_MESSAGE
 
+    @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
+    def test_get_run_tasks_success_multiple_pages(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.get.side_effect = [
+            create_successful_response_mock({**GET_RUN_RESPONSE, 
"next_page_token": "PAGETOKEN"}),
+            create_successful_response_mock(GET_RUN_RESPONSE),
+        ]
+
+        tasks = self.hook.get_run_tasks(RUN_ID)
+
+        assert mock_requests.get.call_count == 2
+
+        first_call_args = mock_requests.method_calls[0]
+        assert first_call_args[1][0] == get_run_endpoint(HOST)
+        assert first_call_args[2]["params"] == {"run_id": RUN_ID}
+
+        second_call_args = mock_requests.method_calls[1]
+        assert second_call_args[1][0] == get_run_endpoint(HOST)
+        assert second_call_args[2]["params"] == {"run_id": RUN_ID, 
"page_token": "PAGETOKEN"}
+
+        assert len(tasks) == 2
+        assert tasks == GET_RUN_RESPONSE["tasks"] * 2
+
     @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
     def test_cancel_run(self, mock_requests):
         mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE

Reply via email to