Copilot commented on code in PR #64545:
URL: https://github.com/apache/airflow/pull/64545#discussion_r3025328220


##########
providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_condition.py:
##########
@@ -0,0 +1,223 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerConditionOperator
+
+from unit.amazon.aws.utils.test_template_fields import validate_template_fields
+
+
+def _choose(conditions, if_ids="if_task", else_ids="else_task"):
+    """Instantiate with conditions list and call choose_branch."""
+    op = SageMakerConditionOperator(
+        task_id="test",
+        conditions=conditions,
+        if_task_ids=if_ids,
+        else_task_ids=else_ids,
+    )
+    return op.choose_branch(context=MagicMock(spec=dict))
+
+
+def _choose_flat(condition_type, left, right, if_ids="if_task", 
else_ids="else_task"):
+    """Instantiate with flat params and call choose_branch."""
+    op = SageMakerConditionOperator(
+        task_id="test",
+        condition_type=condition_type,
+        left_value=left,
+        right_value=right,
+        if_task_ids=if_ids,
+        else_task_ids=else_ids,
+    )
+    return op.choose_branch(context=MagicMock(spec=dict))
+
+
+def test_template_fields():
+    op = SageMakerConditionOperator(
+        task_id="test",
+        conditions=[{"type": "Equals", "left_value": 1, "right_value": 1}],
+        if_task_ids=["a"],
+        else_task_ids=["b"],
+    )
+    validate_template_fields(op)
+
+
+class TestConditionTypes:
+    """One true + one false per condition type, plus logical combinators."""
+
+    @pytest.mark.parametrize(
+        ("cond_type", "left", "right", "expected"),
+        [
+            ("Equals", 1, 1, "if"),
+            ("Equals", 1, 2, "else"),
+            ("GreaterThan", 5, 3, "if"),
+            ("GreaterThan", 3, 5, "else"),
+            ("GreaterThanOrEqualTo", 3, 3, "if"),
+            ("GreaterThanOrEqualTo", 2, 3, "else"),
+            ("LessThan", 3, 5, "if"),
+            ("LessThan", 5, 3, "else"),
+            ("LessThanOrEqualTo", 3, 3, "if"),
+            ("LessThanOrEqualTo", 5, 3, "else"),
+        ],
+    )
+    def test_comparison(self, cond_type, left, right, expected):
+        result = _choose([{"type": cond_type, "left_value": left, 
"right_value": right}])
+        assert result == (["if_task"] if expected == "if" else ["else_task"])
+
+    def test_in_true(self):
+        assert _choose([{"type": "In", "value": 1, "in_values": [1, 2]}]) == 
["if_task"]
+
+    def test_in_false(self):
+        assert _choose([{"type": "In", "value": 4, "in_values": [1, 2]}]) == 
["else_task"]
+
+    def test_not_negates(self):
+        cond = [{"type": "Not", "condition": {"type": "Equals", "left_value": 
1, "right_value": 1}}]
+        assert _choose(cond) == ["else_task"]
+
+    def test_or_any_true(self):
+        cond = [
+            {
+                "type": "Or",
+                "conditions": [
+                    {"type": "Equals", "left_value": 1, "right_value": 2},
+                    {"type": "Equals", "left_value": 1, "right_value": 1},
+                ],
+            }
+        ]
+        assert _choose(cond) == ["if_task"]
+
+
+class TestAndSemantics:
+    def test_multiple_conditions_and(self):
+        """All true -> if, one false -> else."""
+        assert _choose(
+            [
+                {"type": "Equals", "left_value": 1, "right_value": 1},
+                {"type": "GreaterThan", "left_value": 5, "right_value": 3},
+            ]
+        ) == ["if_task"]
+        assert _choose(
+            [
+                {"type": "Equals", "left_value": 1, "right_value": 1},
+                {"type": "GreaterThan", "left_value": 2, "right_value": 10},
+            ]
+        ) == ["else_task"]
+
+
+class TestValueCasting:
+    def test_cast_numeric_and_passthrough(self):
+        """int string, float string, bool string, non-numeric string, 
non-string type."""
+        assert SageMakerConditionOperator._cast("42") == 42
+        assert SageMakerConditionOperator._cast("0.9") == 0.9
+        assert SageMakerConditionOperator._cast("true") is True
+        assert SageMakerConditionOperator._cast("us-east-1") == "us-east-1"
+        assert SageMakerConditionOperator._cast(42) == 42
+
+
+class TestValidation:
+    def test_empty_conditions_raises(self):
+        with pytest.raises(ValueError, match="At least 1 condition is 
required"):
+            SageMakerConditionOperator(task_id="t", conditions=[], 
if_task_ids="a", else_task_ids="b")
+
+    def test_unknown_type_raises(self):
+        with pytest.raises(ValueError, match="Unknown condition type"):
+            _choose([{"type": "FooBar", "left_value": 1, "right_value": 2}])
+
+    def test_type_mismatch_raises(self):
+        with pytest.raises(TypeError, match="Cannot compare"):
+            _choose([{"type": "GreaterThanOrEqualTo", "left_value": "hello", 
"right_value": 0.9}])
+
+    def test_none_operand_raises(self):
+        """Covers both Python None and Jinja-rendered string 'None'."""
+        with pytest.raises(TypeError, match="received None"):
+            _choose([{"type": "Equals", "left_value": None, "right_value": 1}])
+        with pytest.raises(TypeError, match="received None"):
+            _choose([{"type": "GreaterThanOrEqualTo", "left_value": "None", 
"right_value": 0.9}])
+
+    @pytest.mark.parametrize(
+        ("condition", "match_pattern"),
+        [
+            ({}, "missing required key 'type'"),
+            ({"type": "Equals", "left_value": 1}, "missing required key"),
+            ({"type": "Not"}, "missing required key"),
+            ({"type": "Or"}, "missing required key"),
+        ],
+    )
+    def test_missing_key_raises(self, condition, match_pattern):
+        with pytest.raises(ValueError, match=match_pattern):
+            _choose([condition])
+
+
+class TestFlatInterface:
+    def test_flat_condition(self):
+        """Flat params work for comparison and In types."""
+        assert _choose_flat("Equals", 1, 1) == ["if_task"]
+        assert _choose_flat("Equals", 1, 2) == ["else_task"]
+        assert _choose_flat("In", 1, [1, 2, 3]) == ["if_task"]
+        assert _choose_flat("In", 99, [1, 2, 3]) == ["else_task"]
+
+    def test_invalid_condition_type_raises(self):
+        with pytest.raises(ValueError, match="Unknown condition_type"):
+            SageMakerConditionOperator(
+                task_id="t",
+                condition_type="NotEquals",
+                left_value=1,
+                right_value=1,
+                if_task_ids="a",
+                else_task_ids="b",
+            )
+
+    def test_mutual_exclusion_raises(self):
+        with pytest.raises(ValueError, match="Cannot use 'condition_type' and 
'conditions' together"):
+            SageMakerConditionOperator(
+                task_id="t",
+                condition_type="Equals",
+                left_value=1,
+                right_value=1,
+                conditions=[{"type": "Equals", "left_value": 1, "right_value": 
1}],
+                if_task_ids="a",
+                else_task_ids="b",
+            )
+
+    def test_neither_provided_raises(self):
+        with pytest.raises(ValueError, match="Missing condition"):
+            SageMakerConditionOperator(task_id="t", if_task_ids="a", 
else_task_ids="b")
+
+    def test_optional_else_task_ids(self):
+        """else_task_ids defaults to empty list when omitted."""
+        op = SageMakerConditionOperator(
+            task_id="t",
+            conditions=[{"type": "Equals", "left_value": 1, "right_value": 1}],
+            if_task_ids=["deploy"],
+        )
+        assert op.else_task_ids == []
+
+    def test_no_else_branch_raises_on_false(self):
+        """When else_task_ids is empty and conditions are false, raises 
AirflowFailException."""
+        from airflow.providers.common.compat.sdk import AirflowFailException
+

Review Comment:
   Importing `AirflowFailException` inside the test function is inconsistent 
with the usual top-of-file import style used in these unit tests and makes 
linting/type-checking harder. Consider moving this import to the module import 
section.



##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1991,3 +1996,262 @@ def execute(self, context):
             self.hook.conn.get_waiter("notebook_instance_in_service").wait(
                 NotebookInstanceName=self.instance_name
             )
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+    """
+    Evaluates a single condition or a list of conditions, and routes tasks 
based on the result.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:SageMakerConditionOperator`
+
+    :param condition_type: Condition type for the simple (flat) interface.
+        Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+        ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+        Mutually exclusive with ``conditions``.
+    :param left_value: Left operand for the flat interface. For ``In`` 
conditions
+        this is the value to check membership of.
+    :param right_value: Right operand for the flat interface. For ``In`` 
conditions
+        this is the list of allowed values.
+    :param conditions: List of condition dicts to evaluate (AND-ed together).
+        Each dict must have a ``type`` key. Minimum 1, maximum 200.
+        Mutually exclusive with 
``condition_type``/``left_value``/``right_value``.
+    :param if_task_ids: Task ID(s) to execute when all conditions are True.
+    :param else_task_ids: Task ID(s) to execute when any condition is False.
+        If omitted, the task fails with ``AirflowFailException`` when 
conditions are not met.
+    """
+
+    _VALID_FLAT_TYPES: ClassVar[set[str]] = {
+        "Equals",
+        "GreaterThan",
+        "GreaterThanOrEqualTo",
+        "LessThan",
+        "LessThanOrEqualTo",
+        "In",
+    }
+
+    template_fields: Sequence[str] = (
+        "condition_type",
+        "left_value",
+        "right_value",
+        "conditions",
+        "if_task_ids",
+        "else_task_ids",
+    )
+
+    def __init__(
+        self,
+        *,
+        condition_type: str | None = None,
+        left_value: Any = None,
+        right_value: Any = None,
+        conditions: list[dict] | None = None,
+        if_task_ids: str | list[str],
+        else_task_ids: str | list[str] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        has_flat = condition_type is not None
+        has_list = conditions is not None
+
+        if has_flat and has_list:
+            raise ValueError(
+                "Cannot use 'condition_type' and 'conditions' together. "
+                "Use 'condition_type' with 'left_value'/'right_value' for a 
single condition, "
+                "or 'conditions' for multiple/nested conditions."
+            )
+        if not has_flat and not has_list:
+            raise ValueError(
+                "Missing condition: provide 'condition_type' with 
'left_value'/'right_value' "
+                "for a single condition, or 'conditions' for multiple/nested 
conditions."
+            )
+
+        if has_flat:
+            if condition_type not in self._VALID_FLAT_TYPES:
+                raise ValueError(
+                    f"Unknown condition_type '{condition_type}'. "
+                    f"Expected one of: {', 
'.join(sorted(self._VALID_FLAT_TYPES))}."
+                )
+            self.condition_type: str | None = condition_type
+            self.left_value = left_value
+            self.right_value = right_value
+            if condition_type == "In":
+                self.conditions: list[dict[str, Any]] = [
+                    {"type": "In", "value": left_value, "in_values": 
right_value}
+                ]
+            else:
+                self.conditions = [
+                    {"type": condition_type, "left_value": left_value, 
"right_value": right_value}
+                ]
+        else:
+            self.condition_type = None
+            self.left_value = None
+            self.right_value = None
+            self.conditions = conditions  # type: ignore[assignment]
+
+        if not self.conditions:
+            raise ValueError("At least 1 condition is required, but got an 
empty list.")
+        self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else 
if_task_ids
+        self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str) 
else (else_task_ids or [])
+
+    @staticmethod
+    def _cast(value: Any) -> Any:
+        """
+        Cast Jinja-rendered string values to appropriate Python types.
+
+        This is a compatibility shim for environments where
+        ``render_template_as_native_obj=True`` is not available at the DAG or
+        task level (e.g., YAML DAGs). Once task-level native rendering
+        is widely supported, this method can be removed in favor of letting
+        Airflow handle the casting natively.
+
+        - Numeric strings become int or float.
+        - ``"true"``/``"false"`` become booleans.
+        - ``"None"`` becomes ``None`` (common when ``xcom_pull`` returns 
nothing).
+        - Other strings are returned unchanged.
+        - Non-string types pass through as-is.
+        """
+        if not isinstance(value, str):
+            return value
+        if value == "None":
+            return None
+        try:
+            return int(value)
+        except (ValueError, TypeError):
+            pass
+        try:
+            return float(value)
+        except (ValueError, TypeError):
+            pass
+        if value.lower() == "true":
+            return True
+        if value.lower() == "false":
+            return False
+        return value
+
+    _COMPARISON_OPERATORS: ClassVar[dict[str, Callable[[Any, Any], bool]]] = {
+        "Equals": lambda left, right: left == right,
+        "GreaterThan": lambda left, right: left > right,
+        "GreaterThanOrEqualTo": lambda left, right: left >= right,
+        "LessThan": lambda left, right: left < right,
+        "LessThanOrEqualTo": lambda left, right: left <= right,
+    }
+
+    def _evaluate(self, condition: dict, depth: int = 0) -> bool:
+        """
+        Recursively evaluate a single condition dict.
+
+        :param condition: A condition dictionary with a ``type`` key.
+        :param depth: Current nesting depth (used for log indentation only).
+        :returns: Boolean result of the condition evaluation.
+        """
+        log_indent = "  " * depth
+        try:
+            condition_type = condition["type"]
+        except KeyError:
+            raise ValueError("Condition dict is missing required key 'type'.")
+
+        if condition_type in self._COMPARISON_OPERATORS:
+            try:
+                left = self._cast(condition["left_value"])
+                right = self._cast(condition["right_value"])
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+
+            # None check — likely an XCom that was not pushed
+            if left is None or right is None:
+                raise TypeError(
+                    f"Condition '{condition_type}' received None: 
left={left!r}, right={right!r}. "
+                    "This usually means the upstream task did not run or did 
not push a value to XCom."
+                )
+
+            # Type compatibility check
+            left_type = type(left)
+            right_type = type(right)
+            numeric_types = (int, float)
+            left_is_numeric = isinstance(left, numeric_types) and not 
isinstance(left, bool)
+            right_is_numeric = isinstance(right, numeric_types) and not 
isinstance(right, bool)
+
+            if not (left_is_numeric and right_is_numeric) and left_type is not 
right_type:
+                raise TypeError(
+                    f"Cannot compare {left_type.__name__} ({left!r}) with 
{right_type.__name__} ({right!r}) "
+                    f"in condition '{condition_type}'. Both values must be the 
same type."
+                )
+
+            comparison_result = 
self._COMPARISON_OPERATORS[condition_type](left, right)
+            self.log.info(
+                "%s%s: %r %s %r -> %s",
+                log_indent,
+                condition_type,
+                left,
+                condition_type,
+                right,
+                comparison_result,
+            )
+            return comparison_result
+
+        if condition_type == "In":
+            try:
+                query_value = self._cast(condition["value"])
+                allowed_values = [self._cast(val) for val in 
condition["in_values"]]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            membership_result = query_value in allowed_values
+            self.log.info("%sIn: %r in %r -> %s", log_indent, query_value, 
allowed_values, membership_result)
+            return membership_result
+
+        if condition_type == "Not":
+            try:
+                inner_condition = condition["condition"]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            inner_result = self._evaluate(inner_condition, depth + 1)
+            negated_result = not inner_result
+            self.log.info("%sNot: not %s -> %s", log_indent, inner_result, 
negated_result)
+            return negated_result
+
+        if condition_type == "Or":
+            try:
+                inner_conditions = condition["conditions"]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            inner_results = [self._evaluate(inner_cond, depth + 1) for 
inner_cond in inner_conditions]
+            or_result = any(inner_results)
+            self.log.info("%sOr: any(%r) -> %s", log_indent, inner_results, 
or_result)
+            return or_result

Review Comment:
   `Or` evaluation currently evaluates *all* nested conditions (`inner_results 
= [...]`) rather than short-circuiting. This can change expected boolean 
semantics and even fail the task when an earlier condition already makes the 
overall `Or` true (e.g., later nested condition has a missing key / type 
mismatch). Consider iterating and returning `True` on the first true condition 
(optionally logging evaluated results) to preserve short-circuit behavior.
   ```suggestion
               inner_results: list[bool] = []
               for inner_cond in inner_conditions:
                   result = self._evaluate(inner_cond, depth + 1)
                   inner_results.append(result)
                   if result:
                       self.log.info("%sOr: any(%r) -> True", log_indent, 
inner_results)
                       return True
               self.log.info("%sOr: any(%r) -> False", log_indent, 
inner_results)
               return False
   ```



##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1991,3 +1996,262 @@ def execute(self, context):
             self.hook.conn.get_waiter("notebook_instance_in_service").wait(
                 NotebookInstanceName=self.instance_name
             )
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+    """
+    Evaluates a single condition or a list of conditions, and routes tasks 
based on the result.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:SageMakerConditionOperator`
+
+    :param condition_type: Condition type for the simple (flat) interface.
+        Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+        ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+        Mutually exclusive with ``conditions``.
+    :param left_value: Left operand for the flat interface. For ``In`` 
conditions
+        this is the value to check membership of.
+    :param right_value: Right operand for the flat interface. For ``In`` 
conditions
+        this is the list of allowed values.
+    :param conditions: List of condition dicts to evaluate (AND-ed together).
+        Each dict must have a ``type`` key. Minimum 1, maximum 200.

Review Comment:
   The docstring states the `conditions` list is "maximum 200", but `__init__` 
only validates non-empty and does not enforce an upper bound. Either add a 
length check (and a clear error) or remove the documented maximum to avoid 
misleading users.
   ```suggestion
           Each dict must have a ``type`` key. Minimum 1.
   ```



##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1991,3 +1996,262 @@ def execute(self, context):
             self.hook.conn.get_waiter("notebook_instance_in_service").wait(
                 NotebookInstanceName=self.instance_name
             )
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+    """
+    Evaluates a single condition or a list of conditions, and routes tasks 
based on the result.
+

Review Comment:
   PR description/title mentions adding a `SageMakerFailOperator`, but this 
change set only introduces `SageMakerConditionOperator` (plus its tests/docs). 
If `SageMakerFailOperator` is still intended, it looks missing from the 
implementation/tests/docs; otherwise, the PR description/title should be 
updated to match the actual changes.



##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1991,3 +1996,262 @@ def execute(self, context):
             self.hook.conn.get_waiter("notebook_instance_in_service").wait(
                 NotebookInstanceName=self.instance_name
             )
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+    """
+    Evaluates a single condition or a list of conditions, and routes tasks 
based on the result.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:SageMakerConditionOperator`
+
+    :param condition_type: Condition type for the simple (flat) interface.
+        Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+        ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+        Mutually exclusive with ``conditions``.
+    :param left_value: Left operand for the flat interface. For ``In`` 
conditions
+        this is the value to check membership of.
+    :param right_value: Right operand for the flat interface. For ``In`` 
conditions
+        this is the list of allowed values.
+    :param conditions: List of condition dicts to evaluate (AND-ed together).
+        Each dict must have a ``type`` key. Minimum 1, maximum 200.
+        Mutually exclusive with 
``condition_type``/``left_value``/``right_value``.
+    :param if_task_ids: Task ID(s) to execute when all conditions are True.
+    :param else_task_ids: Task ID(s) to execute when any condition is False.
+        If omitted, the task fails with ``AirflowFailException`` when 
conditions are not met.
+    """
+
+    _VALID_FLAT_TYPES: ClassVar[set[str]] = {
+        "Equals",
+        "GreaterThan",
+        "GreaterThanOrEqualTo",
+        "LessThan",
+        "LessThanOrEqualTo",
+        "In",
+    }
+
+    template_fields: Sequence[str] = (
+        "condition_type",
+        "left_value",
+        "right_value",
+        "conditions",
+        "if_task_ids",
+        "else_task_ids",
+    )
+
+    def __init__(
+        self,
+        *,
+        condition_type: str | None = None,
+        left_value: Any = None,
+        right_value: Any = None,
+        conditions: list[dict] | None = None,
+        if_task_ids: str | list[str],
+        else_task_ids: str | list[str] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        has_flat = condition_type is not None
+        has_list = conditions is not None
+
+        if has_flat and has_list:
+            raise ValueError(
+                "Cannot use 'condition_type' and 'conditions' together. "
+                "Use 'condition_type' with 'left_value'/'right_value' for a 
single condition, "
+                "or 'conditions' for multiple/nested conditions."
+            )
+        if not has_flat and not has_list:
+            raise ValueError(
+                "Missing condition: provide 'condition_type' with 
'left_value'/'right_value' "
+                "for a single condition, or 'conditions' for multiple/nested 
conditions."
+            )
+
+        if has_flat:
+            if condition_type not in self._VALID_FLAT_TYPES:
+                raise ValueError(
+                    f"Unknown condition_type '{condition_type}'. "
+                    f"Expected one of: {', 
'.join(sorted(self._VALID_FLAT_TYPES))}."
+                )
+            self.condition_type: str | None = condition_type
+            self.left_value = left_value
+            self.right_value = right_value
+            if condition_type == "In":
+                self.conditions: list[dict[str, Any]] = [
+                    {"type": "In", "value": left_value, "in_values": 
right_value}
+                ]
+            else:
+                self.conditions = [
+                    {"type": condition_type, "left_value": left_value, 
"right_value": right_value}
+                ]
+        else:
+            self.condition_type = None
+            self.left_value = None
+            self.right_value = None
+            self.conditions = conditions  # type: ignore[assignment]
+
+        if not self.conditions:
+            raise ValueError("At least 1 condition is required, but got an 
empty list.")
+        self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else 
if_task_ids
+        self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str) 
else (else_task_ids or [])
+
+    @staticmethod
+    def _cast(value: Any) -> Any:
+        """
+        Cast Jinja-rendered string values to appropriate Python types.
+
+        This is a compatibility shim for environments where
+        ``render_template_as_native_obj=True`` is not available at the DAG or
+        task level (e.g., YAML DAGs). Once task-level native rendering
+        is widely supported, this method can be removed in favor of letting
+        Airflow handle the casting natively.
+
+        - Numeric strings become int or float.
+        - ``"true"``/``"false"`` become booleans.
+        - ``"None"`` becomes ``None`` (common when ``xcom_pull`` returns 
nothing).
+        - Other strings are returned unchanged.
+        - Non-string types pass through as-is.
+        """
+        if not isinstance(value, str):
+            return value
+        if value == "None":
+            return None
+        try:
+            return int(value)
+        except (ValueError, TypeError):
+            pass
+        try:
+            return float(value)
+        except (ValueError, TypeError):
+            pass
+        if value.lower() == "true":
+            return True
+        if value.lower() == "false":
+            return False
+        return value
+
+    _COMPARISON_OPERATORS: ClassVar[dict[str, Callable[[Any, Any], bool]]] = {
+        "Equals": lambda left, right: left == right,
+        "GreaterThan": lambda left, right: left > right,
+        "GreaterThanOrEqualTo": lambda left, right: left >= right,
+        "LessThan": lambda left, right: left < right,
+        "LessThanOrEqualTo": lambda left, right: left <= right,
+    }
+
+    def _evaluate(self, condition: dict, depth: int = 0) -> bool:
+        """
+        Recursively evaluate a single condition dict.
+
+        :param condition: A condition dictionary with a ``type`` key.
+        :param depth: Current nesting depth (used for log indentation only).
+        :returns: Boolean result of the condition evaluation.
+        """
+        log_indent = "  " * depth
+        try:
+            condition_type = condition["type"]
+        except KeyError:
+            raise ValueError("Condition dict is missing required key 'type'.")
+
+        if condition_type in self._COMPARISON_OPERATORS:
+            try:
+                left = self._cast(condition["left_value"])
+                right = self._cast(condition["right_value"])
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+
+            # None check — likely an XCom that was not pushed
+            if left is None or right is None:
+                raise TypeError(
+                    f"Condition '{condition_type}' received None: 
left={left!r}, right={right!r}. "
+                    "This usually means the upstream task did not run or did 
not push a value to XCom."
+                )
+
+            # Type compatibility check
+            left_type = type(left)
+            right_type = type(right)
+            numeric_types = (int, float)
+            left_is_numeric = isinstance(left, numeric_types) and not 
isinstance(left, bool)
+            right_is_numeric = isinstance(right, numeric_types) and not 
isinstance(right, bool)
+
+            if not (left_is_numeric and right_is_numeric) and left_type is not 
right_type:
+                raise TypeError(
+                    f"Cannot compare {left_type.__name__} ({left!r}) with 
{right_type.__name__} ({right!r}) "
+                    f"in condition '{condition_type}'. Both values must be the 
same type."
+                )
+
+            comparison_result = 
self._COMPARISON_OPERATORS[condition_type](left, right)
+            self.log.info(
+                "%s%s: %r %s %r -> %s",
+                log_indent,
+                condition_type,
+                left,
+                condition_type,
+                right,
+                comparison_result,
+            )
+            return comparison_result
+
+        if condition_type == "In":
+            try:
+                query_value = self._cast(condition["value"])
+                allowed_values = [self._cast(val) for val in 
condition["in_values"]]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            membership_result = query_value in allowed_values
+            self.log.info("%sIn: %r in %r -> %s", log_indent, query_value, 
allowed_values, membership_result)
+            return membership_result
+
+        if condition_type == "Not":
+            try:
+                inner_condition = condition["condition"]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            inner_result = self._evaluate(inner_condition, depth + 1)
+            negated_result = not inner_result
+            self.log.info("%sNot: not %s -> %s", log_indent, inner_result, 
negated_result)
+            return negated_result
+
+        if condition_type == "Or":
+            try:
+                inner_conditions = condition["conditions"]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            inner_results = [self._evaluate(inner_cond, depth + 1) for 
inner_cond in inner_conditions]
+            or_result = any(inner_results)
+            self.log.info("%sOr: any(%r) -> %s", log_indent, inner_results, 
or_result)
+            return or_result
+
+        raise ValueError(f"Unknown condition type '{condition_type}'.")
+
+    def choose_branch(self, context: Context) -> list[str]:
+        """
+        Evaluate all conditions and return the appropriate branch task IDs.
+
+        :param context: Airflow context dictionary.
+        :returns: ``if_task_ids`` when all conditions are True, 
``else_task_ids`` otherwise.
+        """
+        condition_count = len(self.conditions)
+        self.log.info("Evaluating %d condition(s).", condition_count)
+
+        evaluation_results = [self._evaluate(condition) for condition in 
self.conditions]
+        all_conditions_met = all(evaluation_results)
+

Review Comment:
   Top-level condition evaluation builds a full `evaluation_results` list 
before calling `all(...)`. This means all conditions are evaluated even after 
one is already `False`, which can raise exceptions (e.g., missing XCom / type 
mismatch) that would otherwise be irrelevant to routing to the else branch. 
Consider evaluating conditions in a loop with early exit on the first `False` 
(and/or making exception behavior explicit).
   ```suggestion
           evaluation_results: list[bool] = []
           all_conditions_met = True
   
           for condition in self.conditions:
               result = self._evaluate(condition)
               evaluation_results.append(result)
               if not result:
                   all_conditions_met = False
                   break
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to