bhavya2109sharma commented on code in PR #64545:
URL: https://github.com/apache/airflow/pull/64545#discussion_r3019184488
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1991,3 +1997,270 @@ def execute(self, context):
self.hook.conn.get_waiter("notebook_instance_in_service").wait(
NotebookInstanceName=self.instance_name
)
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+ """
+ Evaluates a single condition or a list of conditions, and routes tasks
based on the result.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerConditionOperator`
+
+ :param condition_type: Condition type for the simple (flat) interface.
+ Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+ ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+ Mutually exclusive with ``conditions``.
+ :param left: Left operand for the flat interface. For ``In`` conditions
+ this is the value to check membership of.
+ :param right: Right operand for the flat interface. For ``In`` conditions
+ this is the list of allowed values.
+ :param conditions: List of condition dicts to evaluate (AND-ed together).
+ Each dict must have a ``type`` key. Minimum 1, maximum 200.
+ Mutually exclusive with ``condition_type``/``left``/``right``.
+ :param if_task_ids: Task ID(s) to execute when all conditions are True.
+ :param else_task_ids: Task ID(s) to execute when any condition is False.
+ """
+
+ _VALID_FLAT_TYPES: ClassVar[set[str]] = {
+ "Equals",
+ "GreaterThan",
+ "GreaterThanOrEqualTo",
+ "LessThan",
+ "LessThanOrEqualTo",
+ "In",
+ }
+
+ template_fields: Sequence[str] = (
+ "condition_type",
+ "left",
+ "right",
+ "conditions",
+ "if_task_ids",
+ "else_task_ids",
+ )
+
+ def __init__(
+ self,
+ *,
+ condition_type: str | None = None,
+ left: Any = None,
+ right: Any = None,
+ conditions: list[dict] | None = None,
+ if_task_ids: str | list[str],
+ else_task_ids: str | list[str],
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ has_flat = condition_type is not None
+ has_list = conditions is not None
+
+ if has_flat and has_list:
+ raise ValueError(
+ "Cannot use 'condition_type' and 'conditions' together. "
+ "Use 'condition_type' with 'left'/'right' for a single
condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+ if not has_flat and not has_list:
+ raise ValueError(
+ "Missing condition: provide 'condition_type' with
'left'/'right' for a single condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+
+ if has_flat:
+ if condition_type not in self._VALID_FLAT_TYPES:
+ raise ValueError(
+ f"Unknown condition_type '{condition_type}'. "
+ f"Expected one of: {',
'.join(sorted(self._VALID_FLAT_TYPES))}."
+ )
+ self.condition_type = condition_type
+ self.left = left
+ self.right = right
+ if condition_type == "In":
+ self.conditions = [{"type": "In", "value": left, "in_values":
right}]
+ else:
+ self.conditions = [{"type": condition_type, "left": left,
"right": right}]
+ else:
+ self.condition_type = None
+ self.left = None
+ self.right = None
+ self.conditions = conditions
+
+ if not self.conditions:
+ raise ValueError("At least 1 condition is required, but got an
empty list.")
+ if len(self.conditions) > 200:
+ raise ValueError(f"At most 200 conditions are allowed, but got
{len(self.conditions)}.")
+ self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else
if_task_ids
+ self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str)
else else_task_ids
+
+ @staticmethod
+ def _cast(value: Any) -> Any:
+ """
+ Cast Jinja-rendered string values to appropriate Python types.
+
+ - Numeric strings become int or float.
+ - ``"true"``/``"false"`` become booleans.
+ - Other strings are returned unchanged.
+ - Non-string types pass through as-is.
+ """
+ if not isinstance(value, str):
+ return value
+ if value == "None":
+ return None
+ try:
+ return int(value)
+ except (ValueError, TypeError):
+ pass
+ try:
+ return float(value)
+ except (ValueError, TypeError):
+ pass
+ if value.lower() == "true":
+ return True
+ if value.lower() == "false":
+ return False
+ return value
+
+ _COMPARISON_OPERATORS: ClassVar[dict[str, Callable[[Any, Any], bool]]] = {
+ "Equals": lambda left, right: left == right,
+ "GreaterThan": lambda left, right: left > right,
+ "GreaterThanOrEqualTo": lambda left, right: left >= right,
+ "LessThan": lambda left, right: left < right,
+ "LessThanOrEqualTo": lambda left, right: left <= right,
+ }
+
+ def _evaluate(self, condition: dict, depth: int = 0) -> bool:
+ """
+ Recursively evaluate a single condition dict.
+
+ :param condition: A condition dictionary with a ``type`` key.
+ :param depth: Current nesting depth (used for log indentation only).
+ :returns: Boolean result of the condition evaluation.
+ """
+ indent = " " * depth
+ try:
+ cond_type = condition["type"]
+ except KeyError:
+ raise ValueError("Condition dict is missing required key 'type'.")
+
+ if cond_type in self._COMPARISON_OPERATORS:
+ try:
+ left = self._cast(condition["left"])
+ right = self._cast(condition["right"])
+ except KeyError as e:
+ raise ValueError(f"Condition '{cond_type}' missing required
key {e}.") from None
+
+ # None check — likely an XCom that was not pushed
+ if left is None or right is None:
+ raise TypeError(
+ f"Condition '{cond_type}' received None: left={left!r},
right={right!r}. "
+ "This usually means the upstream task did not run or did
not push a value to XCom."
+ )
+
+ # Type compatibility check
+ left_type = type(left)
+ right_type = type(right)
+ numeric_types = (int, float)
+ left_is_numeric = isinstance(left, numeric_types) and not
isinstance(left, bool)
+ right_is_numeric = isinstance(right, numeric_types) and not
isinstance(right, bool)
+
+ if left_is_numeric and right_is_numeric:
+ pass
Review Comment:
no.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]