This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 14785bc84c deferrable mode for `SageMakerTuningOperator` and 
`SageMakerEndpointOperator` (#32112)
14785bc84c is described below

commit 14785bc84c984b8747fa062b84e800d22ddc0477
Author: RaphaĆ«l Vandon <[email protected]>
AuthorDate: Tue Jun 27 14:04:09 2023 -0700

    deferrable mode for `SageMakerTuningOperator` and 
`SageMakerEndpointOperator` (#32112)
---
 .../providers/amazon/aws/operators/sagemaker.py    | 98 ++++++++++++++++++----
 airflow/providers/amazon/aws/triggers/sagemaker.py | 38 ++++++---
 .../providers/amazon/aws/waiters/sagemaker.json    | 26 ++++++
 .../aws/operators/test_sagemaker_endpoint.py       | 24 +++++-
 .../amazon/aws/operators/test_sagemaker_tuning.py  | 15 +++-
 .../amazon/aws/triggers/test_sagemaker.py          | 26 +++---
 .../providers/amazon/aws/example_sagemaker.py      |  6 +-
 7 files changed, 182 insertions(+), 51 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 0d4ba9fcab..8467459188 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import datetime
 import json
 import time
 import warnings
@@ -375,6 +376,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
         finish within max_ingestion_time seconds. If you set this parameter to 
None it never times out.
     :param operation: Whether to create an endpoint or update an endpoint. 
Must be either 'create or 'update'.
     :param aws_conn_id: The AWS connection ID to use.
+    :param deferrable:  Will wait asynchronously for completion.
     :return Dict: Returns The ARN of the endpoint created in Amazon SageMaker.
     """
 
@@ -387,15 +389,17 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
         check_interval: int = CHECK_INTERVAL_SECOND,
         max_ingestion_time: int | None = None,
         operation: str = "create",
+        deferrable: bool = False,
         **kwargs,
     ):
         super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
         self.wait_for_completion = wait_for_completion
         self.check_interval = check_interval
-        self.max_ingestion_time = max_ingestion_time
+        self.max_ingestion_time = max_ingestion_time or 3600 * 10
         self.operation = operation.lower()
         if self.operation not in ["create", "update"]:
             raise ValueError('Invalid value! Argument operation has to be one 
of "create" and "update"')
+        self.deferrable = deferrable
 
     def _create_integer_fields(self) -> None:
         """Set fields which should be cast to integers."""
@@ -436,29 +440,54 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
         try:
             response = sagemaker_operation(
                 endpoint_info,
-                wait_for_completion=self.wait_for_completion,
-                check_interval=self.check_interval,
-                max_ingestion_time=self.max_ingestion_time,
+                wait_for_completion=False,
             )
+            # waiting for completion is handled here in the operator
         except ClientError:
             self.operation = "update"
             sagemaker_operation = self.hook.update_endpoint
-            log_str = "Updating"
             response = sagemaker_operation(
                 endpoint_info,
-                wait_for_completion=self.wait_for_completion,
-                check_interval=self.check_interval,
-                max_ingestion_time=self.max_ingestion_time,
+                wait_for_completion=False,
             )
+
         if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
             raise AirflowException(f"Sagemaker endpoint creation failed: 
{response}")
-        else:
-            return {
-                "EndpointConfig": serialize(
-                    
self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"])
+
+        if self.deferrable:
+            self.defer(
+                trigger=SageMakerTrigger(
+                    job_name=endpoint_info["EndpointName"],
+                    job_type="endpoint",
+                    poke_interval=self.check_interval,
+                    aws_conn_id=self.aws_conn_id,
                 ),
-                "Endpoint": 
serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
-            }
+                method_name="execute_complete",
+                timeout=datetime.timedelta(seconds=self.max_ingestion_time),
+            )
+        elif self.wait_for_completion:
+            self.hook.get_waiter("endpoint_in_service").wait(
+                EndpointName=endpoint_info["EndpointName"],
+                WaiterConfig={"Delay": self.check_interval, "MaxAttempts": 
self.max_ingestion_time},
+            )
+
+        return {
+            "EndpointConfig": serialize(
+                
self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"])
+            ),
+            "Endpoint": 
serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
+        }
+
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error while running job: {event}")
+        endpoint_info = self.config.get("Endpoint", self.config)
+        return {
+            "EndpointConfig": serialize(
+                
self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"])
+            ),
+            "Endpoint": 
serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
+        }
 
 
 class SageMakerTransformOperator(SageMakerBaseOperator):
@@ -652,6 +681,7 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
     :param max_ingestion_time: If wait is set to True, the operation fails
         if the tuning job doesn't finish within max_ingestion_time seconds. If 
you
         set this parameter to None, the operation does not timeout.
+    :param deferrable: Will wait asynchronously for completion.
     :return Dict: Returns The ARN of the tuning job created in Amazon 
SageMaker.
     """
 
@@ -663,12 +693,14 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
         wait_for_completion: bool = True,
         check_interval: int = CHECK_INTERVAL_SECOND,
         max_ingestion_time: int | None = None,
+        deferrable: bool = False,
         **kwargs,
     ):
         super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
         self.wait_for_completion = wait_for_completion
         self.check_interval = check_interval
         self.max_ingestion_time = max_ingestion_time
+        self.deferrable = deferrable
 
     def expand_role(self) -> None:
         """Expands an IAM role name into an ARN."""
@@ -695,16 +727,46 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
         )
         response = self.hook.create_tuning_job(
             self.config,
-            wait_for_completion=self.wait_for_completion,
+            wait_for_completion=False,  # we handle this here
             check_interval=self.check_interval,
             max_ingestion_time=self.max_ingestion_time,
         )
         if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
             raise AirflowException(f"Sagemaker Tuning Job creation failed: 
{response}")
+
+        if self.deferrable:
+            self.defer(
+                trigger=SageMakerTrigger(
+                    job_name=self.config["HyperParameterTuningJobName"],
+                    job_type="tuning",
+                    poke_interval=self.check_interval,
+                    aws_conn_id=self.aws_conn_id,
+                ),
+                method_name="execute_complete",
+                timeout=datetime.timedelta(seconds=self.max_ingestion_time)
+                if self.max_ingestion_time is not None
+                else None,
+            )
+            description = {}  # never executed but makes static checkers happy
+        elif self.wait_for_completion:
+            description = self.hook.check_status(
+                self.config["HyperParameterTuningJobName"],
+                "HyperParameterTuningJobStatus",
+                self.hook.describe_tuning_job,
+                self.check_interval,
+                self.max_ingestion_time,
+            )
         else:
-            return {
-                "Tuning": 
serialize(self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"]))
-            }
+            description = 
self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"])
+
+        return {"Tuning": serialize(description)}
+
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error while running job: {event}")
+        return {
+            "Tuning": 
serialize(self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"]))
+        }
 
 
 class SageMakerModelOperator(SageMakerBaseOperator):
diff --git a/airflow/providers/amazon/aws/triggers/sagemaker.py 
b/airflow/providers/amazon/aws/triggers/sagemaker.py
index 92266cad5f..ca511a4a46 100644
--- a/airflow/providers/amazon/aws/triggers/sagemaker.py
+++ b/airflow/providers/amazon/aws/triggers/sagemaker.py
@@ -21,6 +21,7 @@ from functools import cached_property
 from typing import Any
 
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 
 
@@ -41,7 +42,7 @@ class SageMakerTrigger(BaseTrigger):
         job_name: str,
         job_type: str,
         poke_interval: int = 30,
-        max_attempts: int | None = None,
+        max_attempts: int = 480,
         aws_conn_id: str = "aws_default",
     ):
         super().__init__()
@@ -74,14 +75,28 @@ class SageMakerTrigger(BaseTrigger):
             "training": "TrainingJobComplete",
             "transform": "TransformJobComplete",
             "processing": "ProcessingJobComplete",
+            "tuning": "TuningJobComplete",
+            "endpoint": "endpoint_in_service",  # this one is provided by boto
         }[job_type.lower()]
 
     @staticmethod
-    def _get_job_type_waiter_job_name_arg(job_type: str) -> str:
+    def _get_waiter_arg_name(job_type: str) -> str:
         return {
             "training": "TrainingJobName",
             "transform": "TransformJobName",
             "processing": "ProcessingJobName",
+            "tuning": "HyperParameterTuningJobName",
+            "endpoint": "EndpointName",
+        }[job_type.lower()]
+
+    @staticmethod
+    def _get_response_status_key(job_type: str) -> str:
+        return {
+            "training": "TrainingJobStatus",
+            "transform": "TransformJobStatus",
+            "processing": "ProcessingJobStatus",
+            "tuning": "HyperParameterTuningJobStatus",
+            "endpoint": "EndpointStatus",
         }[job_type.lower()]
 
     async def run(self):
@@ -90,12 +105,13 @@ class SageMakerTrigger(BaseTrigger):
             waiter = self.hook.get_waiter(
                 self._get_job_type_waiter(self.job_type), deferrable=True, 
client=client
             )
-            waiter_args = {
-                self._get_job_type_waiter_job_name_arg(self.job_type): 
self.job_name,
-                "WaiterConfig": {
-                    "Delay": self.poke_interval,
-                    "MaxAttempts": self.max_attempts,
-                },
-            }
-            await waiter.wait(**waiter_args)
-        yield TriggerEvent({"status": "success", "message": "Job completed."})
+            await async_wait(
+                waiter=waiter,
+                waiter_delay=self.poke_interval,
+                waiter_max_attempts=self.max_attempts,
+                args={self._get_waiter_arg_name(self.job_type): self.job_name},
+                failure_message=f"Error while waiting for {self.job_type} job",
+                status_message=f"{self.job_type} job not done yet",
+                status_args=[self._get_response_status_key(self.job_type)],
+            )
+            yield TriggerEvent({"status": "success", "message": "Job 
completed."})
diff --git a/airflow/providers/amazon/aws/waiters/sagemaker.json 
b/airflow/providers/amazon/aws/waiters/sagemaker.json
index 73e3f09925..2c2760982c 100644
--- a/airflow/providers/amazon/aws/waiters/sagemaker.json
+++ b/airflow/providers/amazon/aws/waiters/sagemaker.json
@@ -78,6 +78,32 @@
                     "state": "failure"
                 }
             ]
+        },
+        "TuningJobComplete": {
+            "delay": 30,
+            "operation": "DescribeHyperParameterTuningJob",
+            "maxAttempts": 60,
+            "description": "Wait until job is COMPLETED",
+            "acceptors": [
+                {
+                    "matcher": "path",
+                    "argument": "HyperParameterTuningJobStatus",
+                    "expected": "Completed",
+                    "state": "success"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "HyperParameterTuningJobStatus",
+                    "expected": "Failed",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "HyperParameterTuningJobStatus",
+                    "expected": "Stopped",
+                    "state": "failure"
+                }
+            ]
         }
     }
 }
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
index 498a38b816..8a566535b9 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
@@ -22,10 +22,11 @@ from unittest import mock
 import pytest
 from botocore.exceptions import ClientError
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
 from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerEndpointOperator
+from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
 
 CREATE_MODEL_PARAMS: dict = {
     "ModelName": "model_name",
@@ -83,12 +84,12 @@ class TestSageMakerEndpointOperator:
     @mock.patch.object(sagemaker, "serialize", return_value="")
     def test_execute(self, serialize, mock_endpoint, mock_endpoint_config, 
mock_model, mock_client):
         mock_endpoint.return_value = {"EndpointArn": "test_arn", 
"ResponseMetadata": {"HTTPStatusCode": 200}}
+
         self.sagemaker.execute(None)
+
         mock_model.assert_called_once_with(CREATE_MODEL_PARAMS)
         
mock_endpoint_config.assert_called_once_with(CREATE_ENDPOINT_CONFIG_PARAMS)
-        mock_endpoint.assert_called_once_with(
-            CREATE_ENDPOINT_PARAMS, wait_for_completion=False, 
check_interval=5, max_ingestion_time=None
-        )
+        mock_endpoint.assert_called_once_with(CREATE_ENDPOINT_PARAMS, 
wait_for_completion=False)
         assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
         for variant in 
self.sagemaker.config["EndpointConfig"]["ProductionVariants"]:
             assert variant["InitialInstanceCount"] == 
int(variant["InitialInstanceCount"])
@@ -120,3 +121,18 @@ class TestSageMakerEndpointOperator:
             "ResponseMetadata": {"HTTPStatusCode": 200},
         }
         self.sagemaker.execute(None)
+
+    @mock.patch.object(SageMakerHook, "create_model")
+    @mock.patch.object(SageMakerHook, "create_endpoint_config")
+    @mock.patch.object(SageMakerHook, "create_endpoint")
+    def test_deferred(self, mock_create_endpoint, _, __):
+        self.sagemaker.deferrable = True
+
+        mock_create_endpoint.return_value = {"ResponseMetadata": 
{"HTTPStatusCode": 200}}
+
+        with pytest.raises(TaskDeferred) as defer:
+            self.sagemaker.execute(None)
+
+        assert isinstance(defer.value.trigger, SageMakerTrigger)
+        assert defer.value.trigger.job_name == "endpoint_name"
+        assert defer.value.trigger.job_type == "endpoint"
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
index 4862b930f1..4d6805ec74 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
@@ -21,10 +21,11 @@ from unittest import mock
 
 import pytest
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
 from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerTuningOperator
+from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
 
 EXPECTED_INTEGER_FIELDS: list[list[str]] = [
     ["HyperParameterTuningJobConfig", "ResourceLimits", 
"MaxNumberOfTrainingJobs"],
@@ -107,3 +108,15 @@ class TestSageMakerTuningOperator:
         mock_tuning.return_value = {"TrainingJobArn": "test_arn", 
"ResponseMetadata": {"HTTPStatusCode": 404}}
         with pytest.raises(AirflowException):
             self.sagemaker.execute(None)
+
+    @mock.patch.object(SageMakerHook, "create_tuning_job")
+    def test_defers(self, create_mock):
+        create_mock.return_value = {"ResponseMetadata": {"HTTPStatusCode": 
200}}
+        self.sagemaker.deferrable = True
+
+        with pytest.raises(TaskDeferred) as defer:
+            self.sagemaker.execute(None)
+
+        assert isinstance(defer.value.trigger, SageMakerTrigger)
+        assert defer.value.trigger.job_name == "job_name"
+        assert defer.value.trigger.job_type == "tuning"
diff --git a/tests/providers/amazon/aws/triggers/test_sagemaker.py 
b/tests/providers/amazon/aws/triggers/test_sagemaker.py
index 5a7f8e3c8e..f2d05f85a6 100644
--- a/tests/providers/amazon/aws/triggers/test_sagemaker.py
+++ b/tests/providers/amazon/aws/triggers/test_sagemaker.py
@@ -49,28 +49,26 @@ class TestSagemakerTrigger:
         assert args["aws_conn_id"] == AWS_CONN_ID
 
     @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "job_type",
+        [
+            "training",
+            "transform",
+            "processing",
+            "tuning",
+            "endpoint",
+        ],
+    )
     
@mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.get_waiter")
     
@mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.async_conn")
-    
@mock.patch("airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger._get_job_type_waiter")
-    @mock.patch(
-        
"airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger._get_job_type_waiter_job_name_arg"
-    )
-    async def test_sagemaker_trigger_run(
-        self,
-        mock_get_job_type_waiter_job_name_arg,
-        mock_get_job_type_waiter,
-        mock_async_conn,
-        mock_get_waiter,
-    ):
-        mock_get_job_type_waiter_job_name_arg.return_value = "job_name"
-        mock_get_job_type_waiter.return_value = "waiter"
+    async def test_sagemaker_trigger_run_all_job_types(self, mock_async_conn, 
mock_get_waiter, job_type):
         mock_async_conn.__aenter__.return_value = mock.MagicMock()
 
         mock_get_waiter().wait = AsyncMock()
 
         sagemaker_trigger = SageMakerTrigger(
             job_name=JOB_NAME,
-            job_type=JOB_TYPE,
+            job_type=job_type,
             poke_interval=POKE_INTERVAL,
             max_attempts=MAX_ATTEMPTS,
             aws_conn_id=AWS_CONN_ID,
diff --git a/tests/system/providers/amazon/aws/example_sagemaker.py 
b/tests/system/providers/amazon/aws/example_sagemaker.py
index 9506970446..2b0f3fc6ef 100644
--- a/tests/system/providers/amazon/aws/example_sagemaker.py
+++ b/tests/system/providers/amazon/aws/example_sagemaker.py
@@ -159,12 +159,11 @@ def _build_and_upload_docker_image(preprocess_script, 
repository_uri):
         docker_build_and_push_commands = f"""
             cp /root/.aws/credentials /tmp/credentials &&
             # login to public ecr repo containing amazonlinux image
-            docker login --username {creds.username} --password 
{creds.password} public.ecr.aws
+            docker login --username {creds.username} --password 
{creds.password} public.ecr.aws &&
             docker build --platform=linux/amd64 -f {dockerfile.name} -t 
{repository_uri} /tmp &&
             rm /tmp/credentials &&
 
             # login again, this time to the private repo we created to hold 
that specific image
-            aws ecr get-login-password --region {ecr_region} |
             docker login --username {creds.username} --password 
{creds.password} {repository_uri} &&
             docker push {repository_uri}
             """
@@ -178,7 +177,8 @@ def _build_and_upload_docker_image(preprocess_script, 
repository_uri):
         if docker_build.returncode != 0:
             raise RuntimeError(
                 "Failed to prepare docker image for the preprocessing job.\n"
-                f"The following error happened while executing the sequence of 
bash commands:\n{stderr}"
+                "The following error happened while executing the sequence of 
bash commands:\n"
+                f"{stderr.decode()}"
             )
 
 

Reply via email to