This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3eae5bca9b0 fix: Map Airflow trigger_rule to Databricks run_if in 
`DatabricksWorkflowTaskGroup` (#63420)
3eae5bca9b0 is described below

commit 3eae5bca9b0db744dd75f0c3a14180ae920f63e5
Author: Noritaka Sekiyama <[email protected]>
AuthorDate: Sat Mar 14 16:05:04 2026 +0800

    fix: Map Airflow trigger_rule to Databricks run_if in 
`DatabricksWorkflowTaskGroup` (#63420)
    
    * fix: Map Airflow trigger_rule to Databricks run_if in 
DatabricksWorkflowTaskGroup
    
    When using DatabricksWorkflowTaskGroup, the Airflow trigger_rule on
    DatabricksTaskOperator was not being propagated to the Databricks Jobs
    API. This caused tasks to execute in Databricks even when Airflow marked
    them as skipped due to upstream skip conditions.
    
    This fix maps Airflow TriggerRule values to the corresponding Databricks
    RunIf conditions in the workflow task JSON, keeping Airflow and
    Databricks task states in sync.
    
    closes: #47024
    
    * Add logging for trigger_rule to run_if mapping decisions
    
    Addresses review feedback to log which mapping decision was made,
    making it easier to debug issues with trigger rule propagation.
---
 .../providers/databricks/operators/databricks.py   | 34 +++++++++++++
 .../unit/databricks/operators/test_databricks.py   | 57 ++++++++++++++++++++++
 2 files changed, 91 insertions(+)

diff --git 
a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py 
b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
index 2ec8f03b4ed..e99750ee72e 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
@@ -67,6 +67,19 @@ XCOM_JOB_ID_KEY = "job_id"
 XCOM_RUN_PAGE_URL_KEY = "run_page_url"
 XCOM_STATEMENT_ID_KEY = "statement_id"
 
+# Mapping from Airflow TriggerRule to Databricks RunIf condition.
+# Only non-default rules are included; ALL_SUCCESS is the Databricks default
+# and is omitted to keep the task JSON minimal.
+_TRIGGER_RULE_TO_DATABRICKS_RUN_IF: dict[str, str] = {
+    "all_failed": "ALL_FAILED",
+    "all_done": "ALL_DONE",
+    "one_success": "AT_LEAST_ONE_SUCCESS",
+    "one_failed": "AT_LEAST_ONE_FAILED",
+    "none_failed": "NONE_FAILED",
+    "none_failed_min_one_success": "NONE_FAILED",
+    "always": "ALL_DONE",
+}
+
 
 def _handle_databricks_operator_execution(operator, hook, log, context) -> 
None:
     """
@@ -1356,6 +1369,7 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
     ) -> dict[str, object]:
         """Convert the operator to a Databricks workflow task that can be a 
task in a workflow."""
         base_task_json = self._get_task_base_json()
+
         result = {
             "task_key": self.databricks_task_key,
             "depends_on": [
@@ -1366,6 +1380,26 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
             **base_task_json,
         }
 
+        trigger_rule_value = (
+            self.trigger_rule.value if hasattr(self.trigger_rule, "value") 
else str(self.trigger_rule)
+        )
+        databricks_run_if = 
_TRIGGER_RULE_TO_DATABRICKS_RUN_IF.get(trigger_rule_value)
+        if databricks_run_if:
+            self.log.info(
+                "Mapping Airflow trigger_rule '%s' to Databricks run_if '%s' 
for task '%s'",
+                trigger_rule_value,
+                databricks_run_if,
+                self.task_id,
+            )
+            result["run_if"] = databricks_run_if
+        else:
+            self.log.info(
+                "No Databricks run_if mapping for Airflow trigger_rule '%s' on 
task '%s'; "
+                "using Databricks default (ALL_SUCCESS)",
+                trigger_rule_value,
+                self.task_id,
+            )
+
         if self.existing_cluster_id and self.job_cluster_key:
             raise ValueError(
                 "Both existing_cluster_id and job_cluster_key are set. Only 
one can be set per task."
diff --git 
a/providers/databricks/tests/unit/databricks/operators/test_databricks.py 
b/providers/databricks/tests/unit/databricks/operators/test_databricks.py
index 82888fc461b..e4ee05f36ef 100644
--- a/providers/databricks/tests/unit/databricks/operators/test_databricks.py
+++ b/providers/databricks/tests/unit/databricks/operators/test_databricks.py
@@ -2561,6 +2561,63 @@ class TestDatabricksNotebookOperator:
 
         assert task_json == expected_json
 
+    @pytest.mark.parametrize(
+        ("trigger_rule", "expected_run_if"),
+        [
+            ("all_failed", "ALL_FAILED"),
+            ("all_done", "ALL_DONE"),
+            ("one_success", "AT_LEAST_ONE_SUCCESS"),
+            ("one_failed", "AT_LEAST_ONE_FAILED"),
+            ("none_failed", "NONE_FAILED"),
+            ("none_failed_min_one_success", "NONE_FAILED"),
+            ("always", "ALL_DONE"),
+        ],
+    )
+    def test_convert_to_databricks_workflow_task_with_trigger_rule(self, 
trigger_rule, expected_run_if):
+        """Test that trigger_rule is mapped to Databricks run_if in the 
workflow task JSON."""
+        dag = DAG(dag_id="example_dag", schedule=None, 
start_date=datetime.now())
+        operator = DatabricksNotebookOperator(
+            notebook_path="/path/to/notebook",
+            source="WORKSPACE",
+            task_id="test_task",
+            trigger_rule=trigger_rule,
+            dag=dag,
+        )
+
+        databricks_workflow_task_group = MagicMock()
+        databricks_workflow_task_group.notebook_packages = []
+        databricks_workflow_task_group.notebook_params = {}
+
+        operator.task_group = databricks_workflow_task_group
+        relevant_upstreams = []
+        task_dict = {}
+
+        task_json = 
operator._convert_to_databricks_workflow_task(relevant_upstreams, task_dict)
+
+        assert task_json["run_if"] == expected_run_if
+
+    def 
test_convert_to_databricks_workflow_task_default_trigger_rule_no_run_if(self):
+        """Test that the default trigger_rule (all_success) does not add 
run_if to the task JSON."""
+        dag = DAG(dag_id="example_dag", schedule=None, 
start_date=datetime.now())
+        operator = DatabricksNotebookOperator(
+            notebook_path="/path/to/notebook",
+            source="WORKSPACE",
+            task_id="test_task",
+            dag=dag,
+        )
+
+        databricks_workflow_task_group = MagicMock()
+        databricks_workflow_task_group.notebook_packages = []
+        databricks_workflow_task_group.notebook_params = {}
+
+        operator.task_group = databricks_workflow_task_group
+        relevant_upstreams = []
+        task_dict = {}
+
+        task_json = 
operator._convert_to_databricks_workflow_task(relevant_upstreams, task_dict)
+
+        assert "run_if" not in task_json
+
     def test_convert_to_databricks_workflow_task_no_task_group(self):
         """Test that an error is raised if the operator is not in a 
TaskGroup."""
         operator = DatabricksNotebookOperator(

Reply via email to