This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3eae5bca9b0 fix: Map Airflow trigger_rule to Databricks run_if in
`DatabricksWorkflowTaskGroup` (#63420)
3eae5bca9b0 is described below
commit 3eae5bca9b0db744dd75f0c3a14180ae920f63e5
Author: Noritaka Sekiyama <[email protected]>
AuthorDate: Sat Mar 14 16:05:04 2026 +0800
fix: Map Airflow trigger_rule to Databricks run_if in
`DatabricksWorkflowTaskGroup` (#63420)
* fix: Map Airflow trigger_rule to Databricks run_if in
DatabricksWorkflowTaskGroup
When using DatabricksWorkflowTaskGroup, the Airflow trigger_rule on
DatabricksTaskOperator was not being propagated to the Databricks Jobs
API. This caused tasks to execute in Databricks even when Airflow marked
them as skipped due to upstream skip conditions.
This fix maps Airflow TriggerRule values to the corresponding Databricks
RunIf conditions in the workflow task JSON, keeping Airflow and
Databricks task states in sync.
closes: #47024
* Add logging for trigger_rule to run_if mapping decisions
Addresses review feedback to log which mapping decision was made,
making it easier to debug issues with trigger rule propagation.
---
.../providers/databricks/operators/databricks.py | 34 +++++++++++++
.../unit/databricks/operators/test_databricks.py | 57 ++++++++++++++++++++++
2 files changed, 91 insertions(+)
diff --git
a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
index 2ec8f03b4ed..e99750ee72e 100644
---
a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
+++
b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
@@ -67,6 +67,19 @@ XCOM_JOB_ID_KEY = "job_id"
XCOM_RUN_PAGE_URL_KEY = "run_page_url"
XCOM_STATEMENT_ID_KEY = "statement_id"
+# Mapping from Airflow TriggerRule to Databricks RunIf condition.
+# Only non-default rules are included; ALL_SUCCESS is the Databricks default
+# and is omitted to keep the task JSON minimal.
+_TRIGGER_RULE_TO_DATABRICKS_RUN_IF: dict[str, str] = {
+ "all_failed": "ALL_FAILED",
+ "all_done": "ALL_DONE",
+ "one_success": "AT_LEAST_ONE_SUCCESS",
+ "one_failed": "AT_LEAST_ONE_FAILED",
+ "none_failed": "NONE_FAILED",
+ "none_failed_min_one_success": "NONE_FAILED",
+ "always": "ALL_DONE",
+}
+
def _handle_databricks_operator_execution(operator, hook, log, context) ->
None:
"""
@@ -1356,6 +1369,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
) -> dict[str, object]:
"""Convert the operator to a Databricks workflow task that can be a
task in a workflow."""
base_task_json = self._get_task_base_json()
+
result = {
"task_key": self.databricks_task_key,
"depends_on": [
@@ -1366,6 +1380,26 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
**base_task_json,
}
+ trigger_rule_value = (
+ self.trigger_rule.value if hasattr(self.trigger_rule, "value")
else str(self.trigger_rule)
+ )
+ databricks_run_if =
_TRIGGER_RULE_TO_DATABRICKS_RUN_IF.get(trigger_rule_value)
+ if databricks_run_if:
+ self.log.info(
+ "Mapping Airflow trigger_rule '%s' to Databricks run_if '%s'
for task '%s'",
+ trigger_rule_value,
+ databricks_run_if,
+ self.task_id,
+ )
+ result["run_if"] = databricks_run_if
+ else:
+ self.log.info(
+ "No Databricks run_if mapping for Airflow trigger_rule '%s' on
task '%s'; "
+ "using Databricks default (ALL_SUCCESS)",
+ trigger_rule_value,
+ self.task_id,
+ )
+
if self.existing_cluster_id and self.job_cluster_key:
raise ValueError(
"Both existing_cluster_id and job_cluster_key are set. Only
one can be set per task."
diff --git
a/providers/databricks/tests/unit/databricks/operators/test_databricks.py
b/providers/databricks/tests/unit/databricks/operators/test_databricks.py
index 82888fc461b..e4ee05f36ef 100644
--- a/providers/databricks/tests/unit/databricks/operators/test_databricks.py
+++ b/providers/databricks/tests/unit/databricks/operators/test_databricks.py
@@ -2561,6 +2561,63 @@ class TestDatabricksNotebookOperator:
assert task_json == expected_json
+ @pytest.mark.parametrize(
+ ("trigger_rule", "expected_run_if"),
+ [
+ ("all_failed", "ALL_FAILED"),
+ ("all_done", "ALL_DONE"),
+ ("one_success", "AT_LEAST_ONE_SUCCESS"),
+ ("one_failed", "AT_LEAST_ONE_FAILED"),
+ ("none_failed", "NONE_FAILED"),
+ ("none_failed_min_one_success", "NONE_FAILED"),
+ ("always", "ALL_DONE"),
+ ],
+ )
+ def test_convert_to_databricks_workflow_task_with_trigger_rule(self,
trigger_rule, expected_run_if):
+ """Test that trigger_rule is mapped to Databricks run_if in the
workflow task JSON."""
+ dag = DAG(dag_id="example_dag", schedule=None,
start_date=datetime.now())
+ operator = DatabricksNotebookOperator(
+ notebook_path="/path/to/notebook",
+ source="WORKSPACE",
+ task_id="test_task",
+ trigger_rule=trigger_rule,
+ dag=dag,
+ )
+
+ databricks_workflow_task_group = MagicMock()
+ databricks_workflow_task_group.notebook_packages = []
+ databricks_workflow_task_group.notebook_params = {}
+
+ operator.task_group = databricks_workflow_task_group
+ relevant_upstreams = []
+ task_dict = {}
+
+ task_json =
operator._convert_to_databricks_workflow_task(relevant_upstreams, task_dict)
+
+ assert task_json["run_if"] == expected_run_if
+
+ def
test_convert_to_databricks_workflow_task_default_trigger_rule_no_run_if(self):
+ """Test that the default trigger_rule (all_success) does not add
run_if to the task JSON."""
+ dag = DAG(dag_id="example_dag", schedule=None,
start_date=datetime.now())
+ operator = DatabricksNotebookOperator(
+ notebook_path="/path/to/notebook",
+ source="WORKSPACE",
+ task_id="test_task",
+ dag=dag,
+ )
+
+ databricks_workflow_task_group = MagicMock()
+ databricks_workflow_task_group.notebook_packages = []
+ databricks_workflow_task_group.notebook_params = {}
+
+ operator.task_group = databricks_workflow_task_group
+ relevant_upstreams = []
+ task_dict = {}
+
+ task_json =
operator._convert_to_databricks_workflow_task(relevant_upstreams, task_dict)
+
+ assert "run_if" not in task_json
+
def test_convert_to_databricks_workflow_task_no_task_group(self):
"""Test that an error is raised if the operator is not in a
TaskGroup."""
operator = DatabricksNotebookOperator(