bhavya2109sharma commented on code in PR #64545:
URL: https://github.com/apache/airflow/pull/64545#discussion_r3019175177
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1991,3 +1997,270 @@ def execute(self, context):
self.hook.conn.get_waiter("notebook_instance_in_service").wait(
NotebookInstanceName=self.instance_name
)
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+ """
+ Evaluates a single condition or a list of conditions, and routes tasks
based on the result.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerConditionOperator`
+
+ :param condition_type: Condition type for the simple (flat) interface.
+ Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+ ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+ Mutually exclusive with ``conditions``.
+ :param left: Left operand for the flat interface. For ``In`` conditions
+ this is the value to check membership of.
+ :param right: Right operand for the flat interface. For ``In`` conditions
+ this is the list of allowed values.
+ :param conditions: List of condition dicts to evaluate (AND-ed together).
+ Each dict must have a ``type`` key. Minimum 1, maximum 200.
+ Mutually exclusive with ``condition_type``/``left``/``right``.
+ :param if_task_ids: Task ID(s) to execute when all conditions are True.
+ :param else_task_ids: Task ID(s) to execute when any condition is False.
+ """
+
+ _VALID_FLAT_TYPES: ClassVar[set[str]] = {
+ "Equals",
+ "GreaterThan",
+ "GreaterThanOrEqualTo",
+ "LessThan",
+ "LessThanOrEqualTo",
+ "In",
+ }
+
+ template_fields: Sequence[str] = (
+ "condition_type",
+ "left",
+ "right",
+ "conditions",
+ "if_task_ids",
+ "else_task_ids",
+ )
+
+ def __init__(
+ self,
+ *,
+ condition_type: str | None = None,
+ left: Any = None,
+ right: Any = None,
+ conditions: list[dict] | None = None,
+ if_task_ids: str | list[str],
+ else_task_ids: str | list[str],
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ has_flat = condition_type is not None
+ has_list = conditions is not None
+
+ if has_flat and has_list:
+ raise ValueError(
+ "Cannot use 'condition_type' and 'conditions' together. "
+ "Use 'condition_type' with 'left'/'right' for a single
condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+ if not has_flat and not has_list:
+ raise ValueError(
+ "Missing condition: provide 'condition_type' with
'left'/'right' for a single condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+
+ if has_flat:
+ if condition_type not in self._VALID_FLAT_TYPES:
+ raise ValueError(
+ f"Unknown condition_type '{condition_type}'. "
+ f"Expected one of: {',
'.join(sorted(self._VALID_FLAT_TYPES))}."
+ )
+ self.condition_type = condition_type
+ self.left = left
+ self.right = right
+ if condition_type == "In":
+ self.conditions = [{"type": "In", "value": left, "in_values":
right}]
+ else:
+ self.conditions = [{"type": condition_type, "left": left,
"right": right}]
+ else:
+ self.condition_type = None
+ self.left = None
+ self.right = None
+ self.conditions = conditions
+
+ if not self.conditions:
+ raise ValueError("At least 1 condition is required, but got an
empty list.")
+ if len(self.conditions) > 200:
+ raise ValueError(f"At most 200 conditions are allowed, but got
{len(self.conditions)}.")
+ self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else
if_task_ids
+ self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str)
else else_task_ids
+
+ @staticmethod
+ def _cast(value: Any) -> Any:
+ """
+ Cast Jinja-rendered string values to appropriate Python types.
+
+ - Numeric strings become int or float.
+ - ``"true"``/``"false"`` become booleans.
+ - Other strings are returned unchanged.
+ - Non-string types pass through as-is.
+ """
+ if not isinstance(value, str):
+ return value
+ if value == "None":
+ return None
+ try:
+ return int(value)
+ except (ValueError, TypeError):
+ pass
+ try:
+ return float(value)
+ except (ValueError, TypeError):
+ pass
+ if value.lower() == "true":
+ return True
+ if value.lower() == "false":
+ return False
+ return value
+
+ _COMPARISON_OPERATORS: ClassVar[dict[str, Callable[[Any, Any], bool]]] = {
+ "Equals": lambda left, right: left == right,
+ "GreaterThan": lambda left, right: left > right,
+ "GreaterThanOrEqualTo": lambda left, right: left >= right,
+ "LessThan": lambda left, right: left < right,
+ "LessThanOrEqualTo": lambda left, right: left <= right,
+ }
+
+ def _evaluate(self, condition: dict, depth: int = 0) -> bool:
+ """
+ Recursively evaluate a single condition dict.
+
+ :param condition: A condition dictionary with a ``type`` key.
+ :param depth: Current nesting depth (used for log indentation only).
+ :returns: Boolean result of the condition evaluation.
+ """
+ indent = " " * depth
Review Comment:
True. The 200 limit matches SageMaker's backend constraint. We can drop the
upper limit entirely as all the comparison is done in-memory
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]