This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b752f7e9c0f fix for databricks repair buttons not overriding the
template parameters (#46704)
b752f7e9c0f is described below
commit b752f7e9c0f5ab1bb19d8d0d7de4806afa48a7b7
Author: Geethanadh Padavala <[email protected]>
AuthorDate: Mon Apr 7 10:16:21 2025 +0100
fix for databricks repair buttons not overriding the template parameters
(#46704)
* fix for databricks repair buttons not overriding the template parameters
* add unit test
* update unit test to validate input json
---
.../databricks/plugins/databricks_workflow.py | 2 ++
.../databricks/plugins/test_databricks_workflow.py | 30 ++++++++++++++++++++++
2 files changed, 32 insertions(+)
diff --git
a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py
b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py
index 85a884a53e5..1b6466f7678 100644
---
a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py
+++
b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py
@@ -149,10 +149,12 @@ def _repair_task(
databricks_run_id,
)
+ run_data = hook.get_run(databricks_run_id)
repair_json = {
"run_id": databricks_run_id,
"latest_repair_id": repair_history_id,
"rerun_tasks": tasks_to_repair,
+ **run_data.get("overriding_parameters", {}),
}
return hook.repair_run(repair_json)
diff --git
a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py
b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py
index c41bd9690b1..628fa43f613 100644
---
a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py
+++
b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py
@@ -96,6 +96,36 @@ def test_repair_task(mock_databricks_hook):
mock_hook_instance.repair_run.assert_called_once()
+@patch("airflow.providers.databricks.plugins.databricks_workflow.DatabricksHook")
+def test_repair_task_with_params(mock_databricks_hook):
+ mock_hook_instance = mock_databricks_hook.return_value
+ mock_hook_instance.get_latest_repair_id.return_value = 100
+ mock_hook_instance.repair_run.return_value = 200
+ mock_hook_instance.get_run.return_value = {
+ "overriding_parameters": {
+ "key1": "value1",
+ "key2": "value2",
+ }
+ }
+
+ tasks_to_repair = ["task1", "task2"]
+ result = _repair_task(DATABRICKS_CONN_ID, DATABRICKS_RUN_ID,
tasks_to_repair, LOG)
+
+ expected_payload = {
+ "run_id": DATABRICKS_RUN_ID,
+ "rerun_tasks": tasks_to_repair,
+ "overriding_parameters": {
+ "key1": "value1",
+ "key2": "value2",
+ }
+ }
+ assert result == 200
+
mock_hook_instance.get_latest_repair_id.assert_called_once_with(DATABRICKS_RUN_ID)
+ mock_hook_instance.get_run.assert_called_once_with(DATABRICKS_RUN_ID)
+ mock_hook_instance.repair_run.assert_called_once_with(expected_payload)
+
+
+
def test_get_launch_task_id_no_launch_task():
task_group = MagicMock(get_child_by_label=MagicMock(side_effect=KeyError))
task_group.parent_group = None