This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9570cb1482 Make Start and Stop SageMaker Pipelines operators 
deferrable (#32683)
9570cb1482 is described below

commit 9570cb1482d25f288e607aaa1210b2457bc5ed12
Author: RaphaĆ«l Vandon <[email protected]>
AuthorDate: Tue Jul 25 11:46:08 2023 -0700

    Make Start and Stop SageMaker Pipelines operators deferrable (#32683)
---
 airflow/providers/amazon/aws/hooks/sagemaker.py    | 38 ++++++----
 .../providers/amazon/aws/operators/sagemaker.py    | 82 +++++++++++++++++++--
 airflow/providers/amazon/aws/triggers/ecs.py       |  1 +
 airflow/providers/amazon/aws/triggers/sagemaker.py | 85 +++++++++++++++++++++-
 .../providers/amazon/aws/waiters/sagemaker.json    | 46 ++++++++++++
 .../aws/operators/test_sagemaker_pipeline.py       | 48 +++++++++---
 6 files changed, 269 insertions(+), 31 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py 
b/airflow/providers/amazon/aws/hooks/sagemaker.py
index 72354c9c8b..758ba7e9c5 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -23,6 +23,7 @@ import re
 import tarfile
 import tempfile
 import time
+import warnings
 from collections import Counter
 from datetime import datetime
 from functools import partial
@@ -30,7 +31,7 @@ from typing import Any, Callable, Generator, cast
 
 from botocore.exceptions import ClientError
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
 from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -1061,7 +1062,7 @@ class SageMakerHook(AwsBaseHook):
         display_name: str = "airflow-triggered-execution",
         pipeline_params: dict | None = None,
         wait_for_completion: bool = False,
-        check_interval: int = 30,
+        check_interval: int | None = None,
         verbose: bool = True,
     ) -> str:
         """Start a new execution for a SageMaker pipeline.
@@ -1073,14 +1074,19 @@ class SageMakerHook(AwsBaseHook):
         :param display_name: The name this pipeline execution will have in the 
UI. Doesn't need to be unique.
         :param pipeline_params: Optional parameters for the pipeline.
             All parameters supplied need to already be present in the pipeline 
definition.
-        :param wait_for_completion: Will only return once the pipeline is 
complete if true.
-        :param check_interval: How long to wait between checks for pipeline 
status when waiting for
-            completion.
-        :param verbose: Whether to print steps details when waiting for 
completion.
-            Defaults to true, consider turning off for pipelines that have 
thousands of steps.
 
         :return: the ARN of the pipeline execution launched.
         """
+        if wait_for_completion or check_interval is not None:
+            warnings.warn(
+                "parameter `wait_for_completion` and `check_interval` are 
deprecated, "
+                "remove them and call check_status yourself if you want to 
wait for completion",
+                AirflowProviderDeprecationWarning,
+                stacklevel=2,
+            )
+        if check_interval is None:
+            check_interval = 30
+
         formatted_params = format_tags(pipeline_params, key_label="Name")
 
         try:
@@ -1108,7 +1114,7 @@ class SageMakerHook(AwsBaseHook):
         self,
         pipeline_exec_arn: str,
         wait_for_completion: bool = False,
-        check_interval: int = 10,
+        check_interval: int | None = None,
         verbose: bool = True,
         fail_if_not_running: bool = False,
     ) -> str:
@@ -1119,12 +1125,6 @@ class SageMakerHook(AwsBaseHook):
 
         :param pipeline_exec_arn: Amazon Resource Name (ARN) of the pipeline 
execution.
             It's the ARN of the pipeline itself followed by "/execution/" and 
an id.
-        :param wait_for_completion: Whether to wait for the pipeline to reach 
a final state.
-            (i.e. either 'Stopped' or 'Failed')
-        :param check_interval: How long to wait between checks for pipeline 
status when waiting for
-            completion.
-        :param verbose: Whether to print steps details when waiting for 
completion.
-            Defaults to true, consider turning off for pipelines that have 
thousands of steps.
         :param fail_if_not_running: This method will raise an exception if the 
pipeline we're trying to stop
             is not in an "Executing" state when the call is sent (which would 
mean that the pipeline is
             already either stopping or stopped).
@@ -1133,6 +1133,16 @@ class SageMakerHook(AwsBaseHook):
         :return: Status of the pipeline execution after the operation.
             One of 'Executing'|'Stopping'|'Stopped'|'Failed'|'Succeeded'.
         """
+        if wait_for_completion or check_interval is not None:
+            warnings.warn(
+                "parameter `wait_for_completion` and `check_interval` are 
deprecated, "
+                "remove them and call check_status yourself if you want to 
wait for completion",
+                AirflowProviderDeprecationWarning,
+                stacklevel=2,
+            )
+        if check_interval is None:
+            check_interval = 10
+
         retries = 2  # i.e. 3 calls max, 1 initial + 2 retries
         while True:
             try:
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index ac1b7a73d2..83a1e4f3d2 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -30,7 +30,10 @@ from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarni
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.amazon.aws.triggers.sagemaker import (
+    SageMakerPipelineTrigger,
+    SageMakerTrigger,
+)
 from airflow.providers.amazon.aws.utils import trim_none_values
 from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
 from airflow.providers.amazon.aws.utils.tags import format_tags
@@ -998,8 +1001,10 @@ class 
SageMakerStartPipelineOperator(SageMakerBaseOperator):
         All parameters supplied need to already be present in the pipeline 
definition.
     :param wait_for_completion: If true, this operator will only complete once 
the pipeline is complete.
     :param check_interval: How long to wait between checks for pipeline status 
when waiting for completion.
+    :param waiter_max_attempts: How many times to check the status before 
failing.
     :param verbose: Whether to print steps details when waiting for completion.
         Defaults to true, consider turning off for pipelines that have 
thousands of steps.
+    :param deferrable: Run operator in the deferrable mode.
 
     :return str: Returns The ARN of the pipeline execution created in Amazon 
SageMaker.
     """
@@ -1015,7 +1020,9 @@ class 
SageMakerStartPipelineOperator(SageMakerBaseOperator):
         pipeline_params: dict | None = None,
         wait_for_completion: bool = False,
         check_interval: int = CHECK_INTERVAL_SECOND,
+        waiter_max_attempts: int = 9999,
         verbose: bool = True,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs,
     ):
         super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
@@ -1024,22 +1031,46 @@ class 
SageMakerStartPipelineOperator(SageMakerBaseOperator):
         self.pipeline_params = pipeline_params
         self.wait_for_completion = wait_for_completion
         self.check_interval = check_interval
+        self.waiter_max_attempts = waiter_max_attempts
         self.verbose = verbose
+        self.deferrable = deferrable
 
     def execute(self, context: Context) -> str:
         arn = self.hook.start_pipeline(
             pipeline_name=self.pipeline_name,
             display_name=self.display_name,
             pipeline_params=self.pipeline_params,
-            wait_for_completion=self.wait_for_completion,
-            check_interval=self.check_interval,
-            verbose=self.verbose,
         )
         self.log.info(
             "Starting a new execution for pipeline %s, running with ARN %s", 
self.pipeline_name, arn
         )
+        if self.deferrable:
+            self.defer(
+                trigger=SageMakerPipelineTrigger(
+                    waiter_type=SageMakerPipelineTrigger.Type.COMPLETE,
+                    pipeline_execution_arn=arn,
+                    waiter_delay=self.check_interval,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                ),
+                method_name="execute_complete",
+            )
+        elif self.wait_for_completion:
+            self.hook.check_status(
+                arn,
+                "PipelineExecutionStatus",
+                lambda p: self.hook.describe_pipeline_exec(p, self.verbose),
+                self.check_interval,
+                non_terminal_states=self.hook.pipeline_non_terminal_states,
+                max_ingestion_time=self.waiter_max_attempts * 
self.check_interval,
+            )
         return arn
 
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        if event is None or event["status"] != "success":
+            raise AirflowException(f"Failure during pipeline execution: 
{event}")
+        return event["value"]
+
 
 class SageMakerStopPipelineOperator(SageMakerBaseOperator):
     """
@@ -1057,6 +1088,7 @@ class 
SageMakerStopPipelineOperator(SageMakerBaseOperator):
     :param verbose: Whether to print steps details when waiting for completion.
         Defaults to true, consider turning off for pipelines that have 
thousands of steps.
     :param fail_if_not_running: raises an exception if the pipeline stopped or 
succeeded before this was run
+    :param deferrable: Run operator in the deferrable mode.
 
     :return str: Returns the status of the pipeline execution after the 
operation has been done.
     """
@@ -1073,23 +1105,24 @@ class 
SageMakerStopPipelineOperator(SageMakerBaseOperator):
         pipeline_exec_arn: str,
         wait_for_completion: bool = False,
         check_interval: int = CHECK_INTERVAL_SECOND,
+        waiter_max_attempts: int = 9999,
         verbose: bool = True,
         fail_if_not_running: bool = False,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs,
     ):
         super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
         self.pipeline_exec_arn = pipeline_exec_arn
         self.wait_for_completion = wait_for_completion
         self.check_interval = check_interval
+        self.waiter_max_attempts = waiter_max_attempts
         self.verbose = verbose
         self.fail_if_not_running = fail_if_not_running
+        self.deferrable = deferrable
 
     def execute(self, context: Context) -> str:
         status = self.hook.stop_pipeline(
             pipeline_exec_arn=self.pipeline_exec_arn,
-            wait_for_completion=self.wait_for_completion,
-            check_interval=self.check_interval,
-            verbose=self.verbose,
             fail_if_not_running=self.fail_if_not_running,
         )
         self.log.info(
@@ -1097,8 +1130,43 @@ class 
SageMakerStopPipelineOperator(SageMakerBaseOperator):
             self.pipeline_exec_arn,
             status,
         )
+
+        if status not in self.hook.pipeline_non_terminal_states:
+            # pipeline already stopped
+            return status
+
+        # else, eventually wait for completion
+        if self.deferrable:
+            self.defer(
+                trigger=SageMakerPipelineTrigger(
+                    waiter_type=SageMakerPipelineTrigger.Type.STOPPED,
+                    pipeline_execution_arn=self.pipeline_exec_arn,
+                    waiter_delay=self.check_interval,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                ),
+                method_name="execute_complete",
+            )
+        elif self.wait_for_completion:
+            status = self.hook.check_status(
+                self.pipeline_exec_arn,
+                "PipelineExecutionStatus",
+                lambda p: self.hook.describe_pipeline_exec(p, self.verbose),
+                self.check_interval,
+                non_terminal_states=self.hook.pipeline_non_terminal_states,
+                max_ingestion_time=self.waiter_max_attempts * 
self.check_interval,
+            )["PipelineExecutionStatus"]
+
         return status
 
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+        if event is None or event["status"] != "success":
+            raise AirflowException(f"Failure during pipeline execution: 
{event}")
+        else:
+            # theoretically we should do a `describe` call to know this,
+            # but if we reach this point, this is the only possible status
+            return "Stopped"
+
 
 class SageMakerRegisterModelVersionOperator(SageMakerBaseOperator):
     """
diff --git a/airflow/providers/amazon/aws/triggers/ecs.py 
b/airflow/providers/amazon/aws/triggers/ecs.py
index d0fdfb63b6..cc203207d1 100644
--- a/airflow/providers/amazon/aws/triggers/ecs.py
+++ b/airflow/providers/amazon/aws/triggers/ecs.py
@@ -173,6 +173,7 @@ class TaskDoneTrigger(BaseTrigger):
                     )
                     # we reach this point only if the waiter met a success 
criteria
                     yield TriggerEvent({"status": "success", "task_arn": 
self.task_arn})
+                    return
                 except WaiterError as error:
                     if "terminal failure" in str(error):
                         raise
diff --git a/airflow/providers/amazon/aws/triggers/sagemaker.py 
b/airflow/providers/amazon/aws/triggers/sagemaker.py
index ca511a4a46..ec11323d42 100644
--- a/airflow/providers/amazon/aws/triggers/sagemaker.py
+++ b/airflow/providers/amazon/aws/triggers/sagemaker.py
@@ -17,9 +17,15 @@
 
 from __future__ import annotations
 
+import asyncio
+from collections import Counter
+from enum import IntEnum
 from functools import cached_property
-from typing import Any
+from typing import Any, AsyncIterator
 
+from botocore.exceptions import WaiterError
+
+from airflow import AirflowException
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
 from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -115,3 +121,80 @@ class SageMakerTrigger(BaseTrigger):
                 status_args=[self._get_response_status_key(self.job_type)],
             )
             yield TriggerEvent({"status": "success", "message": "Job 
completed."})
+
+
+class SageMakerPipelineTrigger(BaseTrigger):
+    """Trigger to wait for a sagemaker pipeline execution to finish."""
+
+    class Type(IntEnum):
+        """Type of waiter to use."""
+
+        COMPLETE = 1
+        STOPPED = 2
+
+    def __init__(
+        self,
+        waiter_type: Type,
+        pipeline_execution_arn: str,
+        waiter_delay: int,
+        waiter_max_attempts: int,
+        aws_conn_id: str,
+    ):
+        self.waiter_type = waiter_type
+        self.pipeline_execution_arn = pipeline_execution_arn
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.aws_conn_id = aws_conn_id
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            self.__class__.__module__ + "." + self.__class__.__qualname__,
+            {
+                "waiter_type": self.waiter_type.value,  # saving the int value 
here
+                "pipeline_execution_arn": self.pipeline_execution_arn,
+                "waiter_delay": self.waiter_delay,
+                "waiter_max_attempts": self.waiter_max_attempts,
+                "aws_conn_id": self.aws_conn_id,
+            },
+        )
+
+    _waiter_name = {
+        Type.COMPLETE: "PipelineExecutionComplete",
+        Type.STOPPED: "PipelineExecutionStopped",
+    }
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        attempts = 0
+        hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
+        async with hook.async_conn as conn:
+            waiter = hook.get_waiter(self._waiter_name[self.waiter_type], 
deferrable=True, client=conn)
+            while attempts < self.waiter_max_attempts:
+                attempts = attempts + 1
+                try:
+                    await waiter.wait(
+                        PipelineExecutionArn=self.pipeline_execution_arn, 
WaiterConfig={"MaxAttempts": 1}
+                    )
+                    # we reach this point only if the waiter met a success 
criteria
+                    yield TriggerEvent({"status": "success", "value": 
self.pipeline_execution_arn})
+                    return
+                except WaiterError as error:
+                    if "terminal failure" in str(error):
+                        raise
+
+                    self.log.info(
+                        "Status of the pipeline execution: %s", 
error.last_response["PipelineExecutionStatus"]
+                    )
+
+                    res = await conn.list_pipeline_execution_steps(
+                        PipelineExecutionArn=self.pipeline_execution_arn
+                    )
+                    count_by_state = Counter(s["StepStatus"] for s in 
res["PipelineExecutionSteps"])
+                    running_steps = [
+                        s["StepName"] for s in res["PipelineExecutionSteps"] 
if s["StepStatus"] == "Executing"
+                    ]
+                    self.log.info("State of the pipeline steps: %s", 
count_by_state)
+                    self.log.info("Steps currently in progress: %s", 
running_steps)
+
+                    await asyncio.sleep(int(self.waiter_delay))
+
+            raise AirflowException("Waiter error: max attempts reached")
diff --git a/airflow/providers/amazon/aws/waiters/sagemaker.json 
b/airflow/providers/amazon/aws/waiters/sagemaker.json
index 2c2760982c..eba60c7266 100644
--- a/airflow/providers/amazon/aws/waiters/sagemaker.json
+++ b/airflow/providers/amazon/aws/waiters/sagemaker.json
@@ -104,6 +104,52 @@
                     "state": "failure"
                 }
             ]
+        },
+        "PipelineExecutionComplete": {
+            "delay": 30,
+            "operation": "DescribePipelineExecution",
+            "maxAttempts": 60,
+            "description": "Wait until pipeline execution is Succeeded",
+            "acceptors": [
+                {
+                    "matcher": "path",
+                    "argument": "PipelineExecutionStatus",
+                    "expected": "Succeeded",
+                    "state": "success"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "PipelineExecutionStatus",
+                    "expected": "Failed",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "PipelineExecutionStatus",
+                    "expected": "Stopped",
+                    "state": "failure"
+                }
+            ]
+        },
+        "PipelineExecutionStopped": {
+            "delay": 10,
+            "operation": "DescribePipelineExecution",
+            "maxAttempts": 120,
+            "description": "Wait until pipeline execution is Stopped",
+            "acceptors": [
+                {
+                    "matcher": "path",
+                    "argument": "PipelineExecutionStatus",
+                    "expected": "Stopped",
+                    "state": "success"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "PipelineExecutionStatus",
+                    "expected": "Failed",
+                    "state": "failure"
+                }
+            ]
         }
     }
 }
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_pipeline.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_pipeline.py
index dde3627458..2d509625c5 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_pipeline.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_pipeline.py
@@ -18,17 +18,23 @@
 from __future__ import annotations
 
 from unittest import mock
+from unittest.mock import MagicMock
 
+import pytest
+
+from airflow.exceptions import TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators.sagemaker import (
     SageMakerStartPipelineOperator,
     SageMakerStopPipelineOperator,
 )
+from airflow.providers.amazon.aws.triggers.sagemaker import 
SageMakerPipelineTrigger
 
 
 class TestSageMakerStartPipelineOperator:
     @mock.patch.object(SageMakerHook, "start_pipeline")
-    def test_execute(self, start_pipeline):
+    @mock.patch.object(SageMakerHook, "check_status")
+    def test_execute(self, check_status, start_pipeline):
         op = SageMakerStartPipelineOperator(
             task_id="test_sagemaker_operator",
             pipeline_name="my_pipeline",
@@ -39,17 +45,29 @@ class TestSageMakerStartPipelineOperator:
             verbose=False,
         )
 
-        op.execute(None)
+        op.execute({})
 
         start_pipeline.assert_called_once_with(
             pipeline_name="my_pipeline",
             display_name="test_disp_name",
             pipeline_params={"is_a_test": "yes"},
-            wait_for_completion=True,
-            check_interval=12,
-            verbose=False,
+        )
+        check_status.assert_called_once()
+
+    @mock.patch.object(SageMakerHook, "start_pipeline")
+    def test_defer(self, start_mock):
+        op = SageMakerStartPipelineOperator(
+            task_id="test_sagemaker_operator",
+            pipeline_name="my_pipeline",
+            deferrable=True,
         )
 
+        with pytest.raises(TaskDeferred) as defer:
+            op.execute({})
+
+        assert isinstance(defer.value.trigger, SageMakerPipelineTrigger)
+        assert defer.value.trigger.waiter_type == 
SageMakerPipelineTrigger.Type.COMPLETE
+
 
 class TestSageMakerStopPipelineOperator:
     @mock.patch.object(SageMakerHook, "stop_pipeline")
@@ -58,12 +76,24 @@ class TestSageMakerStopPipelineOperator:
             task_id="test_sagemaker_operator", pipeline_exec_arn="pipeline_arn"
         )
 
-        op.execute(None)
+        op.execute({})
 
         stop_pipeline.assert_called_once_with(
             pipeline_exec_arn="pipeline_arn",
-            wait_for_completion=False,
-            check_interval=30,
             fail_if_not_running=False,
-            verbose=True,
         )
+
+    @mock.patch.object(SageMakerHook, "stop_pipeline")
+    def test_defer(self, stop_mock: MagicMock):
+        stop_mock.return_value = "Stopping"
+        op = SageMakerStopPipelineOperator(
+            task_id="test_sagemaker_operator",
+            pipeline_exec_arn="my_pipeline_arn",
+            deferrable=True,
+        )
+
+        with pytest.raises(TaskDeferred) as defer:
+            op.execute({})
+
+        assert isinstance(defer.value.trigger, SageMakerPipelineTrigger)
+        assert defer.value.trigger.waiter_type == 
SageMakerPipelineTrigger.Type.STOPPED

Reply via email to