o-nikolas commented on code in PR #64545:
URL: https://github.com/apache/airflow/pull/64545#discussion_r3018737104
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1991,3 +1997,270 @@ def execute(self, context):
self.hook.conn.get_waiter("notebook_instance_in_service").wait(
NotebookInstanceName=self.instance_name
)
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+ """
+ Evaluates a single condition or a list of conditions, and routes tasks
based on the result.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerConditionOperator`
+
+ :param condition_type: Condition type for the simple (flat) interface.
+ Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+ ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+ Mutually exclusive with ``conditions``.
+ :param left: Left operand for the flat interface. For ``In`` conditions
+ this is the value to check membership of.
+ :param right: Right operand for the flat interface. For ``In`` conditions
+ this is the list of allowed values.
+ :param conditions: List of condition dicts to evaluate (AND-ed together).
+ Each dict must have a ``type`` key. Minimum 1, maximum 200.
+ Mutually exclusive with ``condition_type``/``left``/``right``.
+ :param if_task_ids: Task ID(s) to execute when all conditions are True.
+ :param else_task_ids: Task ID(s) to execute when any condition is False.
+ """
+
+ _VALID_FLAT_TYPES: ClassVar[set[str]] = {
+ "Equals",
+ "GreaterThan",
+ "GreaterThanOrEqualTo",
+ "LessThan",
+ "LessThanOrEqualTo",
+ "In",
+ }
+
+ template_fields: Sequence[str] = (
+ "condition_type",
+ "left",
+ "right",
+ "conditions",
+ "if_task_ids",
+ "else_task_ids",
+ )
+
+ def __init__(
+ self,
+ *,
+ condition_type: str | None = None,
+ left: Any = None,
+ right: Any = None,
+ conditions: list[dict] | None = None,
+ if_task_ids: str | list[str],
+ else_task_ids: str | list[str],
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ has_flat = condition_type is not None
+ has_list = conditions is not None
+
+ if has_flat and has_list:
+ raise ValueError(
+ "Cannot use 'condition_type' and 'conditions' together. "
+ "Use 'condition_type' with 'left'/'right' for a single
condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+ if not has_flat and not has_list:
+ raise ValueError(
+ "Missing condition: provide 'condition_type' with
'left'/'right' for a single condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+
+ if has_flat:
+ if condition_type not in self._VALID_FLAT_TYPES:
+ raise ValueError(
+ f"Unknown condition_type '{condition_type}'. "
+ f"Expected one of: {',
'.join(sorted(self._VALID_FLAT_TYPES))}."
+ )
+ self.condition_type = condition_type
+ self.left = left
+ self.right = right
+ if condition_type == "In":
+ self.conditions = [{"type": "In", "value": left, "in_values":
right}]
+ else:
+ self.conditions = [{"type": condition_type, "left": left,
"right": right}]
+ else:
+ self.condition_type = None
+ self.left = None
+ self.right = None
+ self.conditions = conditions
+
+ if not self.conditions:
+ raise ValueError("At least 1 condition is required, but got an
empty list.")
+ if len(self.conditions) > 200:
+ raise ValueError(f"At most 200 conditions are allowed, but got
{len(self.conditions)}.")
+ self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else
if_task_ids
+ self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str)
else else_task_ids
+
+ @staticmethod
+ def _cast(value: Any) -> Any:
Review Comment:
You can provide `render_template_as_native_obj=True` to your Dag to have
Airflow do the casting for you (and soon it can be on the task level
specifically). Let's remove this so we're not re-inventing the wheel.
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1991,3 +1997,270 @@ def execute(self, context):
self.hook.conn.get_waiter("notebook_instance_in_service").wait(
NotebookInstanceName=self.instance_name
)
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+ """
+ Evaluates a single condition or a list of conditions, and routes tasks
based on the result.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerConditionOperator`
+
+ :param condition_type: Condition type for the simple (flat) interface.
+ Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+ ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+ Mutually exclusive with ``conditions``.
+ :param left: Left operand for the flat interface. For ``In`` conditions
+ this is the value to check membership of.
+ :param right: Right operand for the flat interface. For ``In`` conditions
+ this is the list of allowed values.
+ :param conditions: List of condition dicts to evaluate (AND-ed together).
+ Each dict must have a ``type`` key. Minimum 1, maximum 200.
+ Mutually exclusive with ``condition_type``/``left``/``right``.
+ :param if_task_ids: Task ID(s) to execute when all conditions are True.
+ :param else_task_ids: Task ID(s) to execute when any condition is False.
+ """
+
+ _VALID_FLAT_TYPES: ClassVar[set[str]] = {
+ "Equals",
+ "GreaterThan",
+ "GreaterThanOrEqualTo",
+ "LessThan",
+ "LessThanOrEqualTo",
+ "In",
+ }
+
+ template_fields: Sequence[str] = (
+ "condition_type",
+ "left",
+ "right",
+ "conditions",
+ "if_task_ids",
+ "else_task_ids",
+ )
+
+ def __init__(
+ self,
+ *,
+ condition_type: str | None = None,
+ left: Any = None,
+ right: Any = None,
+ conditions: list[dict] | None = None,
+ if_task_ids: str | list[str],
+ else_task_ids: str | list[str],
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ has_flat = condition_type is not None
+ has_list = conditions is not None
+
+ if has_flat and has_list:
+ raise ValueError(
+ "Cannot use 'condition_type' and 'conditions' together. "
+ "Use 'condition_type' with 'left'/'right' for a single
condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+ if not has_flat and not has_list:
+ raise ValueError(
+ "Missing condition: provide 'condition_type' with
'left'/'right' for a single condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+
+ if has_flat:
+ if condition_type not in self._VALID_FLAT_TYPES:
+ raise ValueError(
+ f"Unknown condition_type '{condition_type}'. "
+ f"Expected one of: {',
'.join(sorted(self._VALID_FLAT_TYPES))}."
+ )
+ self.condition_type = condition_type
+ self.left = left
+ self.right = right
+ if condition_type == "In":
+ self.conditions = [{"type": "In", "value": left, "in_values":
right}]
+ else:
+ self.conditions = [{"type": condition_type, "left": left,
"right": right}]
+ else:
+ self.condition_type = None
+ self.left = None
+ self.right = None
+ self.conditions = conditions
+
+ if not self.conditions:
+ raise ValueError("At least 1 condition is required, but got an
empty list.")
+ if len(self.conditions) > 200:
+ raise ValueError(f"At most 200 conditions are allowed, but got
{len(self.conditions)}.")
+ self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else
if_task_ids
+ self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str)
else else_task_ids
+
+ @staticmethod
+ def _cast(value: Any) -> Any:
+ """
+ Cast Jinja-rendered string values to appropriate Python types.
+
+ - Numeric strings become int or float.
+ - ``"true"``/``"false"`` become booleans.
+ - Other strings are returned unchanged.
+ - Non-string types pass through as-is.
+ """
+ if not isinstance(value, str):
+ return value
+ if value == "None":
+ return None
+ try:
+ return int(value)
+ except (ValueError, TypeError):
+ pass
+ try:
+ return float(value)
+ except (ValueError, TypeError):
+ pass
+ if value.lower() == "true":
+ return True
+ if value.lower() == "false":
+ return False
+ return value
+
+ _COMPARISON_OPERATORS: ClassVar[dict[str, Callable[[Any, Any], bool]]] = {
+ "Equals": lambda left, right: left == right,
+ "GreaterThan": lambda left, right: left > right,
+ "GreaterThanOrEqualTo": lambda left, right: left >= right,
+ "LessThan": lambda left, right: left < right,
+ "LessThanOrEqualTo": lambda left, right: left <= right,
+ }
+
+ def _evaluate(self, condition: dict, depth: int = 0) -> bool:
+ """
+ Recursively evaluate a single condition dict.
+
+ :param condition: A condition dictionary with a ``type`` key.
+ :param depth: Current nesting depth (used for log indentation only).
+ :returns: Boolean result of the condition evaluation.
+ """
+ indent = " " * depth
+ try:
+ cond_type = condition["type"]
+ except KeyError:
+ raise ValueError("Condition dict is missing required key 'type'.")
+
+ if cond_type in self._COMPARISON_OPERATORS:
+ try:
+ left = self._cast(condition["left"])
+ right = self._cast(condition["right"])
+ except KeyError as e:
+ raise ValueError(f"Condition '{cond_type}' missing required
key {e}.") from None
+
+ # None check — likely an XCom that was not pushed
+ if left is None or right is None:
+ raise TypeError(
+ f"Condition '{cond_type}' received None: left={left!r},
right={right!r}. "
+ "This usually means the upstream task did not run or did
not push a value to XCom."
+ )
+
+ # Type compatibility check
+ left_type = type(left)
+ right_type = type(right)
+ numeric_types = (int, float)
+ left_is_numeric = isinstance(left, numeric_types) and not
isinstance(left, bool)
+ right_is_numeric = isinstance(right, numeric_types) and not
isinstance(right, bool)
+
+ if left_is_numeric and right_is_numeric:
+ pass
+ elif left_type is not right_type:
+ raise TypeError(
+ f"Cannot compare {left_type.__name__} ({left!r}) with
{right_type.__name__} ({right!r}) "
+ f"in condition '{cond_type}'. Both values must be the same
type."
+ )
+
+ result = self._COMPARISON_OPERATORS[cond_type](left, right)
+ self.log.info("%s%s: %r %s %r -> %s", indent, cond_type, left,
cond_type, right, result)
+ return result
+
+ if cond_type == "In":
+ try:
+ value = self._cast(condition["value"])
+ in_values = [self._cast(v) for v in condition["in_values"]]
+ except KeyError as e:
+ raise ValueError(f"Condition '{cond_type}' missing required
key {e}.") from None
+ result = value in in_values
+ self.log.info("%sIn: %r in %r -> %s", indent, value, in_values,
result)
+ return result
+
+ if cond_type == "Not":
+ try:
+ inner = condition["condition"]
+ except KeyError as e:
+ raise ValueError(f"Condition '{cond_type}' missing required
key {e}.") from None
+ inner_result = self._evaluate(inner, depth + 1)
+ result = not inner_result
+ self.log.info("%sNot: not %s -> %s", indent, inner_result, result)
+ return result
+
+ if cond_type == "Or":
+ try:
+ inner_conditions = condition["conditions"]
+ except KeyError as e:
+ raise ValueError(f"Condition '{cond_type}' missing required
key {e}.") from None
+ inner_results = [self._evaluate(c, depth + 1) for c in
inner_conditions]
+ result = any(inner_results)
+ self.log.info("%sOr: any(%r) -> %s", indent, inner_results, result)
+ return result
+
+ raise ValueError(f"Unknown condition type '{cond_type}'.")
+
+ def choose_branch(self, context: Context) -> list[str]:
+ """
+ Evaluate all conditions and return the appropriate branch task IDs.
+
+ :param context: Airflow context dictionary.
+ :returns: ``if_task_ids`` when all conditions are True,
``else_task_ids`` otherwise.
+ """
+ num_conditions = len(self.conditions)
+ self.log.info("Evaluating %d condition(s).", num_conditions)
+
+ results = [self._evaluate(c) for c in self.conditions]
+ all_true = all(results)
+
+ if all_true:
+ self.log.info(
+ "All %d condition(s) evaluated to True. Routing to
if_task_ids=%r.",
+ num_conditions,
+ self.if_task_ids,
+ )
+ return self.if_task_ids
+ self.log.info(
+ "Not all conditions are True (results=%r). Routing to
else_task_ids=%r.",
+ results,
+ self.else_task_ids,
+ )
+ return self.else_task_ids
+
+
+class SageMakerFailOperator(BaseOperator):
+ """
+ Terminate the workflow with a user-defined error message.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerFailOperator`
+
+ :param error_message: Failure reason passed to ``AirflowFailException``.
+ Defaults to ``""``.
+ """
+
+ template_fields: Sequence[str] = ("error_message",)
+
+ def __init__(self, *, error_message: str = "", **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.error_message = error_message
+
+ _ERROR_MESSAGE_MAX_LEN: ClassVar[int] = 3072
Review Comment:
Why are we enforcing a maximum length for the error message?
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1991,3 +1997,270 @@ def execute(self, context):
self.hook.conn.get_waiter("notebook_instance_in_service").wait(
NotebookInstanceName=self.instance_name
)
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+ """
+ Evaluates a single condition or a list of conditions, and routes tasks
based on the result.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerConditionOperator`
+
+ :param condition_type: Condition type for the simple (flat) interface.
+ Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+ ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+ Mutually exclusive with ``conditions``.
+ :param left: Left operand for the flat interface. For ``In`` conditions
+ this is the value to check membership of.
+ :param right: Right operand for the flat interface. For ``In`` conditions
+ this is the list of allowed values.
+ :param conditions: List of condition dicts to evaluate (AND-ed together).
+ Each dict must have a ``type`` key. Minimum 1, maximum 200.
+ Mutually exclusive with ``condition_type``/``left``/``right``.
+ :param if_task_ids: Task ID(s) to execute when all conditions are True.
+ :param else_task_ids: Task ID(s) to execute when any condition is False.
+ """
+
+ _VALID_FLAT_TYPES: ClassVar[set[str]] = {
+ "Equals",
+ "GreaterThan",
+ "GreaterThanOrEqualTo",
+ "LessThan",
+ "LessThanOrEqualTo",
+ "In",
+ }
+
+ template_fields: Sequence[str] = (
+ "condition_type",
+ "left",
+ "right",
+ "conditions",
+ "if_task_ids",
+ "else_task_ids",
+ )
+
+ def __init__(
+ self,
+ *,
+ condition_type: str | None = None,
+ left: Any = None,
+ right: Any = None,
+ conditions: list[dict] | None = None,
+ if_task_ids: str | list[str],
+ else_task_ids: str | list[str],
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ has_flat = condition_type is not None
+ has_list = conditions is not None
+
+ if has_flat and has_list:
+ raise ValueError(
+ "Cannot use 'condition_type' and 'conditions' together. "
+ "Use 'condition_type' with 'left'/'right' for a single
condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+ if not has_flat and not has_list:
+ raise ValueError(
+ "Missing condition: provide 'condition_type' with
'left'/'right' for a single condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+
+ if has_flat:
+ if condition_type not in self._VALID_FLAT_TYPES:
+ raise ValueError(
+ f"Unknown condition_type '{condition_type}'. "
+ f"Expected one of: {',
'.join(sorted(self._VALID_FLAT_TYPES))}."
+ )
+ self.condition_type = condition_type
+ self.left = left
+ self.right = right
+ if condition_type == "In":
+ self.conditions = [{"type": "In", "value": left, "in_values":
right}]
+ else:
+ self.conditions = [{"type": condition_type, "left": left,
"right": right}]
+ else:
+ self.condition_type = None
+ self.left = None
+ self.right = None
+ self.conditions = conditions
+
+ if not self.conditions:
+ raise ValueError("At least 1 condition is required, but got an
empty list.")
+ if len(self.conditions) > 200:
+ raise ValueError(f"At most 200 conditions are allowed, but got
{len(self.conditions)}.")
+ self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else
if_task_ids
+ self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str)
else else_task_ids
+
+ @staticmethod
+ def _cast(value: Any) -> Any:
+ """
+ Cast Jinja-rendered string values to appropriate Python types.
+
+ - Numeric strings become int or float.
+ - ``"true"``/``"false"`` become booleans.
+ - Other strings are returned unchanged.
+ - Non-string types pass through as-is.
+ """
+ if not isinstance(value, str):
+ return value
+ if value == "None":
+ return None
+ try:
+ return int(value)
+ except (ValueError, TypeError):
+ pass
+ try:
+ return float(value)
+ except (ValueError, TypeError):
+ pass
+ if value.lower() == "true":
+ return True
+ if value.lower() == "false":
+ return False
+ return value
+
+ _COMPARISON_OPERATORS: ClassVar[dict[str, Callable[[Any, Any], bool]]] = {
+ "Equals": lambda left, right: left == right,
+ "GreaterThan": lambda left, right: left > right,
+ "GreaterThanOrEqualTo": lambda left, right: left >= right,
+ "LessThan": lambda left, right: left < right,
+ "LessThanOrEqualTo": lambda left, right: left <= right,
+ }
+
+ def _evaluate(self, condition: dict, depth: int = 0) -> bool:
+ """
+ Recursively evaluate a single condition dict.
+
+ :param condition: A condition dictionary with a ``type`` key.
+ :param depth: Current nesting depth (used for log indentation only).
+ :returns: Boolean result of the condition evaluation.
+ """
+ indent = " " * depth
Review Comment:
If we allow up to 200 conditions this will soon become completely
unreadable, no?
##########
providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_fail.py:
##########
Review Comment:
I think there are too many tests in here, many branches are being tested
over and over by tests. It is very common for AI to write unit tests like this.
But in OSS, every line of code is a line that needs to be maintained and we
don't have a full roster of paid employees to do that (we have some, of
course). So if one or two of the more complicated tests in here are testing the
behaviour of the simpler ones, please remove the simpler ones entirely.
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1991,3 +1997,270 @@ def execute(self, context):
self.hook.conn.get_waiter("notebook_instance_in_service").wait(
NotebookInstanceName=self.instance_name
)
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+ """
+ Evaluates a single condition or a list of conditions, and routes tasks
based on the result.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerConditionOperator`
+
+ :param condition_type: Condition type for the simple (flat) interface.
+ Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+ ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+ Mutually exclusive with ``conditions``.
+ :param left: Left operand for the flat interface. For ``In`` conditions
+ this is the value to check membership of.
+ :param right: Right operand for the flat interface. For ``In`` conditions
+ this is the list of allowed values.
+ :param conditions: List of condition dicts to evaluate (AND-ed together).
+ Each dict must have a ``type`` key. Minimum 1, maximum 200.
+ Mutually exclusive with ``condition_type``/``left``/``right``.
+ :param if_task_ids: Task ID(s) to execute when all conditions are True.
+ :param else_task_ids: Task ID(s) to execute when any condition is False.
+ """
+
+ _VALID_FLAT_TYPES: ClassVar[set[str]] = {
+ "Equals",
+ "GreaterThan",
+ "GreaterThanOrEqualTo",
+ "LessThan",
+ "LessThanOrEqualTo",
+ "In",
+ }
+
+ template_fields: Sequence[str] = (
+ "condition_type",
+ "left",
+ "right",
+ "conditions",
+ "if_task_ids",
+ "else_task_ids",
+ )
+
+ def __init__(
+ self,
+ *,
+ condition_type: str | None = None,
+ left: Any = None,
+ right: Any = None,
+ conditions: list[dict] | None = None,
+ if_task_ids: str | list[str],
+ else_task_ids: str | list[str],
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ has_flat = condition_type is not None
+ has_list = conditions is not None
+
+ if has_flat and has_list:
+ raise ValueError(
+ "Cannot use 'condition_type' and 'conditions' together. "
+ "Use 'condition_type' with 'left'/'right' for a single
condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+ if not has_flat and not has_list:
+ raise ValueError(
+ "Missing condition: provide 'condition_type' with
'left'/'right' for a single condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+
+ if has_flat:
+ if condition_type not in self._VALID_FLAT_TYPES:
+ raise ValueError(
+ f"Unknown condition_type '{condition_type}'. "
+ f"Expected one of: {',
'.join(sorted(self._VALID_FLAT_TYPES))}."
+ )
+ self.condition_type = condition_type
+ self.left = left
+ self.right = right
+ if condition_type == "In":
+ self.conditions = [{"type": "In", "value": left, "in_values":
right}]
+ else:
+ self.conditions = [{"type": condition_type, "left": left,
"right": right}]
+ else:
+ self.condition_type = None
+ self.left = None
+ self.right = None
+ self.conditions = conditions
+
+ if not self.conditions:
+ raise ValueError("At least 1 condition is required, but got an
empty list.")
+ if len(self.conditions) > 200:
+ raise ValueError(f"At most 200 conditions are allowed, but got
{len(self.conditions)}.")
+ self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else
if_task_ids
+ self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str)
else else_task_ids
+
+ @staticmethod
+ def _cast(value: Any) -> Any:
+ """
+ Cast Jinja-rendered string values to appropriate Python types.
+
+ - Numeric strings become int or float.
+ - ``"true"``/``"false"`` become booleans.
+ - Other strings are returned unchanged.
+ - Non-string types pass through as-is.
+ """
+ if not isinstance(value, str):
+ return value
+ if value == "None":
+ return None
+ try:
+ return int(value)
+ except (ValueError, TypeError):
+ pass
+ try:
+ return float(value)
+ except (ValueError, TypeError):
+ pass
+ if value.lower() == "true":
+ return True
+ if value.lower() == "false":
+ return False
+ return value
+
+ _COMPARISON_OPERATORS: ClassVar[dict[str, Callable[[Any, Any], bool]]] = {
+ "Equals": lambda left, right: left == right,
+ "GreaterThan": lambda left, right: left > right,
+ "GreaterThanOrEqualTo": lambda left, right: left >= right,
+ "LessThan": lambda left, right: left < right,
+ "LessThanOrEqualTo": lambda left, right: left <= right,
+ }
+
+ def _evaluate(self, condition: dict, depth: int = 0) -> bool:
+ """
+ Recursively evaluate a single condition dict.
+
+ :param condition: A condition dictionary with a ``type`` key.
+ :param depth: Current nesting depth (used for log indentation only).
+ :returns: Boolean result of the condition evaluation.
+ """
+ indent = " " * depth
+ try:
+ cond_type = condition["type"]
+ except KeyError:
+ raise ValueError("Condition dict is missing required key 'type'.")
+
+ if cond_type in self._COMPARISON_OPERATORS:
+ try:
+ left = self._cast(condition["left"])
+ right = self._cast(condition["right"])
+ except KeyError as e:
+ raise ValueError(f"Condition '{cond_type}' missing required
key {e}.") from None
+
+ # None check — likely an XCom that was not pushed
+ if left is None or right is None:
+ raise TypeError(
+ f"Condition '{cond_type}' received None: left={left!r},
right={right!r}. "
+ "This usually means the upstream task did not run or did
not push a value to XCom."
+ )
+
+ # Type compatibility check
+ left_type = type(left)
+ right_type = type(right)
+ numeric_types = (int, float)
+ left_is_numeric = isinstance(left, numeric_types) and not
isinstance(left, bool)
+ right_is_numeric = isinstance(right, numeric_types) and not
isinstance(right, bool)
+
+ if left_is_numeric and right_is_numeric:
+ pass
+ elif left_type is not right_type:
+ raise TypeError(
+ f"Cannot compare {left_type.__name__} ({left!r}) with
{right_type.__name__} ({right!r}) "
+ f"in condition '{cond_type}'. Both values must be the same
type."
+ )
+
+ result = self._COMPARISON_OPERATORS[cond_type](left, right)
+ self.log.info("%s%s: %r %s %r -> %s", indent, cond_type, left,
cond_type, right, result)
+ return result
+
+ if cond_type == "In":
+ try:
+ value = self._cast(condition["value"])
+ in_values = [self._cast(v) for v in condition["in_values"]]
+ except KeyError as e:
+ raise ValueError(f"Condition '{cond_type}' missing required
key {e}.") from None
+ result = value in in_values
+ self.log.info("%sIn: %r in %r -> %s", indent, value, in_values,
result)
+ return result
+
+ if cond_type == "Not":
+ try:
+ inner = condition["condition"]
+ except KeyError as e:
+ raise ValueError(f"Condition '{cond_type}' missing required
key {e}.") from None
+ inner_result = self._evaluate(inner, depth + 1)
+ result = not inner_result
+ self.log.info("%sNot: not %s -> %s", indent, inner_result, result)
+ return result
+
+ if cond_type == "Or":
+ try:
+ inner_conditions = condition["conditions"]
+ except KeyError as e:
+ raise ValueError(f"Condition '{cond_type}' missing required
key {e}.") from None
+ inner_results = [self._evaluate(c, depth + 1) for c in
inner_conditions]
+ result = any(inner_results)
+ self.log.info("%sOr: any(%r) -> %s", indent, inner_results, result)
+ return result
+
+ raise ValueError(f"Unknown condition type '{cond_type}'.")
+
+ def choose_branch(self, context: Context) -> list[str]:
+ """
+ Evaluate all conditions and return the appropriate branch task IDs.
+
+ :param context: Airflow context dictionary.
+ :returns: ``if_task_ids`` when all conditions are True,
``else_task_ids`` otherwise.
+ """
+ num_conditions = len(self.conditions)
+ self.log.info("Evaluating %d condition(s).", num_conditions)
+
+ results = [self._evaluate(c) for c in self.conditions]
Review Comment:
Small nit, but descriptive variable names are preferred. It's not too many
more keyboard strokes and it makes the code much nicer to read. Here and
elsewhere throughout this PR
##########
providers/amazon/tests/system/amazon/aws/example_sagemaker_condition.py:
##########
@@ -0,0 +1,202 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+System test for SageMakerConditionOperator.
+
+This operator evaluates conditions against XCom values using pure Python
+(no AWS API calls), so this test only needs Airflow — no SageMaker resources.
+
+The DAG simulates an ML accuracy-gate workflow:
+
+1. ``produce_metrics`` pushes a dict of metrics to XCom.
+2. ``check_accuracy`` uses SageMakerConditionOperator to branch:
+ - accuracy >= 0.9 AND loss < 0.1 -> ``deploy_model``
+ - otherwise -> ``retrain_model``
+3. Only the correct branch task runs; the other is skipped.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerConditionOperator
+
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+
+if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk import DAG, chain, task
+else:
+ from airflow.decorators import task # type: ignore[attr-defined,no-redef]
+ from airflow.models.baseoperator import chain # type:
ignore[attr-defined,no-redef]
+ from airflow.models.dag import DAG # type:
ignore[attr-defined,no-redef,assignment]
+
+from system.amazon.aws.utils import SystemTestContextBuilder
+
+DAG_ID = "example_sagemaker_condition"
+
+sys_test_context_task = SystemTestContextBuilder().build()
+
+
+with DAG(
+ DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+) as dag:
+ test_context = sys_test_context_task()
+
+ # TEST SETUP: push simulated ML metrics to XCom
+
+ @task
+ def produce_metrics():
+ """Simulate an ML training job that returns accuracy and loss
metrics."""
+ return {"accuracy": 0.95, "loss": 0.04}
+
+ metrics = produce_metrics()
+
+ # TEST BODY
+
+ # [START howto_operator_sagemaker_condition]
+ check_accuracy = SageMakerConditionOperator(
+ task_id="check_accuracy",
+ conditions=[
+ {
+ "type": "GreaterThanOrEqualTo",
+ "left": "{{
ti.xcom_pull(task_ids='produce_metrics')['accuracy'] }}",
+ "right": 0.9,
+ },
+ {
+ "type": "LessThan",
+ "left": "{{ ti.xcom_pull(task_ids='produce_metrics')['loss']
}}",
+ "right": 0.1,
+ },
+ ],
+ if_task_ids=["deploy_model"],
+ else_task_ids=["retrain_model"],
+ )
+ # [END howto_operator_sagemaker_condition]
+
+ @task
+ def deploy_model():
+ """Placeholder: model meets quality bar, proceed to deployment."""
+ return "deployed"
+
+ @task
+ def retrain_model():
+ """Placeholder: model does not meet quality bar, retrain."""
+ return "retrained"
+
+ # Scenario 2: condition evaluates to False -> else branch
+
+ @task
+ def produce_bad_metrics():
+ """Simulate a training job with poor accuracy."""
+ return {"accuracy": 0.5, "loss": 0.8}
+
+ bad_metrics = produce_bad_metrics()
+
+ # [START howto_operator_sagemaker_condition_flat]
+ check_bad_accuracy = SageMakerConditionOperator(
+ task_id="check_bad_accuracy",
+ condition_type="GreaterThanOrEqualTo",
+ left="{{ ti.xcom_pull(task_ids='produce_bad_metrics')['accuracy'] }}",
+ right=0.9,
+ if_task_ids=["should_not_run"],
+ else_task_ids=["should_run"],
+ )
+ # [END howto_operator_sagemaker_condition_flat]
+
+ @task
+ def should_not_run():
+ """This task should be skipped because accuracy < 0.9."""
+ return "error: should not have run"
+
+ @task
+ def should_run():
+ """This task should execute because accuracy < 0.9 -> else branch."""
+ return "correctly routed to else branch"
+
+ # Scenario 3: Or condition + Not condition
+
+ # [START howto_operator_sagemaker_condition_not_or]
+ check_logical = SageMakerConditionOperator(
+ task_id="check_logical",
+ conditions=[
+ {
+ "type": "Or",
+ "conditions": [
+ {"type": "Equals", "left": 1, "right": 2},
+ {"type": "Equals", "left": 3, "right": 3},
+ ],
+ },
+ {
+ "type": "Not",
+ "condition": {"type": "Equals", "left": "a", "right": "b"},
+ },
+ ],
+ if_task_ids=["logical_pass"],
+ else_task_ids=["logical_fail"],
+ )
+ # [END howto_operator_sagemaker_condition_not_or]
+
+ @task
+ def logical_pass():
+ """Or(1==2, 3==3) -> True AND Not(a==b) -> True -> if branch."""
+ return "logical conditions passed"
+
+ @task
+ def logical_fail():
+ return "error: logical conditions should have passed"
+
+ # Wire up dependencies
+
+ chain(
+ # TEST SETUP
+ test_context,
Review Comment:
I believe you can run test_context just once. That way you only get a single
test env_id
##########
providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_fail.py:
##########
@@ -0,0 +1,94 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Unit tests for SageMakerFailOperator.
+
+No system test exists for this operator. It has zero external dependencies
+(no AWS calls) — it validates a string and raises AirflowFailException.
+A system test would require the DAG to fail, which the test framework
+treats as a test failure. Unit tests provide full coverage.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerFailOperator
+from airflow.providers.common.compat.sdk import AirflowFailException
+
+from unit.amazon.aws.utils.test_template_fields import validate_template_fields
+
+
+def test_template_fields():
+ op = SageMakerFailOperator(task_id="test")
+ validate_template_fields(op)
+
+
+def test_default_error_message():
+ op = SageMakerFailOperator(task_id="test")
+ assert op.error_message == ""
Review Comment:
Do we want a better default than empty string?
##########
providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_fail.py:
##########
@@ -0,0 +1,94 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Unit tests for SageMakerFailOperator.
+
+No system test exists for this operator. It has zero external dependencies
+(no AWS calls) — it validates a string and raises AirflowFailException.
+A system test would require the DAG to fail, which the test framework
+treats as a test failure. Unit tests provide full coverage.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerFailOperator
+from airflow.providers.common.compat.sdk import AirflowFailException
+
+from unit.amazon.aws.utils.test_template_fields import validate_template_fields
+
+
+def test_template_fields():
+ op = SageMakerFailOperator(task_id="test")
+ validate_template_fields(op)
+
+
+def test_default_error_message():
+ op = SageMakerFailOperator(task_id="test")
+ assert op.error_message == ""
+
+
+def test_execute_raises_airflow_fail_exception():
+ op = SageMakerFailOperator(task_id="test", error_message="Model accuracy
below threshold")
+ with pytest.raises(AirflowFailException, match="Model accuracy below
threshold"):
+ op.execute(context=MagicMock(spec=dict))
+
+
+def test_execute_logs_before_raising(caplog):
Review Comment:
Here and elsewhere don't use caplog. There was some historical drama in the
Airflow community with it. It's very likely fine in this case, but if you could
instead just mock the `log` attribute that will be more inline with the
community.
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1991,3 +1997,270 @@ def execute(self, context):
self.hook.conn.get_waiter("notebook_instance_in_service").wait(
NotebookInstanceName=self.instance_name
)
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+ """
+ Evaluates a single condition or a list of conditions, and routes tasks
based on the result.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerConditionOperator`
+
+ :param condition_type: Condition type for the simple (flat) interface.
+ Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+ ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+ Mutually exclusive with ``conditions``.
+ :param left: Left operand for the flat interface. For ``In`` conditions
+ this is the value to check membership of.
+ :param right: Right operand for the flat interface. For ``In`` conditions
+ this is the list of allowed values.
+ :param conditions: List of condition dicts to evaluate (AND-ed together).
+ Each dict must have a ``type`` key. Minimum 1, maximum 200.
+ Mutually exclusive with ``condition_type``/``left``/``right``.
+ :param if_task_ids: Task ID(s) to execute when all conditions are True.
+ :param else_task_ids: Task ID(s) to execute when any condition is False.
+ """
+
+ _VALID_FLAT_TYPES: ClassVar[set[str]] = {
+ "Equals",
+ "GreaterThan",
+ "GreaterThanOrEqualTo",
+ "LessThan",
+ "LessThanOrEqualTo",
+ "In",
+ }
+
+ template_fields: Sequence[str] = (
+ "condition_type",
+ "left",
+ "right",
+ "conditions",
+ "if_task_ids",
+ "else_task_ids",
+ )
+
+ def __init__(
+ self,
+ *,
+ condition_type: str | None = None,
+ left: Any = None,
+ right: Any = None,
+ conditions: list[dict] | None = None,
+ if_task_ids: str | list[str],
+ else_task_ids: str | list[str],
Review Comment:
It's not immediately obvious to me which each of these are without looking
up the docs. Are these names the convention in Sagemaker Pipelines?
##########
providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py:
##########
@@ -1991,3 +1997,270 @@ def execute(self, context):
self.hook.conn.get_waiter("notebook_instance_in_service").wait(
NotebookInstanceName=self.instance_name
)
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+ """
+ Evaluates a single condition or a list of conditions, and routes tasks
based on the result.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerConditionOperator`
+
+ :param condition_type: Condition type for the simple (flat) interface.
+ Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+ ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+ Mutually exclusive with ``conditions``.
+ :param left: Left operand for the flat interface. For ``In`` conditions
+ this is the value to check membership of.
+ :param right: Right operand for the flat interface. For ``In`` conditions
+ this is the list of allowed values.
+ :param conditions: List of condition dicts to evaluate (AND-ed together).
+ Each dict must have a ``type`` key. Minimum 1, maximum 200.
+ Mutually exclusive with ``condition_type``/``left``/``right``.
+ :param if_task_ids: Task ID(s) to execute when all conditions are True.
+ :param else_task_ids: Task ID(s) to execute when any condition is False.
+ """
+
+ _VALID_FLAT_TYPES: ClassVar[set[str]] = {
+ "Equals",
+ "GreaterThan",
+ "GreaterThanOrEqualTo",
+ "LessThan",
+ "LessThanOrEqualTo",
+ "In",
+ }
+
+ template_fields: Sequence[str] = (
+ "condition_type",
+ "left",
+ "right",
+ "conditions",
+ "if_task_ids",
+ "else_task_ids",
+ )
+
+ def __init__(
+ self,
+ *,
+ condition_type: str | None = None,
+ left: Any = None,
+ right: Any = None,
+ conditions: list[dict] | None = None,
+ if_task_ids: str | list[str],
+ else_task_ids: str | list[str],
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ has_flat = condition_type is not None
+ has_list = conditions is not None
+
+ if has_flat and has_list:
+ raise ValueError(
+ "Cannot use 'condition_type' and 'conditions' together. "
+ "Use 'condition_type' with 'left'/'right' for a single
condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+ if not has_flat and not has_list:
+ raise ValueError(
+ "Missing condition: provide 'condition_type' with
'left'/'right' for a single condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+
+ if has_flat:
+ if condition_type not in self._VALID_FLAT_TYPES:
+ raise ValueError(
+ f"Unknown condition_type '{condition_type}'. "
+ f"Expected one of: {',
'.join(sorted(self._VALID_FLAT_TYPES))}."
+ )
+ self.condition_type = condition_type
+ self.left = left
+ self.right = right
+ if condition_type == "In":
+ self.conditions = [{"type": "In", "value": left, "in_values":
right}]
+ else:
+ self.conditions = [{"type": condition_type, "left": left,
"right": right}]
+ else:
+ self.condition_type = None
+ self.left = None
+ self.right = None
+ self.conditions = conditions
+
+ if not self.conditions:
+ raise ValueError("At least 1 condition is required, but got an
empty list.")
+ if len(self.conditions) > 200:
+ raise ValueError(f"At most 200 conditions are allowed, but got
{len(self.conditions)}.")
+ self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else
if_task_ids
+ self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str)
else else_task_ids
+
+ @staticmethod
+ def _cast(value: Any) -> Any:
+ """
+ Cast Jinja-rendered string values to appropriate Python types.
+
+ - Numeric strings become int or float.
+ - ``"true"``/``"false"`` become booleans.
+ - Other strings are returned unchanged.
+ - Non-string types pass through as-is.
+ """
+ if not isinstance(value, str):
+ return value
+ if value == "None":
+ return None
+ try:
+ return int(value)
+ except (ValueError, TypeError):
+ pass
+ try:
+ return float(value)
+ except (ValueError, TypeError):
+ pass
+ if value.lower() == "true":
+ return True
+ if value.lower() == "false":
+ return False
+ return value
+
+ _COMPARISON_OPERATORS: ClassVar[dict[str, Callable[[Any, Any], bool]]] = {
+ "Equals": lambda left, right: left == right,
+ "GreaterThan": lambda left, right: left > right,
+ "GreaterThanOrEqualTo": lambda left, right: left >= right,
+ "LessThan": lambda left, right: left < right,
+ "LessThanOrEqualTo": lambda left, right: left <= right,
+ }
+
+ def _evaluate(self, condition: dict, depth: int = 0) -> bool:
+ """
+ Recursively evaluate a single condition dict.
+
+ :param condition: A condition dictionary with a ``type`` key.
+ :param depth: Current nesting depth (used for log indentation only).
+ :returns: Boolean result of the condition evaluation.
+ """
+ indent = " " * depth
+ try:
+ cond_type = condition["type"]
+ except KeyError:
+ raise ValueError("Condition dict is missing required key 'type'.")
+
+ if cond_type in self._COMPARISON_OPERATORS:
+ try:
+ left = self._cast(condition["left"])
+ right = self._cast(condition["right"])
+ except KeyError as e:
+ raise ValueError(f"Condition '{cond_type}' missing required
key {e}.") from None
+
+ # None check — likely an XCom that was not pushed
+ if left is None or right is None:
+ raise TypeError(
+ f"Condition '{cond_type}' received None: left={left!r},
right={right!r}. "
+ "This usually means the upstream task did not run or did
not push a value to XCom."
+ )
+
+ # Type compatibility check
+ left_type = type(left)
+ right_type = type(right)
+ numeric_types = (int, float)
+ left_is_numeric = isinstance(left, numeric_types) and not
isinstance(left, bool)
+ right_is_numeric = isinstance(right, numeric_types) and not
isinstance(right, bool)
+
+ if left_is_numeric and right_is_numeric:
+ pass
Review Comment:
Is this if statement helpful?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]