Copilot commented on code in PR #64545: URL: https://github.com/apache/airflow/pull/64545#discussion_r3025328220
########## providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_condition.py: ########## @@ -0,0 +1,223 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from airflow.providers.amazon.aws.operators.sagemaker import SageMakerConditionOperator + +from unit.amazon.aws.utils.test_template_fields import validate_template_fields + + +def _choose(conditions, if_ids="if_task", else_ids="else_task"): + """Instantiate with conditions list and call choose_branch.""" + op = SageMakerConditionOperator( + task_id="test", + conditions=conditions, + if_task_ids=if_ids, + else_task_ids=else_ids, + ) + return op.choose_branch(context=MagicMock(spec=dict)) + + +def _choose_flat(condition_type, left, right, if_ids="if_task", else_ids="else_task"): + """Instantiate with flat params and call choose_branch.""" + op = SageMakerConditionOperator( + task_id="test", + condition_type=condition_type, + left_value=left, + right_value=right, + if_task_ids=if_ids, + else_task_ids=else_ids, + ) + return op.choose_branch(context=MagicMock(spec=dict)) + + +def test_template_fields(): + op = SageMakerConditionOperator( + task_id="test", + conditions=[{"type": "Equals", "left_value": 1, "right_value": 1}], + if_task_ids=["a"], + else_task_ids=["b"], + ) + validate_template_fields(op) + + +class TestConditionTypes: + """One true + one false per condition type, plus logical combinators.""" + + @pytest.mark.parametrize( + ("cond_type", "left", "right", "expected"), + [ + ("Equals", 1, 1, "if"), + ("Equals", 1, 2, "else"), + ("GreaterThan", 5, 3, "if"), + ("GreaterThan", 3, 5, "else"), + ("GreaterThanOrEqualTo", 3, 3, "if"), + ("GreaterThanOrEqualTo", 2, 3, "else"), + ("LessThan", 3, 5, "if"), + ("LessThan", 5, 3, "else"), + ("LessThanOrEqualTo", 3, 3, "if"), + ("LessThanOrEqualTo", 5, 3, "else"), + ], + ) + def test_comparison(self, cond_type, left, right, expected): + result = _choose([{"type": cond_type, "left_value": left, "right_value": right}]) + assert result == (["if_task"] if expected == "if" else ["else_task"]) + + def test_in_true(self): + assert _choose([{"type": "In", "value": 1, "in_values": [1, 2]}]) == ["if_task"] + + def test_in_false(self): + assert _choose([{"type": "In", "value": 4, "in_values": [1, 2]}]) == ["else_task"] + + def test_not_negates(self): + cond = [{"type": "Not", "condition": {"type": "Equals", "left_value": 1, "right_value": 1}}] + assert _choose(cond) == ["else_task"] + + def test_or_any_true(self): + cond = [ + { + "type": "Or", + "conditions": [ + {"type": "Equals", "left_value": 1, "right_value": 2}, + {"type": "Equals", "left_value": 1, "right_value": 1}, + ], + } + ] + assert _choose(cond) == ["if_task"] + + +class TestAndSemantics: + def test_multiple_conditions_and(self): + """All true -> if, one false -> else.""" + assert _choose( + [ + {"type": "Equals", "left_value": 1, "right_value": 1}, + {"type": "GreaterThan", "left_value": 5, "right_value": 3}, + ] + ) == ["if_task"] + assert _choose( + [ + {"type": "Equals", "left_value": 1, "right_value": 1}, + {"type": "GreaterThan", "left_value": 2, "right_value": 10}, + ] + ) == ["else_task"] + + +class TestValueCasting: + def test_cast_numeric_and_passthrough(self): + """int string, float string, bool string, non-numeric string, non-string type.""" + assert SageMakerConditionOperator._cast("42") == 42 + assert SageMakerConditionOperator._cast("0.9") == 0.9 + assert SageMakerConditionOperator._cast("true") is True + assert SageMakerConditionOperator._cast("us-east-1") == "us-east-1" + assert SageMakerConditionOperator._cast(42) == 42 + + +class TestValidation: + def test_empty_conditions_raises(self): + with pytest.raises(ValueError, match="At least 1 condition is required"): + SageMakerConditionOperator(task_id="t", conditions=[], if_task_ids="a", else_task_ids="b") + + def test_unknown_type_raises(self): + with pytest.raises(ValueError, match="Unknown condition type"): + _choose([{"type": "FooBar", "left_value": 1, "right_value": 2}]) + + def test_type_mismatch_raises(self): + with pytest.raises(TypeError, match="Cannot compare"): + _choose([{"type": "GreaterThanOrEqualTo", "left_value": "hello", "right_value": 0.9}]) + + def test_none_operand_raises(self): + """Covers both Python None and Jinja-rendered string 'None'.""" + with pytest.raises(TypeError, match="received None"): + _choose([{"type": "Equals", "left_value": None, "right_value": 1}]) + with pytest.raises(TypeError, match="received None"): + _choose([{"type": "GreaterThanOrEqualTo", "left_value": "None", "right_value": 0.9}]) + + @pytest.mark.parametrize( + ("condition", "match_pattern"), + [ + ({}, "missing required key 'type'"), + ({"type": "Equals", "left_value": 1}, "missing required key"), + ({"type": "Not"}, "missing required key"), + ({"type": "Or"}, "missing required key"), + ], + ) + def test_missing_key_raises(self, condition, match_pattern): + with pytest.raises(ValueError, match=match_pattern): + _choose([condition]) + + +class TestFlatInterface: + def test_flat_condition(self): + """Flat params work for comparison and In types.""" + assert _choose_flat("Equals", 1, 1) == ["if_task"] + assert _choose_flat("Equals", 1, 2) == ["else_task"] + assert _choose_flat("In", 1, [1, 2, 3]) == ["if_task"] + assert _choose_flat("In", 99, [1, 2, 3]) == ["else_task"] + + def test_invalid_condition_type_raises(self): + with pytest.raises(ValueError, match="Unknown condition_type"): + SageMakerConditionOperator( + task_id="t", + condition_type="NotEquals", + left_value=1, + right_value=1, + if_task_ids="a", + else_task_ids="b", + ) + + def test_mutual_exclusion_raises(self): + with pytest.raises(ValueError, match="Cannot use 'condition_type' and 'conditions' together"): + SageMakerConditionOperator( + task_id="t", + condition_type="Equals", + left_value=1, + right_value=1, + conditions=[{"type": "Equals", "left_value": 1, "right_value": 1}], + if_task_ids="a", + else_task_ids="b", + ) + + def test_neither_provided_raises(self): + with pytest.raises(ValueError, match="Missing condition"): + SageMakerConditionOperator(task_id="t", if_task_ids="a", else_task_ids="b") + + def test_optional_else_task_ids(self): + """else_task_ids defaults to empty list when omitted.""" + op = SageMakerConditionOperator( + task_id="t", + conditions=[{"type": "Equals", "left_value": 1, "right_value": 1}], + if_task_ids=["deploy"], + ) + assert op.else_task_ids == [] + + def test_no_else_branch_raises_on_false(self): + """When else_task_ids is empty and conditions are false, raises AirflowFailException.""" + from airflow.providers.common.compat.sdk import AirflowFailException + Review Comment: Importing `AirflowFailException` inside the test function is inconsistent with the usual top-of-file import style used in these unit tests and makes linting/type-checking harder. Consider moving this import to the module import section. ########## providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py: ########## @@ -1991,3 +1996,262 @@ def execute(self, context): self.hook.conn.get_waiter("notebook_instance_in_service").wait( NotebookInstanceName=self.instance_name ) + + +class SageMakerConditionOperator(BaseBranchOperator): + """ + Evaluates a single condition or a list of conditions, and routes tasks based on the result. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerConditionOperator` + + :param condition_type: Condition type for the simple (flat) interface. + Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``, + ``LessThan``, ``LessThanOrEqualTo``, ``In``. + Mutually exclusive with ``conditions``. + :param left_value: Left operand for the flat interface. For ``In`` conditions + this is the value to check membership of. + :param right_value: Right operand for the flat interface. For ``In`` conditions + this is the list of allowed values. + :param conditions: List of condition dicts to evaluate (AND-ed together). + Each dict must have a ``type`` key. Minimum 1, maximum 200. + Mutually exclusive with ``condition_type``/``left_value``/``right_value``. + :param if_task_ids: Task ID(s) to execute when all conditions are True. + :param else_task_ids: Task ID(s) to execute when any condition is False. + If omitted, the task fails with ``AirflowFailException`` when conditions are not met. + """ + + _VALID_FLAT_TYPES: ClassVar[set[str]] = { + "Equals", + "GreaterThan", + "GreaterThanOrEqualTo", + "LessThan", + "LessThanOrEqualTo", + "In", + } + + template_fields: Sequence[str] = ( + "condition_type", + "left_value", + "right_value", + "conditions", + "if_task_ids", + "else_task_ids", + ) + + def __init__( + self, + *, + condition_type: str | None = None, + left_value: Any = None, + right_value: Any = None, + conditions: list[dict] | None = None, + if_task_ids: str | list[str], + else_task_ids: str | list[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + has_flat = condition_type is not None + has_list = conditions is not None + + if has_flat and has_list: + raise ValueError( + "Cannot use 'condition_type' and 'conditions' together. " + "Use 'condition_type' with 'left_value'/'right_value' for a single condition, " + "or 'conditions' for multiple/nested conditions." + ) + if not has_flat and not has_list: + raise ValueError( + "Missing condition: provide 'condition_type' with 'left_value'/'right_value' " + "for a single condition, or 'conditions' for multiple/nested conditions." + ) + + if has_flat: + if condition_type not in self._VALID_FLAT_TYPES: + raise ValueError( + f"Unknown condition_type '{condition_type}'. " + f"Expected one of: {', '.join(sorted(self._VALID_FLAT_TYPES))}." + ) + self.condition_type: str | None = condition_type + self.left_value = left_value + self.right_value = right_value + if condition_type == "In": + self.conditions: list[dict[str, Any]] = [ + {"type": "In", "value": left_value, "in_values": right_value} + ] + else: + self.conditions = [ + {"type": condition_type, "left_value": left_value, "right_value": right_value} + ] + else: + self.condition_type = None + self.left_value = None + self.right_value = None + self.conditions = conditions # type: ignore[assignment] + + if not self.conditions: + raise ValueError("At least 1 condition is required, but got an empty list.") + self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else if_task_ids + self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str) else (else_task_ids or []) + + @staticmethod + def _cast(value: Any) -> Any: + """ + Cast Jinja-rendered string values to appropriate Python types. + + This is a compatibility shim for environments where + ``render_template_as_native_obj=True`` is not available at the DAG or + task level (e.g., YAML DAGs). Once task-level native rendering + is widely supported, this method can be removed in favor of letting + Airflow handle the casting natively. + + - Numeric strings become int or float. + - ``"true"``/``"false"`` become booleans. + - ``"None"`` becomes ``None`` (common when ``xcom_pull`` returns nothing). + - Other strings are returned unchanged. + - Non-string types pass through as-is. + """ + if not isinstance(value, str): + return value + if value == "None": + return None + try: + return int(value) + except (ValueError, TypeError): + pass + try: + return float(value) + except (ValueError, TypeError): + pass + if value.lower() == "true": + return True + if value.lower() == "false": + return False + return value + + _COMPARISON_OPERATORS: ClassVar[dict[str, Callable[[Any, Any], bool]]] = { + "Equals": lambda left, right: left == right, + "GreaterThan": lambda left, right: left > right, + "GreaterThanOrEqualTo": lambda left, right: left >= right, + "LessThan": lambda left, right: left < right, + "LessThanOrEqualTo": lambda left, right: left <= right, + } + + def _evaluate(self, condition: dict, depth: int = 0) -> bool: + """ + Recursively evaluate a single condition dict. + + :param condition: A condition dictionary with a ``type`` key. + :param depth: Current nesting depth (used for log indentation only). + :returns: Boolean result of the condition evaluation. + """ + log_indent = " " * depth + try: + condition_type = condition["type"] + except KeyError: + raise ValueError("Condition dict is missing required key 'type'.") + + if condition_type in self._COMPARISON_OPERATORS: + try: + left = self._cast(condition["left_value"]) + right = self._cast(condition["right_value"]) + except KeyError as e: + raise ValueError(f"Condition '{condition_type}' missing required key {e}.") from None + + # None check — likely an XCom that was not pushed + if left is None or right is None: + raise TypeError( + f"Condition '{condition_type}' received None: left={left!r}, right={right!r}. " + "This usually means the upstream task did not run or did not push a value to XCom." + ) + + # Type compatibility check + left_type = type(left) + right_type = type(right) + numeric_types = (int, float) + left_is_numeric = isinstance(left, numeric_types) and not isinstance(left, bool) + right_is_numeric = isinstance(right, numeric_types) and not isinstance(right, bool) + + if not (left_is_numeric and right_is_numeric) and left_type is not right_type: + raise TypeError( + f"Cannot compare {left_type.__name__} ({left!r}) with {right_type.__name__} ({right!r}) " + f"in condition '{condition_type}'. Both values must be the same type." + ) + + comparison_result = self._COMPARISON_OPERATORS[condition_type](left, right) + self.log.info( + "%s%s: %r %s %r -> %s", + log_indent, + condition_type, + left, + condition_type, + right, + comparison_result, + ) + return comparison_result + + if condition_type == "In": + try: + query_value = self._cast(condition["value"]) + allowed_values = [self._cast(val) for val in condition["in_values"]] + except KeyError as e: + raise ValueError(f"Condition '{condition_type}' missing required key {e}.") from None + membership_result = query_value in allowed_values + self.log.info("%sIn: %r in %r -> %s", log_indent, query_value, allowed_values, membership_result) + return membership_result + + if condition_type == "Not": + try: + inner_condition = condition["condition"] + except KeyError as e: + raise ValueError(f"Condition '{condition_type}' missing required key {e}.") from None + inner_result = self._evaluate(inner_condition, depth + 1) + negated_result = not inner_result + self.log.info("%sNot: not %s -> %s", log_indent, inner_result, negated_result) + return negated_result + + if condition_type == "Or": + try: + inner_conditions = condition["conditions"] + except KeyError as e: + raise ValueError(f"Condition '{condition_type}' missing required key {e}.") from None + inner_results = [self._evaluate(inner_cond, depth + 1) for inner_cond in inner_conditions] + or_result = any(inner_results) + self.log.info("%sOr: any(%r) -> %s", log_indent, inner_results, or_result) + return or_result Review Comment: `Or` evaluation currently evaluates *all* nested conditions (`inner_results = [...]`) rather than short-circuiting. This can change expected boolean semantics and even fail the task when an earlier condition already makes the overall `Or` true (e.g., later nested condition has a missing key / type mismatch). Consider iterating and returning `True` on the first true condition (optionally logging evaluated results) to preserve short-circuit behavior. ```suggestion inner_results: list[bool] = [] for inner_cond in inner_conditions: result = self._evaluate(inner_cond, depth + 1) inner_results.append(result) if result: self.log.info("%sOr: any(%r) -> True", log_indent, inner_results) return True self.log.info("%sOr: any(%r) -> False", log_indent, inner_results) return False ``` ########## providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py: ########## @@ -1991,3 +1996,262 @@ def execute(self, context): self.hook.conn.get_waiter("notebook_instance_in_service").wait( NotebookInstanceName=self.instance_name ) + + +class SageMakerConditionOperator(BaseBranchOperator): + """ + Evaluates a single condition or a list of conditions, and routes tasks based on the result. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerConditionOperator` + + :param condition_type: Condition type for the simple (flat) interface. + Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``, + ``LessThan``, ``LessThanOrEqualTo``, ``In``. + Mutually exclusive with ``conditions``. + :param left_value: Left operand for the flat interface. For ``In`` conditions + this is the value to check membership of. + :param right_value: Right operand for the flat interface. For ``In`` conditions + this is the list of allowed values. + :param conditions: List of condition dicts to evaluate (AND-ed together). + Each dict must have a ``type`` key. Minimum 1, maximum 200. Review Comment: The docstring states the `conditions` list is "maximum 200", but `__init__` only validates non-empty and does not enforce an upper bound. Either add a length check (and a clear error) or remove the documented maximum to avoid misleading users. ```suggestion Each dict must have a ``type`` key. Minimum 1. ``` ########## providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py: ########## @@ -1991,3 +1996,262 @@ def execute(self, context): self.hook.conn.get_waiter("notebook_instance_in_service").wait( NotebookInstanceName=self.instance_name ) + + +class SageMakerConditionOperator(BaseBranchOperator): + """ + Evaluates a single condition or a list of conditions, and routes tasks based on the result. + Review Comment: PR description/title mentions adding a `SageMakerFailOperator`, but this change set only introduces `SageMakerConditionOperator` (plus its tests/docs). If `SageMakerFailOperator` is still intended, it looks missing from the implementation/tests/docs; otherwise, the PR description/title should be updated to match the actual changes. ########## providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py: ########## @@ -1991,3 +1996,262 @@ def execute(self, context): self.hook.conn.get_waiter("notebook_instance_in_service").wait( NotebookInstanceName=self.instance_name ) + + +class SageMakerConditionOperator(BaseBranchOperator): + """ + Evaluates a single condition or a list of conditions, and routes tasks based on the result. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerConditionOperator` + + :param condition_type: Condition type for the simple (flat) interface. + Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``, + ``LessThan``, ``LessThanOrEqualTo``, ``In``. + Mutually exclusive with ``conditions``. + :param left_value: Left operand for the flat interface. For ``In`` conditions + this is the value to check membership of. + :param right_value: Right operand for the flat interface. For ``In`` conditions + this is the list of allowed values. + :param conditions: List of condition dicts to evaluate (AND-ed together). + Each dict must have a ``type`` key. Minimum 1, maximum 200. + Mutually exclusive with ``condition_type``/``left_value``/``right_value``. + :param if_task_ids: Task ID(s) to execute when all conditions are True. + :param else_task_ids: Task ID(s) to execute when any condition is False. + If omitted, the task fails with ``AirflowFailException`` when conditions are not met. + """ + + _VALID_FLAT_TYPES: ClassVar[set[str]] = { + "Equals", + "GreaterThan", + "GreaterThanOrEqualTo", + "LessThan", + "LessThanOrEqualTo", + "In", + } + + template_fields: Sequence[str] = ( + "condition_type", + "left_value", + "right_value", + "conditions", + "if_task_ids", + "else_task_ids", + ) + + def __init__( + self, + *, + condition_type: str | None = None, + left_value: Any = None, + right_value: Any = None, + conditions: list[dict] | None = None, + if_task_ids: str | list[str], + else_task_ids: str | list[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + has_flat = condition_type is not None + has_list = conditions is not None + + if has_flat and has_list: + raise ValueError( + "Cannot use 'condition_type' and 'conditions' together. " + "Use 'condition_type' with 'left_value'/'right_value' for a single condition, " + "or 'conditions' for multiple/nested conditions." + ) + if not has_flat and not has_list: + raise ValueError( + "Missing condition: provide 'condition_type' with 'left_value'/'right_value' " + "for a single condition, or 'conditions' for multiple/nested conditions." + ) + + if has_flat: + if condition_type not in self._VALID_FLAT_TYPES: + raise ValueError( + f"Unknown condition_type '{condition_type}'. " + f"Expected one of: {', '.join(sorted(self._VALID_FLAT_TYPES))}." + ) + self.condition_type: str | None = condition_type + self.left_value = left_value + self.right_value = right_value + if condition_type == "In": + self.conditions: list[dict[str, Any]] = [ + {"type": "In", "value": left_value, "in_values": right_value} + ] + else: + self.conditions = [ + {"type": condition_type, "left_value": left_value, "right_value": right_value} + ] + else: + self.condition_type = None + self.left_value = None + self.right_value = None + self.conditions = conditions # type: ignore[assignment] + + if not self.conditions: + raise ValueError("At least 1 condition is required, but got an empty list.") + self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else if_task_ids + self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str) else (else_task_ids or []) + + @staticmethod + def _cast(value: Any) -> Any: + """ + Cast Jinja-rendered string values to appropriate Python types. + + This is a compatibility shim for environments where + ``render_template_as_native_obj=True`` is not available at the DAG or + task level (e.g., YAML DAGs). Once task-level native rendering + is widely supported, this method can be removed in favor of letting + Airflow handle the casting natively. + + - Numeric strings become int or float. + - ``"true"``/``"false"`` become booleans. + - ``"None"`` becomes ``None`` (common when ``xcom_pull`` returns nothing). + - Other strings are returned unchanged. + - Non-string types pass through as-is. + """ + if not isinstance(value, str): + return value + if value == "None": + return None + try: + return int(value) + except (ValueError, TypeError): + pass + try: + return float(value) + except (ValueError, TypeError): + pass + if value.lower() == "true": + return True + if value.lower() == "false": + return False + return value + + _COMPARISON_OPERATORS: ClassVar[dict[str, Callable[[Any, Any], bool]]] = { + "Equals": lambda left, right: left == right, + "GreaterThan": lambda left, right: left > right, + "GreaterThanOrEqualTo": lambda left, right: left >= right, + "LessThan": lambda left, right: left < right, + "LessThanOrEqualTo": lambda left, right: left <= right, + } + + def _evaluate(self, condition: dict, depth: int = 0) -> bool: + """ + Recursively evaluate a single condition dict. + + :param condition: A condition dictionary with a ``type`` key. + :param depth: Current nesting depth (used for log indentation only). + :returns: Boolean result of the condition evaluation. + """ + log_indent = " " * depth + try: + condition_type = condition["type"] + except KeyError: + raise ValueError("Condition dict is missing required key 'type'.") + + if condition_type in self._COMPARISON_OPERATORS: + try: + left = self._cast(condition["left_value"]) + right = self._cast(condition["right_value"]) + except KeyError as e: + raise ValueError(f"Condition '{condition_type}' missing required key {e}.") from None + + # None check — likely an XCom that was not pushed + if left is None or right is None: + raise TypeError( + f"Condition '{condition_type}' received None: left={left!r}, right={right!r}. " + "This usually means the upstream task did not run or did not push a value to XCom." + ) + + # Type compatibility check + left_type = type(left) + right_type = type(right) + numeric_types = (int, float) + left_is_numeric = isinstance(left, numeric_types) and not isinstance(left, bool) + right_is_numeric = isinstance(right, numeric_types) and not isinstance(right, bool) + + if not (left_is_numeric and right_is_numeric) and left_type is not right_type: + raise TypeError( + f"Cannot compare {left_type.__name__} ({left!r}) with {right_type.__name__} ({right!r}) " + f"in condition '{condition_type}'. Both values must be the same type." + ) + + comparison_result = self._COMPARISON_OPERATORS[condition_type](left, right) + self.log.info( + "%s%s: %r %s %r -> %s", + log_indent, + condition_type, + left, + condition_type, + right, + comparison_result, + ) + return comparison_result + + if condition_type == "In": + try: + query_value = self._cast(condition["value"]) + allowed_values = [self._cast(val) for val in condition["in_values"]] + except KeyError as e: + raise ValueError(f"Condition '{condition_type}' missing required key {e}.") from None + membership_result = query_value in allowed_values + self.log.info("%sIn: %r in %r -> %s", log_indent, query_value, allowed_values, membership_result) + return membership_result + + if condition_type == "Not": + try: + inner_condition = condition["condition"] + except KeyError as e: + raise ValueError(f"Condition '{condition_type}' missing required key {e}.") from None + inner_result = self._evaluate(inner_condition, depth + 1) + negated_result = not inner_result + self.log.info("%sNot: not %s -> %s", log_indent, inner_result, negated_result) + return negated_result + + if condition_type == "Or": + try: + inner_conditions = condition["conditions"] + except KeyError as e: + raise ValueError(f"Condition '{condition_type}' missing required key {e}.") from None + inner_results = [self._evaluate(inner_cond, depth + 1) for inner_cond in inner_conditions] + or_result = any(inner_results) + self.log.info("%sOr: any(%r) -> %s", log_indent, inner_results, or_result) + return or_result + + raise ValueError(f"Unknown condition type '{condition_type}'.") + + def choose_branch(self, context: Context) -> list[str]: + """ + Evaluate all conditions and return the appropriate branch task IDs. + + :param context: Airflow context dictionary. + :returns: ``if_task_ids`` when all conditions are True, ``else_task_ids`` otherwise. + """ + condition_count = len(self.conditions) + self.log.info("Evaluating %d condition(s).", condition_count) + + evaluation_results = [self._evaluate(condition) for condition in self.conditions] + all_conditions_met = all(evaluation_results) + Review Comment: Top-level condition evaluation builds a full `evaluation_results` list before calling `all(...)`. This means all conditions are evaluated even after one is already `False`, which can raise exceptions (e.g., missing XCom / type mismatch) that would otherwise be irrelevant to routing to the else branch. Consider evaluating conditions in a loop with early exit on the first `False` (and/or making exception behavior explicit). ```suggestion evaluation_results: list[bool] = [] all_conditions_met = True for condition in self.conditions: result = self._evaluate(condition) evaluation_results.append(result) if not result: all_conditions_met = False break ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
