This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 14785bc84c deferrable mode for `SageMakerTuningOperator` and
`SageMakerEndpointOperator` (#32112)
14785bc84c is described below
commit 14785bc84c984b8747fa062b84e800d22ddc0477
Author: Raphaƫl Vandon <[email protected]>
AuthorDate: Tue Jun 27 14:04:09 2023 -0700
deferrable mode for `SageMakerTuningOperator` and
`SageMakerEndpointOperator` (#32112)
---
.../providers/amazon/aws/operators/sagemaker.py | 98 ++++++++++++++++++----
airflow/providers/amazon/aws/triggers/sagemaker.py | 38 ++++++---
.../providers/amazon/aws/waiters/sagemaker.json | 26 ++++++
.../aws/operators/test_sagemaker_endpoint.py | 24 +++++-
.../amazon/aws/operators/test_sagemaker_tuning.py | 15 +++-
.../amazon/aws/triggers/test_sagemaker.py | 26 +++---
.../providers/amazon/aws/example_sagemaker.py | 6 +-
7 files changed, 182 insertions(+), 51 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 0d4ba9fcab..8467459188 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import datetime
import json
import time
import warnings
@@ -375,6 +376,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
finish within max_ingestion_time seconds. If you set this parameter to
None it never times out.
:param operation: Whether to create an endpoint or update an endpoint.
Must be either 'create or 'update'.
:param aws_conn_id: The AWS connection ID to use.
+ :param deferrable: Will wait asynchronously for completion.
:return Dict: Returns The ARN of the endpoint created in Amazon SageMaker.
"""
@@ -387,15 +389,17 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
check_interval: int = CHECK_INTERVAL_SECOND,
max_ingestion_time: int | None = None,
operation: str = "create",
+ deferrable: bool = False,
**kwargs,
):
super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
- self.max_ingestion_time = max_ingestion_time
+ self.max_ingestion_time = max_ingestion_time or 3600 * 10
self.operation = operation.lower()
if self.operation not in ["create", "update"]:
raise ValueError('Invalid value! Argument operation has to be one
of "create" and "update"')
+ self.deferrable = deferrable
def _create_integer_fields(self) -> None:
"""Set fields which should be cast to integers."""
@@ -436,29 +440,54 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
try:
response = sagemaker_operation(
endpoint_info,
- wait_for_completion=self.wait_for_completion,
- check_interval=self.check_interval,
- max_ingestion_time=self.max_ingestion_time,
+ wait_for_completion=False,
)
+ # waiting for completion is handled here in the operator
except ClientError:
self.operation = "update"
sagemaker_operation = self.hook.update_endpoint
- log_str = "Updating"
response = sagemaker_operation(
endpoint_info,
- wait_for_completion=self.wait_for_completion,
- check_interval=self.check_interval,
- max_ingestion_time=self.max_ingestion_time,
+ wait_for_completion=False,
)
+
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Sagemaker endpoint creation failed:
{response}")
- else:
- return {
- "EndpointConfig": serialize(
-
self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"])
+
+ if self.deferrable:
+ self.defer(
+ trigger=SageMakerTrigger(
+ job_name=endpoint_info["EndpointName"],
+ job_type="endpoint",
+ poke_interval=self.check_interval,
+ aws_conn_id=self.aws_conn_id,
),
- "Endpoint":
serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
- }
+ method_name="execute_complete",
+ timeout=datetime.timedelta(seconds=self.max_ingestion_time),
+ )
+ elif self.wait_for_completion:
+ self.hook.get_waiter("endpoint_in_service").wait(
+ EndpointName=endpoint_info["EndpointName"],
+ WaiterConfig={"Delay": self.check_interval, "MaxAttempts":
self.max_ingestion_time},
+ )
+
+ return {
+ "EndpointConfig": serialize(
+
self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"])
+ ),
+ "Endpoint":
serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
+ }
+
+ def execute_complete(self, context, event=None):
+ if event["status"] != "success":
+ raise AirflowException(f"Error while running job: {event}")
+ endpoint_info = self.config.get("Endpoint", self.config)
+ return {
+ "EndpointConfig": serialize(
+
self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"])
+ ),
+ "Endpoint":
serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
+ }
class SageMakerTransformOperator(SageMakerBaseOperator):
@@ -652,6 +681,7 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
:param max_ingestion_time: If wait is set to True, the operation fails
if the tuning job doesn't finish within max_ingestion_time seconds. If
you
set this parameter to None, the operation does not timeout.
+ :param deferrable: Will wait asynchronously for completion.
:return Dict: Returns The ARN of the tuning job created in Amazon
SageMaker.
"""
@@ -663,12 +693,14 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
wait_for_completion: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
max_ingestion_time: int | None = None,
+ deferrable: bool = False,
**kwargs,
):
super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
+ self.deferrable = deferrable
def expand_role(self) -> None:
"""Expands an IAM role name into an ARN."""
@@ -695,16 +727,46 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
)
response = self.hook.create_tuning_job(
self.config,
- wait_for_completion=self.wait_for_completion,
+ wait_for_completion=False, # we handle this here
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
)
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Sagemaker Tuning Job creation failed:
{response}")
+
+ if self.deferrable:
+ self.defer(
+ trigger=SageMakerTrigger(
+ job_name=self.config["HyperParameterTuningJobName"],
+ job_type="tuning",
+ poke_interval=self.check_interval,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ timeout=datetime.timedelta(seconds=self.max_ingestion_time)
+ if self.max_ingestion_time is not None
+ else None,
+ )
+ description = {} # never executed but makes static checkers happy
+ elif self.wait_for_completion:
+ description = self.hook.check_status(
+ self.config["HyperParameterTuningJobName"],
+ "HyperParameterTuningJobStatus",
+ self.hook.describe_tuning_job,
+ self.check_interval,
+ self.max_ingestion_time,
+ )
else:
- return {
- "Tuning":
serialize(self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"]))
- }
+ description =
self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"])
+
+ return {"Tuning": serialize(description)}
+
+ def execute_complete(self, context, event=None):
+ if event["status"] != "success":
+ raise AirflowException(f"Error while running job: {event}")
+ return {
+ "Tuning":
serialize(self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"]))
+ }
class SageMakerModelOperator(SageMakerBaseOperator):
diff --git a/airflow/providers/amazon/aws/triggers/sagemaker.py
b/airflow/providers/amazon/aws/triggers/sagemaker.py
index 92266cad5f..ca511a4a46 100644
--- a/airflow/providers/amazon/aws/triggers/sagemaker.py
+++ b/airflow/providers/amazon/aws/triggers/sagemaker.py
@@ -21,6 +21,7 @@ from functools import cached_property
from typing import Any
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -41,7 +42,7 @@ class SageMakerTrigger(BaseTrigger):
job_name: str,
job_type: str,
poke_interval: int = 30,
- max_attempts: int | None = None,
+ max_attempts: int = 480,
aws_conn_id: str = "aws_default",
):
super().__init__()
@@ -74,14 +75,28 @@ class SageMakerTrigger(BaseTrigger):
"training": "TrainingJobComplete",
"transform": "TransformJobComplete",
"processing": "ProcessingJobComplete",
+ "tuning": "TuningJobComplete",
+ "endpoint": "endpoint_in_service", # this one is provided by boto
}[job_type.lower()]
@staticmethod
- def _get_job_type_waiter_job_name_arg(job_type: str) -> str:
+ def _get_waiter_arg_name(job_type: str) -> str:
return {
"training": "TrainingJobName",
"transform": "TransformJobName",
"processing": "ProcessingJobName",
+ "tuning": "HyperParameterTuningJobName",
+ "endpoint": "EndpointName",
+ }[job_type.lower()]
+
+ @staticmethod
+ def _get_response_status_key(job_type: str) -> str:
+ return {
+ "training": "TrainingJobStatus",
+ "transform": "TransformJobStatus",
+ "processing": "ProcessingJobStatus",
+ "tuning": "HyperParameterTuningJobStatus",
+ "endpoint": "EndpointStatus",
}[job_type.lower()]
async def run(self):
@@ -90,12 +105,13 @@ class SageMakerTrigger(BaseTrigger):
waiter = self.hook.get_waiter(
self._get_job_type_waiter(self.job_type), deferrable=True,
client=client
)
- waiter_args = {
- self._get_job_type_waiter_job_name_arg(self.job_type):
self.job_name,
- "WaiterConfig": {
- "Delay": self.poke_interval,
- "MaxAttempts": self.max_attempts,
- },
- }
- await waiter.wait(**waiter_args)
- yield TriggerEvent({"status": "success", "message": "Job completed."})
+ await async_wait(
+ waiter=waiter,
+ waiter_delay=self.poke_interval,
+ waiter_max_attempts=self.max_attempts,
+ args={self._get_waiter_arg_name(self.job_type): self.job_name},
+ failure_message=f"Error while waiting for {self.job_type} job",
+ status_message=f"{self.job_type} job not done yet",
+ status_args=[self._get_response_status_key(self.job_type)],
+ )
+ yield TriggerEvent({"status": "success", "message": "Job
completed."})
diff --git a/airflow/providers/amazon/aws/waiters/sagemaker.json
b/airflow/providers/amazon/aws/waiters/sagemaker.json
index 73e3f09925..2c2760982c 100644
--- a/airflow/providers/amazon/aws/waiters/sagemaker.json
+++ b/airflow/providers/amazon/aws/waiters/sagemaker.json
@@ -78,6 +78,32 @@
"state": "failure"
}
]
+ },
+ "TuningJobComplete": {
+ "delay": 30,
+ "operation": "DescribeHyperParameterTuningJob",
+ "maxAttempts": 60,
+ "description": "Wait until job is COMPLETED",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "HyperParameterTuningJobStatus",
+ "expected": "Completed",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "HyperParameterTuningJobStatus",
+ "expected": "Failed",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "HyperParameterTuningJobStatus",
+ "expected": "Stopped",
+ "state": "failure"
+ }
+ ]
}
}
}
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
index 498a38b816..8a566535b9 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
@@ -22,10 +22,11 @@ from unittest import mock
import pytest
from botocore.exceptions import ClientError
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerEndpointOperator
+from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
CREATE_MODEL_PARAMS: dict = {
"ModelName": "model_name",
@@ -83,12 +84,12 @@ class TestSageMakerEndpointOperator:
@mock.patch.object(sagemaker, "serialize", return_value="")
def test_execute(self, serialize, mock_endpoint, mock_endpoint_config,
mock_model, mock_client):
mock_endpoint.return_value = {"EndpointArn": "test_arn",
"ResponseMetadata": {"HTTPStatusCode": 200}}
+
self.sagemaker.execute(None)
+
mock_model.assert_called_once_with(CREATE_MODEL_PARAMS)
mock_endpoint_config.assert_called_once_with(CREATE_ENDPOINT_CONFIG_PARAMS)
- mock_endpoint.assert_called_once_with(
- CREATE_ENDPOINT_PARAMS, wait_for_completion=False,
check_interval=5, max_ingestion_time=None
- )
+ mock_endpoint.assert_called_once_with(CREATE_ENDPOINT_PARAMS,
wait_for_completion=False)
assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
for variant in
self.sagemaker.config["EndpointConfig"]["ProductionVariants"]:
assert variant["InitialInstanceCount"] ==
int(variant["InitialInstanceCount"])
@@ -120,3 +121,18 @@ class TestSageMakerEndpointOperator:
"ResponseMetadata": {"HTTPStatusCode": 200},
}
self.sagemaker.execute(None)
+
+ @mock.patch.object(SageMakerHook, "create_model")
+ @mock.patch.object(SageMakerHook, "create_endpoint_config")
+ @mock.patch.object(SageMakerHook, "create_endpoint")
+ def test_deferred(self, mock_create_endpoint, _, __):
+ self.sagemaker.deferrable = True
+
+ mock_create_endpoint.return_value = {"ResponseMetadata":
{"HTTPStatusCode": 200}}
+
+ with pytest.raises(TaskDeferred) as defer:
+ self.sagemaker.execute(None)
+
+ assert isinstance(defer.value.trigger, SageMakerTrigger)
+ assert defer.value.trigger.job_name == "endpoint_name"
+ assert defer.value.trigger.job_type == "endpoint"
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
index 4862b930f1..4d6805ec74 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
@@ -21,10 +21,11 @@ from unittest import mock
import pytest
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerTuningOperator
+from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
EXPECTED_INTEGER_FIELDS: list[list[str]] = [
["HyperParameterTuningJobConfig", "ResourceLimits",
"MaxNumberOfTrainingJobs"],
@@ -107,3 +108,15 @@ class TestSageMakerTuningOperator:
mock_tuning.return_value = {"TrainingJobArn": "test_arn",
"ResponseMetadata": {"HTTPStatusCode": 404}}
with pytest.raises(AirflowException):
self.sagemaker.execute(None)
+
+ @mock.patch.object(SageMakerHook, "create_tuning_job")
+ def test_defers(self, create_mock):
+ create_mock.return_value = {"ResponseMetadata": {"HTTPStatusCode":
200}}
+ self.sagemaker.deferrable = True
+
+ with pytest.raises(TaskDeferred) as defer:
+ self.sagemaker.execute(None)
+
+ assert isinstance(defer.value.trigger, SageMakerTrigger)
+ assert defer.value.trigger.job_name == "job_name"
+ assert defer.value.trigger.job_type == "tuning"
diff --git a/tests/providers/amazon/aws/triggers/test_sagemaker.py
b/tests/providers/amazon/aws/triggers/test_sagemaker.py
index 5a7f8e3c8e..f2d05f85a6 100644
--- a/tests/providers/amazon/aws/triggers/test_sagemaker.py
+++ b/tests/providers/amazon/aws/triggers/test_sagemaker.py
@@ -49,28 +49,26 @@ class TestSagemakerTrigger:
assert args["aws_conn_id"] == AWS_CONN_ID
@pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "job_type",
+ [
+ "training",
+ "transform",
+ "processing",
+ "tuning",
+ "endpoint",
+ ],
+ )
@mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.get_waiter")
@mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.async_conn")
-
@mock.patch("airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger._get_job_type_waiter")
- @mock.patch(
-
"airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger._get_job_type_waiter_job_name_arg"
- )
- async def test_sagemaker_trigger_run(
- self,
- mock_get_job_type_waiter_job_name_arg,
- mock_get_job_type_waiter,
- mock_async_conn,
- mock_get_waiter,
- ):
- mock_get_job_type_waiter_job_name_arg.return_value = "job_name"
- mock_get_job_type_waiter.return_value = "waiter"
+ async def test_sagemaker_trigger_run_all_job_types(self, mock_async_conn,
mock_get_waiter, job_type):
mock_async_conn.__aenter__.return_value = mock.MagicMock()
mock_get_waiter().wait = AsyncMock()
sagemaker_trigger = SageMakerTrigger(
job_name=JOB_NAME,
- job_type=JOB_TYPE,
+ job_type=job_type,
poke_interval=POKE_INTERVAL,
max_attempts=MAX_ATTEMPTS,
aws_conn_id=AWS_CONN_ID,
diff --git a/tests/system/providers/amazon/aws/example_sagemaker.py
b/tests/system/providers/amazon/aws/example_sagemaker.py
index 9506970446..2b0f3fc6ef 100644
--- a/tests/system/providers/amazon/aws/example_sagemaker.py
+++ b/tests/system/providers/amazon/aws/example_sagemaker.py
@@ -159,12 +159,11 @@ def _build_and_upload_docker_image(preprocess_script,
repository_uri):
docker_build_and_push_commands = f"""
cp /root/.aws/credentials /tmp/credentials &&
# login to public ecr repo containing amazonlinux image
- docker login --username {creds.username} --password
{creds.password} public.ecr.aws
+ docker login --username {creds.username} --password
{creds.password} public.ecr.aws &&
docker build --platform=linux/amd64 -f {dockerfile.name} -t
{repository_uri} /tmp &&
rm /tmp/credentials &&
# login again, this time to the private repo we created to hold
that specific image
- aws ecr get-login-password --region {ecr_region} |
docker login --username {creds.username} --password
{creds.password} {repository_uri} &&
docker push {repository_uri}
"""
@@ -178,7 +177,8 @@ def _build_and_upload_docker_image(preprocess_script,
repository_uri):
if docker_build.returncode != 0:
raise RuntimeError(
"Failed to prepare docker image for the preprocessing job.\n"
- f"The following error happened while executing the sequence of
bash commands:\n{stderr}"
+ "The following error happened while executing the sequence of
bash commands:\n"
+ f"{stderr.decode()}"
)