This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 861d6bc5bb5 Add new hook method (#62822)
861d6bc5bb5 is described below
commit 861d6bc5bb564ab22e290413ff6bfb4d5008434f
Author: Roman Karichev <[email protected]>
AuthorDate: Thu Mar 12 01:55:05 2026 +0100
Add new hook method (#62822)
---
.../providers/databricks/hooks/databricks.py | 25 ++++++++++++++++++++++
.../providers/databricks/operators/databricks.py | 2 +-
.../tests/unit/databricks/hooks/test_databricks.py | 25 ++++++++++++++++++++++
3 files changed, 51 insertions(+), 1 deletion(-)
diff --git
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py
index ff7e6150994..7bba7634552 100644
--- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py
+++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py
@@ -535,6 +535,31 @@ class DatabricksHook(BaseDatabricksHook):
state = response["state"]
return RunState(**state)
+ def get_run_tasks(self, run_id: int) -> list[dict[str, Any]]:
+ """
+ Retrieve list of tasks performed by the run.
+
+ :param run_id: id of the run
+ :return: A list of tasks
+ """
+ has_more = True
+ all_tasks = []
+ page_token = ""
+ json: dict[str, Any] = {"run_id": run_id}
+
+ while has_more:
+ if page_token:
+ json = {**json, "page_token": page_token}
+ response = self._do_api_call(GET_RUN_ENDPOINT, json)
+ tasks = response.get("tasks", [])
+ all_tasks += tasks
+ if "next_page_token" in response:
+ page_token = response["next_page_token"]
+ else:
+ has_more = False
+
+ return all_tasks
+
def get_run(self, run_id: int) -> dict[str, Any]:
"""
Retrieve run information.
diff --git
a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
index 2d70602bbd5..95943e8b6b6 100644
---
a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
+++
b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
@@ -1329,7 +1329,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
"""Retrieve the Databricks task corresponding to the current Airflow
task."""
if self.databricks_run_id is None:
raise ValueError("Databricks job not yet launched. Please run
launch_notebook_job first.")
- tasks = self._hook.get_run(self.databricks_run_id)["tasks"]
+ tasks = self._hook.get_run_tasks(self.databricks_run_id)
# Because the task_key remains the same across multiple runs, and the
Databricks API does not return
# tasks sorted by their attempts/start time, we sort the tasks by
start time. This ensures that we
diff --git
a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py
b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py
index aab9dff3751..558e39a2129 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py
@@ -62,6 +62,7 @@ CLUSTER_ID = "cluster_id"
RUN_ID = 1
JOB_ID = 42
JOB_NAME = "job-name"
+TASK_KEY = "task-key"
PIPELINE_NAME = "some pipeline name"
PIPELINE_ID = "its-a-pipeline-id"
STATEMENT_ID = "statement_id"
@@ -87,6 +88,7 @@ GET_RUN_RESPONSE = {
"job_id": JOB_ID,
"run_page_url": RUN_PAGE_URL,
"state": {"life_cycle_state": LIFE_CYCLE_STATE, "state_message":
STATE_MESSAGE},
+ "tasks": [{"task_key": TASK_KEY}],
}
GET_RUN_OUTPUT_RESPONSE = {"metadata": {}, "error": ERROR_MESSAGE,
"notebook_output": {}}
CLUSTER_STATE = "TERMINATED"
@@ -692,6 +694,29 @@ class TestDatabricksHook:
state_message = self.hook.get_run_state_message(RUN_ID)
assert state_message == STATE_MESSAGE
+ @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
+ def test_get_run_tasks_success_multiple_pages(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.side_effect = [
+ create_successful_response_mock({**GET_RUN_RESPONSE,
"next_page_token": "PAGETOKEN"}),
+ create_successful_response_mock(GET_RUN_RESPONSE),
+ ]
+
+ tasks = self.hook.get_run_tasks(RUN_ID)
+
+ assert mock_requests.get.call_count == 2
+
+ first_call_args = mock_requests.method_calls[0]
+ assert first_call_args[1][0] == get_run_endpoint(HOST)
+ assert first_call_args[2]["params"] == {"run_id": RUN_ID}
+
+ second_call_args = mock_requests.method_calls[1]
+ assert second_call_args[1][0] == get_run_endpoint(HOST)
+ assert second_call_args[2]["params"] == {"run_id": RUN_ID,
"page_token": "PAGETOKEN"}
+
+ assert len(tasks) == 2
+ assert tasks == GET_RUN_RESPONSE["tasks"] * 2
+
@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_cancel_run(self, mock_requests):
mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE