bhavya2109sharma commented on code in PR #64545:
URL: https://github.com/apache/airflow/pull/64545#discussion_r3025369309


##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1991,3 +1996,262 @@ def execute(self, context):
             self.hook.conn.get_waiter("notebook_instance_in_service").wait(
                 NotebookInstanceName=self.instance_name
             )
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+    """
+    Evaluates a single condition or a list of conditions, and routes tasks 
based on the result.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:SageMakerConditionOperator`
+
+    :param condition_type: Condition type for the simple (flat) interface.
+        Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+        ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+        Mutually exclusive with ``conditions``.
+    :param left_value: Left operand for the flat interface. For ``In`` 
conditions
+        this is the value to check membership of.
+    :param right_value: Right operand for the flat interface. For ``In`` 
conditions
+        this is the list of allowed values.
+    :param conditions: List of condition dicts to evaluate (AND-ed together).
+        Each dict must have a ``type`` key. Minimum 1, maximum 200.
+        Mutually exclusive with 
``condition_type``/``left_value``/``right_value``.
+    :param if_task_ids: Task ID(s) to execute when all conditions are True.
+    :param else_task_ids: Task ID(s) to execute when any condition is False.
+        If omitted, the task fails with ``AirflowFailException`` when 
conditions are not met.
+    """
+
+    _VALID_FLAT_TYPES: ClassVar[set[str]] = {
+        "Equals",
+        "GreaterThan",
+        "GreaterThanOrEqualTo",
+        "LessThan",
+        "LessThanOrEqualTo",
+        "In",
+    }
+
+    template_fields: Sequence[str] = (
+        "condition_type",
+        "left_value",
+        "right_value",
+        "conditions",
+        "if_task_ids",
+        "else_task_ids",
+    )
+
+    def __init__(
+        self,
+        *,
+        condition_type: str | None = None,
+        left_value: Any = None,
+        right_value: Any = None,
+        conditions: list[dict] | None = None,
+        if_task_ids: str | list[str],
+        else_task_ids: str | list[str] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        has_flat = condition_type is not None
+        has_list = conditions is not None
+
+        if has_flat and has_list:
+            raise ValueError(
+                "Cannot use 'condition_type' and 'conditions' together. "
+                "Use 'condition_type' with 'left_value'/'right_value' for a 
single condition, "
+                "or 'conditions' for multiple/nested conditions."
+            )
+        if not has_flat and not has_list:
+            raise ValueError(
+                "Missing condition: provide 'condition_type' with 
'left_value'/'right_value' "
+                "for a single condition, or 'conditions' for multiple/nested 
conditions."
+            )
+
+        if has_flat:
+            if condition_type not in self._VALID_FLAT_TYPES:
+                raise ValueError(
+                    f"Unknown condition_type '{condition_type}'. "
+                    f"Expected one of: {', 
'.join(sorted(self._VALID_FLAT_TYPES))}."
+                )
+            self.condition_type: str | None = condition_type
+            self.left_value = left_value
+            self.right_value = right_value
+            if condition_type == "In":
+                self.conditions: list[dict[str, Any]] = [
+                    {"type": "In", "value": left_value, "in_values": 
right_value}
+                ]
+            else:
+                self.conditions = [
+                    {"type": condition_type, "left_value": left_value, 
"right_value": right_value}
+                ]
+        else:
+            self.condition_type = None
+            self.left_value = None
+            self.right_value = None
+            self.conditions = conditions  # type: ignore[assignment]
+
+        if not self.conditions:
+            raise ValueError("At least 1 condition is required, but got an 
empty list.")
+        self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else 
if_task_ids
+        self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str) 
else (else_task_ids or [])
+
+    @staticmethod
+    def _cast(value: Any) -> Any:
+        """
+        Cast Jinja-rendered string values to appropriate Python types.
+
+        This is a compatibility shim for environments where
+        ``render_template_as_native_obj=True`` is not available at the DAG or
+        task level (e.g., YAML DAGs). Once task-level native rendering
+        is widely supported, this method can be removed in favor of letting
+        Airflow handle the casting natively.
+
+        - Numeric strings become int or float.
+        - ``"true"``/``"false"`` become booleans.
+        - ``"None"`` becomes ``None`` (common when ``xcom_pull`` returns 
nothing).
+        - Other strings are returned unchanged.
+        - Non-string types pass through as-is.
+        """
+        if not isinstance(value, str):
+            return value
+        if value == "None":
+            return None
+        try:
+            return int(value)
+        except (ValueError, TypeError):
+            pass
+        try:
+            return float(value)
+        except (ValueError, TypeError):
+            pass
+        if value.lower() == "true":
+            return True
+        if value.lower() == "false":
+            return False
+        return value
+
+    _COMPARISON_OPERATORS: ClassVar[dict[str, Callable[[Any, Any], bool]]] = {
+        "Equals": lambda left, right: left == right,
+        "GreaterThan": lambda left, right: left > right,
+        "GreaterThanOrEqualTo": lambda left, right: left >= right,
+        "LessThan": lambda left, right: left < right,
+        "LessThanOrEqualTo": lambda left, right: left <= right,
+    }
+
+    def _evaluate(self, condition: dict, depth: int = 0) -> bool:
+        """
+        Recursively evaluate a single condition dict.
+
+        :param condition: A condition dictionary with a ``type`` key.
+        :param depth: Current nesting depth (used for log indentation only).
+        :returns: Boolean result of the condition evaluation.
+        """
+        log_indent = "  " * depth
+        try:
+            condition_type = condition["type"]
+        except KeyError:
+            raise ValueError("Condition dict is missing required key 'type'.")
+
+        if condition_type in self._COMPARISON_OPERATORS:
+            try:
+                left = self._cast(condition["left_value"])
+                right = self._cast(condition["right_value"])
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+
+            # None check — likely an XCom that was not pushed
+            if left is None or right is None:
+                raise TypeError(
+                    f"Condition '{condition_type}' received None: 
left={left!r}, right={right!r}. "
+                    "This usually means the upstream task did not run or did 
not push a value to XCom."
+                )
+
+            # Type compatibility check
+            left_type = type(left)
+            right_type = type(right)
+            numeric_types = (int, float)
+            left_is_numeric = isinstance(left, numeric_types) and not 
isinstance(left, bool)
+            right_is_numeric = isinstance(right, numeric_types) and not 
isinstance(right, bool)
+
+            if not (left_is_numeric and right_is_numeric) and left_type is not 
right_type:
+                raise TypeError(
+                    f"Cannot compare {left_type.__name__} ({left!r}) with 
{right_type.__name__} ({right!r}) "
+                    f"in condition '{condition_type}'. Both values must be the 
same type."
+                )
+
+            comparison_result = 
self._COMPARISON_OPERATORS[condition_type](left, right)
+            self.log.info(
+                "%s%s: %r %s %r -> %s",
+                log_indent,
+                condition_type,
+                left,
+                condition_type,
+                right,
+                comparison_result,
+            )
+            return comparison_result
+
+        if condition_type == "In":
+            try:
+                query_value = self._cast(condition["value"])
+                allowed_values = [self._cast(val) for val in 
condition["in_values"]]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            membership_result = query_value in allowed_values
+            self.log.info("%sIn: %r in %r -> %s", log_indent, query_value, 
allowed_values, membership_result)
+            return membership_result
+
+        if condition_type == "Not":
+            try:
+                inner_condition = condition["condition"]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            inner_result = self._evaluate(inner_condition, depth + 1)
+            negated_result = not inner_result
+            self.log.info("%sNot: not %s -> %s", log_indent, inner_result, 
negated_result)
+            return negated_result
+
+        if condition_type == "Or":
+            try:
+                inner_conditions = condition["conditions"]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            inner_results = [self._evaluate(inner_cond, depth + 1) for 
inner_cond in inner_conditions]
+            or_result = any(inner_results)
+            self.log.info("%sOr: any(%r) -> %s", log_indent, inner_results, 
or_result)
+            return or_result
+
+        raise ValueError(f"Unknown condition type '{condition_type}'.")
+
+    def choose_branch(self, context: Context) -> list[str]:
+        """
+        Evaluate all conditions and return the appropriate branch task IDs.
+
+        :param context: Airflow context dictionary.
+        :returns: ``if_task_ids`` when all conditions are True, 
``else_task_ids`` otherwise.
+        """
+        condition_count = len(self.conditions)
+        self.log.info("Evaluating %d condition(s).", condition_count)
+
+        evaluation_results = [self._evaluate(condition) for condition in 
self.conditions]
+        all_conditions_met = all(evaluation_results)
+

Review Comment:
   evaluating all conditions provides complete diagnostic output in the logs 
which will help users debug which specific conditions failed.



##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1991,3 +1996,262 @@ def execute(self, context):
             self.hook.conn.get_waiter("notebook_instance_in_service").wait(
                 NotebookInstanceName=self.instance_name
             )
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+    """
+    Evaluates a single condition or a list of conditions, and routes tasks 
based on the result.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:SageMakerConditionOperator`
+
+    :param condition_type: Condition type for the simple (flat) interface.
+        Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+        ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+        Mutually exclusive with ``conditions``.
+    :param left_value: Left operand for the flat interface. For ``In`` 
conditions
+        this is the value to check membership of.
+    :param right_value: Right operand for the flat interface. For ``In`` 
conditions
+        this is the list of allowed values.
+    :param conditions: List of condition dicts to evaluate (AND-ed together).
+        Each dict must have a ``type`` key. Minimum 1, maximum 200.
+        Mutually exclusive with 
``condition_type``/``left_value``/``right_value``.
+    :param if_task_ids: Task ID(s) to execute when all conditions are True.
+    :param else_task_ids: Task ID(s) to execute when any condition is False.
+        If omitted, the task fails with ``AirflowFailException`` when 
conditions are not met.
+    """
+
+    _VALID_FLAT_TYPES: ClassVar[set[str]] = {
+        "Equals",
+        "GreaterThan",
+        "GreaterThanOrEqualTo",
+        "LessThan",
+        "LessThanOrEqualTo",
+        "In",
+    }
+
+    template_fields: Sequence[str] = (
+        "condition_type",
+        "left_value",
+        "right_value",
+        "conditions",
+        "if_task_ids",
+        "else_task_ids",
+    )
+
+    def __init__(
+        self,
+        *,
+        condition_type: str | None = None,
+        left_value: Any = None,
+        right_value: Any = None,
+        conditions: list[dict] | None = None,
+        if_task_ids: str | list[str],
+        else_task_ids: str | list[str] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        has_flat = condition_type is not None
+        has_list = conditions is not None
+
+        if has_flat and has_list:
+            raise ValueError(
+                "Cannot use 'condition_type' and 'conditions' together. "
+                "Use 'condition_type' with 'left_value'/'right_value' for a 
single condition, "
+                "or 'conditions' for multiple/nested conditions."
+            )
+        if not has_flat and not has_list:
+            raise ValueError(
+                "Missing condition: provide 'condition_type' with 
'left_value'/'right_value' "
+                "for a single condition, or 'conditions' for multiple/nested 
conditions."
+            )
+
+        if has_flat:
+            if condition_type not in self._VALID_FLAT_TYPES:
+                raise ValueError(
+                    f"Unknown condition_type '{condition_type}'. "
+                    f"Expected one of: {', 
'.join(sorted(self._VALID_FLAT_TYPES))}."
+                )
+            self.condition_type: str | None = condition_type
+            self.left_value = left_value
+            self.right_value = right_value
+            if condition_type == "In":
+                self.conditions: list[dict[str, Any]] = [
+                    {"type": "In", "value": left_value, "in_values": 
right_value}
+                ]
+            else:
+                self.conditions = [
+                    {"type": condition_type, "left_value": left_value, 
"right_value": right_value}
+                ]
+        else:
+            self.condition_type = None
+            self.left_value = None
+            self.right_value = None
+            self.conditions = conditions  # type: ignore[assignment]
+
+        if not self.conditions:
+            raise ValueError("At least 1 condition is required, but got an 
empty list.")
+        self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else 
if_task_ids
+        self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str) 
else (else_task_ids or [])
+
+    @staticmethod
+    def _cast(value: Any) -> Any:
+        """
+        Cast Jinja-rendered string values to appropriate Python types.
+
+        This is a compatibility shim for environments where
+        ``render_template_as_native_obj=True`` is not available at the DAG or
+        task level (e.g., YAML DAGs). Once task-level native rendering
+        is widely supported, this method can be removed in favor of letting
+        Airflow handle the casting natively.
+
+        - Numeric strings become int or float.
+        - ``"true"``/``"false"`` become booleans.
+        - ``"None"`` becomes ``None`` (common when ``xcom_pull`` returns 
nothing).
+        - Other strings are returned unchanged.
+        - Non-string types pass through as-is.
+        """
+        if not isinstance(value, str):
+            return value
+        if value == "None":
+            return None
+        try:
+            return int(value)
+        except (ValueError, TypeError):
+            pass
+        try:
+            return float(value)
+        except (ValueError, TypeError):
+            pass
+        if value.lower() == "true":
+            return True
+        if value.lower() == "false":
+            return False
+        return value
+
+    _COMPARISON_OPERATORS: ClassVar[dict[str, Callable[[Any, Any], bool]]] = {
+        "Equals": lambda left, right: left == right,
+        "GreaterThan": lambda left, right: left > right,
+        "GreaterThanOrEqualTo": lambda left, right: left >= right,
+        "LessThan": lambda left, right: left < right,
+        "LessThanOrEqualTo": lambda left, right: left <= right,
+    }
+
+    def _evaluate(self, condition: dict, depth: int = 0) -> bool:
+        """
+        Recursively evaluate a single condition dict.
+
+        :param condition: A condition dictionary with a ``type`` key.
+        :param depth: Current nesting depth (used for log indentation only).
+        :returns: Boolean result of the condition evaluation.
+        """
+        log_indent = "  " * depth
+        try:
+            condition_type = condition["type"]
+        except KeyError:
+            raise ValueError("Condition dict is missing required key 'type'.")
+
+        if condition_type in self._COMPARISON_OPERATORS:
+            try:
+                left = self._cast(condition["left_value"])
+                right = self._cast(condition["right_value"])
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+
+            # None check — likely an XCom that was not pushed
+            if left is None or right is None:
+                raise TypeError(
+                    f"Condition '{condition_type}' received None: 
left={left!r}, right={right!r}. "
+                    "This usually means the upstream task did not run or did 
not push a value to XCom."
+                )
+
+            # Type compatibility check
+            left_type = type(left)
+            right_type = type(right)
+            numeric_types = (int, float)
+            left_is_numeric = isinstance(left, numeric_types) and not 
isinstance(left, bool)
+            right_is_numeric = isinstance(right, numeric_types) and not 
isinstance(right, bool)
+
+            if not (left_is_numeric and right_is_numeric) and left_type is not 
right_type:
+                raise TypeError(
+                    f"Cannot compare {left_type.__name__} ({left!r}) with 
{right_type.__name__} ({right!r}) "
+                    f"in condition '{condition_type}'. Both values must be the 
same type."
+                )
+
+            comparison_result = 
self._COMPARISON_OPERATORS[condition_type](left, right)
+            self.log.info(
+                "%s%s: %r %s %r -> %s",
+                log_indent,
+                condition_type,
+                left,
+                condition_type,
+                right,
+                comparison_result,
+            )
+            return comparison_result
+
+        if condition_type == "In":
+            try:
+                query_value = self._cast(condition["value"])
+                allowed_values = [self._cast(val) for val in 
condition["in_values"]]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            membership_result = query_value in allowed_values
+            self.log.info("%sIn: %r in %r -> %s", log_indent, query_value, 
allowed_values, membership_result)
+            return membership_result
+
+        if condition_type == "Not":
+            try:
+                inner_condition = condition["condition"]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            inner_result = self._evaluate(inner_condition, depth + 1)
+            negated_result = not inner_result
+            self.log.info("%sNot: not %s -> %s", log_indent, inner_result, 
negated_result)
+            return negated_result
+
+        if condition_type == "Or":
+            try:
+                inner_conditions = condition["conditions"]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            inner_results = [self._evaluate(inner_cond, depth + 1) for 
inner_cond in inner_conditions]
+            or_result = any(inner_results)
+            self.log.info("%sOr: any(%r) -> %s", log_indent, inner_results, 
or_result)
+            return or_result

Review Comment:
   evaluating all conditions provides complete diagnostic output in the logs 
which will help users debug which specific conditions failed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to