This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9570cb1482 Make Start and Stop SageMaker Pipelines operators
deferrable (#32683)
9570cb1482 is described below
commit 9570cb1482d25f288e607aaa1210b2457bc5ed12
Author: Raphaƫl Vandon <[email protected]>
AuthorDate: Tue Jul 25 11:46:08 2023 -0700
Make Start and Stop SageMaker Pipelines operators deferrable (#32683)
---
airflow/providers/amazon/aws/hooks/sagemaker.py | 38 ++++++----
.../providers/amazon/aws/operators/sagemaker.py | 82 +++++++++++++++++++--
airflow/providers/amazon/aws/triggers/ecs.py | 1 +
airflow/providers/amazon/aws/triggers/sagemaker.py | 85 +++++++++++++++++++++-
.../providers/amazon/aws/waiters/sagemaker.json | 46 ++++++++++++
.../aws/operators/test_sagemaker_pipeline.py | 48 +++++++++---
6 files changed, 269 insertions(+), 31 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py
b/airflow/providers/amazon/aws/hooks/sagemaker.py
index 72354c9c8b..758ba7e9c5 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -23,6 +23,7 @@ import re
import tarfile
import tempfile
import time
+import warnings
from collections import Counter
from datetime import datetime
from functools import partial
@@ -30,7 +31,7 @@ from typing import Any, Callable, Generator, cast
from botocore.exceptions import ClientError
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -1061,7 +1062,7 @@ class SageMakerHook(AwsBaseHook):
display_name: str = "airflow-triggered-execution",
pipeline_params: dict | None = None,
wait_for_completion: bool = False,
- check_interval: int = 30,
+ check_interval: int | None = None,
verbose: bool = True,
) -> str:
"""Start a new execution for a SageMaker pipeline.
@@ -1073,14 +1074,19 @@ class SageMakerHook(AwsBaseHook):
:param display_name: The name this pipeline execution will have in the
UI. Doesn't need to be unique.
:param pipeline_params: Optional parameters for the pipeline.
All parameters supplied need to already be present in the pipeline
definition.
- :param wait_for_completion: Will only return once the pipeline is
complete if true.
- :param check_interval: How long to wait between checks for pipeline
status when waiting for
- completion.
- :param verbose: Whether to print steps details when waiting for
completion.
- Defaults to true, consider turning off for pipelines that have
thousands of steps.
:return: the ARN of the pipeline execution launched.
"""
+ if wait_for_completion or check_interval is not None:
+ warnings.warn(
+ "parameter `wait_for_completion` and `check_interval` are
deprecated, "
+ "remove them and call check_status yourself if you want to
wait for completion",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ if check_interval is None:
+ check_interval = 30
+
formatted_params = format_tags(pipeline_params, key_label="Name")
try:
@@ -1108,7 +1114,7 @@ class SageMakerHook(AwsBaseHook):
self,
pipeline_exec_arn: str,
wait_for_completion: bool = False,
- check_interval: int = 10,
+ check_interval: int | None = None,
verbose: bool = True,
fail_if_not_running: bool = False,
) -> str:
@@ -1119,12 +1125,6 @@ class SageMakerHook(AwsBaseHook):
:param pipeline_exec_arn: Amazon Resource Name (ARN) of the pipeline
execution.
It's the ARN of the pipeline itself followed by "/execution/" and
an id.
- :param wait_for_completion: Whether to wait for the pipeline to reach
a final state.
- (i.e. either 'Stopped' or 'Failed')
- :param check_interval: How long to wait between checks for pipeline
status when waiting for
- completion.
- :param verbose: Whether to print steps details when waiting for
completion.
- Defaults to true, consider turning off for pipelines that have
thousands of steps.
:param fail_if_not_running: This method will raise an exception if the
pipeline we're trying to stop
is not in an "Executing" state when the call is sent (which would
mean that the pipeline is
already either stopping or stopped).
@@ -1133,6 +1133,16 @@ class SageMakerHook(AwsBaseHook):
:return: Status of the pipeline execution after the operation.
One of 'Executing'|'Stopping'|'Stopped'|'Failed'|'Succeeded'.
"""
+ if wait_for_completion or check_interval is not None:
+ warnings.warn(
+ "parameter `wait_for_completion` and `check_interval` are
deprecated, "
+ "remove them and call check_status yourself if you want to
wait for completion",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ if check_interval is None:
+ check_interval = 10
+
retries = 2 # i.e. 3 calls max, 1 initial + 2 retries
while True:
try:
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py
b/airflow/providers/amazon/aws/operators/sagemaker.py
index ac1b7a73d2..83a1e4f3d2 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -30,7 +30,10 @@ from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarni
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.amazon.aws.triggers.sagemaker import (
+ SageMakerPipelineTrigger,
+ SageMakerTrigger,
+)
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
from airflow.providers.amazon.aws.utils.tags import format_tags
@@ -998,8 +1001,10 @@ class
SageMakerStartPipelineOperator(SageMakerBaseOperator):
All parameters supplied need to already be present in the pipeline
definition.
:param wait_for_completion: If true, this operator will only complete once
the pipeline is complete.
:param check_interval: How long to wait between checks for pipeline status
when waiting for completion.
+ :param waiter_max_attempts: How many times to check the status before
failing.
:param verbose: Whether to print steps details when waiting for completion.
Defaults to true, consider turning off for pipelines that have
thousands of steps.
+ :param deferrable: Run operator in the deferrable mode.
:return str: Returns The ARN of the pipeline execution created in Amazon
SageMaker.
"""
@@ -1015,7 +1020,9 @@ class
SageMakerStartPipelineOperator(SageMakerBaseOperator):
pipeline_params: dict | None = None,
wait_for_completion: bool = False,
check_interval: int = CHECK_INTERVAL_SECOND,
+ waiter_max_attempts: int = 9999,
verbose: bool = True,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
):
super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
@@ -1024,22 +1031,46 @@ class
SageMakerStartPipelineOperator(SageMakerBaseOperator):
self.pipeline_params = pipeline_params
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
+ self.waiter_max_attempts = waiter_max_attempts
self.verbose = verbose
+ self.deferrable = deferrable
def execute(self, context: Context) -> str:
arn = self.hook.start_pipeline(
pipeline_name=self.pipeline_name,
display_name=self.display_name,
pipeline_params=self.pipeline_params,
- wait_for_completion=self.wait_for_completion,
- check_interval=self.check_interval,
- verbose=self.verbose,
)
self.log.info(
"Starting a new execution for pipeline %s, running with ARN %s",
self.pipeline_name, arn
)
+ if self.deferrable:
+ self.defer(
+ trigger=SageMakerPipelineTrigger(
+ waiter_type=SageMakerPipelineTrigger.Type.COMPLETE,
+ pipeline_execution_arn=arn,
+ waiter_delay=self.check_interval,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ )
+ elif self.wait_for_completion:
+ self.hook.check_status(
+ arn,
+ "PipelineExecutionStatus",
+ lambda p: self.hook.describe_pipeline_exec(p, self.verbose),
+ self.check_interval,
+ non_terminal_states=self.hook.pipeline_non_terminal_states,
+ max_ingestion_time=self.waiter_max_attempts *
self.check_interval,
+ )
return arn
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
+ if event is None or event["status"] != "success":
+ raise AirflowException(f"Failure during pipeline execution:
{event}")
+ return event["value"]
+
class SageMakerStopPipelineOperator(SageMakerBaseOperator):
"""
@@ -1057,6 +1088,7 @@ class
SageMakerStopPipelineOperator(SageMakerBaseOperator):
:param verbose: Whether to print steps details when waiting for completion.
Defaults to true, consider turning off for pipelines that have
thousands of steps.
:param fail_if_not_running: raises an exception if the pipeline stopped or
succeeded before this was run
+ :param deferrable: Run operator in the deferrable mode.
:return str: Returns the status of the pipeline execution after the
operation has been done.
"""
@@ -1073,23 +1105,24 @@ class
SageMakerStopPipelineOperator(SageMakerBaseOperator):
pipeline_exec_arn: str,
wait_for_completion: bool = False,
check_interval: int = CHECK_INTERVAL_SECOND,
+ waiter_max_attempts: int = 9999,
verbose: bool = True,
fail_if_not_running: bool = False,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
):
super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
self.pipeline_exec_arn = pipeline_exec_arn
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
+ self.waiter_max_attempts = waiter_max_attempts
self.verbose = verbose
self.fail_if_not_running = fail_if_not_running
+ self.deferrable = deferrable
def execute(self, context: Context) -> str:
status = self.hook.stop_pipeline(
pipeline_exec_arn=self.pipeline_exec_arn,
- wait_for_completion=self.wait_for_completion,
- check_interval=self.check_interval,
- verbose=self.verbose,
fail_if_not_running=self.fail_if_not_running,
)
self.log.info(
@@ -1097,8 +1130,43 @@ class
SageMakerStopPipelineOperator(SageMakerBaseOperator):
self.pipeline_exec_arn,
status,
)
+
+ if status not in self.hook.pipeline_non_terminal_states:
+ # pipeline already stopped
+ return status
+
+ # else, eventually wait for completion
+ if self.deferrable:
+ self.defer(
+ trigger=SageMakerPipelineTrigger(
+ waiter_type=SageMakerPipelineTrigger.Type.STOPPED,
+ pipeline_execution_arn=self.pipeline_exec_arn,
+ waiter_delay=self.check_interval,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ )
+ elif self.wait_for_completion:
+ status = self.hook.check_status(
+ self.pipeline_exec_arn,
+ "PipelineExecutionStatus",
+ lambda p: self.hook.describe_pipeline_exec(p, self.verbose),
+ self.check_interval,
+ non_terminal_states=self.hook.pipeline_non_terminal_states,
+ max_ingestion_time=self.waiter_max_attempts *
self.check_interval,
+ )["PipelineExecutionStatus"]
+
return status
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
+ if event is None or event["status"] != "success":
+ raise AirflowException(f"Failure during pipeline execution:
{event}")
+ else:
+ # theoretically we should do a `describe` call to know this,
+ # but if we reach this point, this is the only possible status
+ return "Stopped"
+
class SageMakerRegisterModelVersionOperator(SageMakerBaseOperator):
"""
diff --git a/airflow/providers/amazon/aws/triggers/ecs.py
b/airflow/providers/amazon/aws/triggers/ecs.py
index d0fdfb63b6..cc203207d1 100644
--- a/airflow/providers/amazon/aws/triggers/ecs.py
+++ b/airflow/providers/amazon/aws/triggers/ecs.py
@@ -173,6 +173,7 @@ class TaskDoneTrigger(BaseTrigger):
)
# we reach this point only if the waiter met a success
criteria
yield TriggerEvent({"status": "success", "task_arn":
self.task_arn})
+ return
except WaiterError as error:
if "terminal failure" in str(error):
raise
diff --git a/airflow/providers/amazon/aws/triggers/sagemaker.py
b/airflow/providers/amazon/aws/triggers/sagemaker.py
index ca511a4a46..ec11323d42 100644
--- a/airflow/providers/amazon/aws/triggers/sagemaker.py
+++ b/airflow/providers/amazon/aws/triggers/sagemaker.py
@@ -17,9 +17,15 @@
from __future__ import annotations
+import asyncio
+from collections import Counter
+from enum import IntEnum
from functools import cached_property
-from typing import Any
+from typing import Any, AsyncIterator
+from botocore.exceptions import WaiterError
+
+from airflow import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -115,3 +121,80 @@ class SageMakerTrigger(BaseTrigger):
status_args=[self._get_response_status_key(self.job_type)],
)
yield TriggerEvent({"status": "success", "message": "Job
completed."})
+
+
+class SageMakerPipelineTrigger(BaseTrigger):
+ """Trigger to wait for a sagemaker pipeline execution to finish."""
+
+ class Type(IntEnum):
+ """Type of waiter to use."""
+
+ COMPLETE = 1
+ STOPPED = 2
+
+ def __init__(
+ self,
+ waiter_type: Type,
+ pipeline_execution_arn: str,
+ waiter_delay: int,
+ waiter_max_attempts: int,
+ aws_conn_id: str,
+ ):
+ self.waiter_type = waiter_type
+ self.pipeline_execution_arn = pipeline_execution_arn
+ self.waiter_delay = waiter_delay
+ self.waiter_max_attempts = waiter_max_attempts
+ self.aws_conn_id = aws_conn_id
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "waiter_type": self.waiter_type.value, # saving the int value
here
+ "pipeline_execution_arn": self.pipeline_execution_arn,
+ "waiter_delay": self.waiter_delay,
+ "waiter_max_attempts": self.waiter_max_attempts,
+ "aws_conn_id": self.aws_conn_id,
+ },
+ )
+
+ _waiter_name = {
+ Type.COMPLETE: "PipelineExecutionComplete",
+ Type.STOPPED: "PipelineExecutionStopped",
+ }
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ attempts = 0
+ hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
+ async with hook.async_conn as conn:
+ waiter = hook.get_waiter(self._waiter_name[self.waiter_type],
deferrable=True, client=conn)
+ while attempts < self.waiter_max_attempts:
+ attempts = attempts + 1
+ try:
+ await waiter.wait(
+ PipelineExecutionArn=self.pipeline_execution_arn,
WaiterConfig={"MaxAttempts": 1}
+ )
+ # we reach this point only if the waiter met a success
criteria
+ yield TriggerEvent({"status": "success", "value":
self.pipeline_execution_arn})
+ return
+ except WaiterError as error:
+ if "terminal failure" in str(error):
+ raise
+
+ self.log.info(
+ "Status of the pipeline execution: %s",
error.last_response["PipelineExecutionStatus"]
+ )
+
+ res = await conn.list_pipeline_execution_steps(
+ PipelineExecutionArn=self.pipeline_execution_arn
+ )
+ count_by_state = Counter(s["StepStatus"] for s in
res["PipelineExecutionSteps"])
+ running_steps = [
+ s["StepName"] for s in res["PipelineExecutionSteps"]
if s["StepStatus"] == "Executing"
+ ]
+ self.log.info("State of the pipeline steps: %s",
count_by_state)
+ self.log.info("Steps currently in progress: %s",
running_steps)
+
+ await asyncio.sleep(int(self.waiter_delay))
+
+ raise AirflowException("Waiter error: max attempts reached")
diff --git a/airflow/providers/amazon/aws/waiters/sagemaker.json
b/airflow/providers/amazon/aws/waiters/sagemaker.json
index 2c2760982c..eba60c7266 100644
--- a/airflow/providers/amazon/aws/waiters/sagemaker.json
+++ b/airflow/providers/amazon/aws/waiters/sagemaker.json
@@ -104,6 +104,52 @@
"state": "failure"
}
]
+ },
+ "PipelineExecutionComplete": {
+ "delay": 30,
+ "operation": "DescribePipelineExecution",
+ "maxAttempts": 60,
+ "description": "Wait until pipeline execution is Succeeded",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "PipelineExecutionStatus",
+ "expected": "Succeeded",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "PipelineExecutionStatus",
+ "expected": "Failed",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "PipelineExecutionStatus",
+ "expected": "Stopped",
+ "state": "failure"
+ }
+ ]
+ },
+ "PipelineExecutionStopped": {
+ "delay": 10,
+ "operation": "DescribePipelineExecution",
+ "maxAttempts": 120,
+ "description": "Wait until pipeline execution is Stopped",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "PipelineExecutionStatus",
+ "expected": "Stopped",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "PipelineExecutionStatus",
+ "expected": "Failed",
+ "state": "failure"
+ }
+ ]
}
}
}
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_pipeline.py
b/tests/providers/amazon/aws/operators/test_sagemaker_pipeline.py
index dde3627458..2d509625c5 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_pipeline.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_pipeline.py
@@ -18,17 +18,23 @@
from __future__ import annotations
from unittest import mock
+from unittest.mock import MagicMock
+import pytest
+
+from airflow.exceptions import TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators.sagemaker import (
SageMakerStartPipelineOperator,
SageMakerStopPipelineOperator,
)
+from airflow.providers.amazon.aws.triggers.sagemaker import
SageMakerPipelineTrigger
class TestSageMakerStartPipelineOperator:
@mock.patch.object(SageMakerHook, "start_pipeline")
- def test_execute(self, start_pipeline):
+ @mock.patch.object(SageMakerHook, "check_status")
+ def test_execute(self, check_status, start_pipeline):
op = SageMakerStartPipelineOperator(
task_id="test_sagemaker_operator",
pipeline_name="my_pipeline",
@@ -39,17 +45,29 @@ class TestSageMakerStartPipelineOperator:
verbose=False,
)
- op.execute(None)
+ op.execute({})
start_pipeline.assert_called_once_with(
pipeline_name="my_pipeline",
display_name="test_disp_name",
pipeline_params={"is_a_test": "yes"},
- wait_for_completion=True,
- check_interval=12,
- verbose=False,
+ )
+ check_status.assert_called_once()
+
+ @mock.patch.object(SageMakerHook, "start_pipeline")
+ def test_defer(self, start_mock):
+ op = SageMakerStartPipelineOperator(
+ task_id="test_sagemaker_operator",
+ pipeline_name="my_pipeline",
+ deferrable=True,
)
+ with pytest.raises(TaskDeferred) as defer:
+ op.execute({})
+
+ assert isinstance(defer.value.trigger, SageMakerPipelineTrigger)
+ assert defer.value.trigger.waiter_type ==
SageMakerPipelineTrigger.Type.COMPLETE
+
class TestSageMakerStopPipelineOperator:
@mock.patch.object(SageMakerHook, "stop_pipeline")
@@ -58,12 +76,24 @@ class TestSageMakerStopPipelineOperator:
task_id="test_sagemaker_operator", pipeline_exec_arn="pipeline_arn"
)
- op.execute(None)
+ op.execute({})
stop_pipeline.assert_called_once_with(
pipeline_exec_arn="pipeline_arn",
- wait_for_completion=False,
- check_interval=30,
fail_if_not_running=False,
- verbose=True,
)
+
+ @mock.patch.object(SageMakerHook, "stop_pipeline")
+ def test_defer(self, stop_mock: MagicMock):
+ stop_mock.return_value = "Stopping"
+ op = SageMakerStopPipelineOperator(
+ task_id="test_sagemaker_operator",
+ pipeline_exec_arn="my_pipeline_arn",
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred) as defer:
+ op.execute({})
+
+ assert isinstance(defer.value.trigger, SageMakerPipelineTrigger)
+ assert defer.value.trigger.waiter_type ==
SageMakerPipelineTrigger.Type.STOPPED