This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a809c91528 Add deferrable param in SageMakerTrainingOperator (#31042)
a809c91528 is described below

commit a809c91528bebacdb5c3bac75ae3c7bf33a99308
Author: Pankaj Koti <[email protected]>
AuthorDate: Mon May 8 23:49:55 2023 +0530

    Add deferrable param in SageMakerTrainingOperator (#31042)
    
    * Add deferrable param in SageMakerTrainingOperator
    
    This will allow running SageMakerTrainingOperator in an async fashion
    meaning that we only submit a job from the worker to run a job and
    then defer to the trigger for polling to wait for the job status reaching
    a terminal state. This way, the worker slot won't be occupied for the
    whole period of task execution.
---
 .../providers/amazon/aws/operators/sagemaker.py    | 45 +++++++++++-
 .../aws/operators/test_sagemaker_training.py       | 16 ++++-
 .../amazon/aws/triggers/test_sagemaker.py          | 81 ++++++++++++++++++++++
 3 files changed, 139 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 4efb1a863b..587bf5b0a1 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -723,6 +723,8 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
     :param print_log: if the operator should print the cloudwatch log during 
training
     :param check_interval: if wait is set to be true, this is the time interval
         in seconds which the operator will check the status of the training job
+    :param max_attempts: Number of times to poll for query state before 
returning the current state,
+        defaults to None.
     :param max_ingestion_time: If wait is set to True, the operation fails if 
the training job
         doesn't finish within max_ingestion_time seconds. If you set this 
parameter to None,
         the operation does not timeout.
@@ -731,6 +733,8 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
     :param action_if_job_exists: Behaviour if the job name already exists. 
Possible options are "timestamp"
         (default), "increment" (deprecated) and "fail".
         This is only relevant if check_if_job_exists is True.
+    :param deferrable: Run operator in the deferrable mode. This is only 
effective if wait_for_completion is
+        set to True.
     :return Dict: Returns The ARN of the training job created in Amazon 
SageMaker.
     """
 
@@ -742,15 +746,18 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
         wait_for_completion: bool = True,
         print_log: bool = True,
         check_interval: int = CHECK_INTERVAL_SECOND,
+        max_attempts: int | None = None,
         max_ingestion_time: int | None = None,
         check_if_job_exists: bool = True,
         action_if_job_exists: str = "timestamp",
+        deferrable: bool = False,
         **kwargs,
     ):
         super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
         self.wait_for_completion = wait_for_completion
         self.print_log = print_log
         self.check_interval = check_interval
+        self.max_attempts = max_attempts or 60
         self.max_ingestion_time = max_ingestion_time
         self.check_if_job_exists = check_if_job_exists
         if action_if_job_exists in {"timestamp", "increment", "fail"}:
@@ -767,6 +774,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
                 f"Argument action_if_job_exists accepts only 'timestamp', 
'increment' and 'fail'. \
                 Provided value: '{action_if_job_exists}'."
             )
+        self.deferrable = deferrable
 
     def expand_role(self) -> None:
         """Expands an IAM role name into an ARN."""
@@ -793,17 +801,50 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
             )
 
         self.log.info("Creating SageMaker training job %s.", 
self.config["TrainingJobName"])
+
+        if self.deferrable and not self.wait_for_completion:
+            self.log.warning(
+                "Setting deferrable to True does not have effect when 
wait_for_completion is set to False."
+            )
+
+        wait_for_completion = self.wait_for_completion
+        if self.deferrable and self.wait_for_completion:
+            # Set wait_for_completion to False so that it waits for the status 
in the deferred task.
+            wait_for_completion = False
+
         response = self.hook.create_training_job(
             self.config,
-            wait_for_completion=self.wait_for_completion,
+            wait_for_completion=wait_for_completion,
             print_log=self.print_log,
             check_interval=self.check_interval,
             max_ingestion_time=self.max_ingestion_time,
         )
         if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
             raise AirflowException(f"Sagemaker Training Job creation failed: 
{response}")
+
+        if self.deferrable and self.wait_for_completion:
+            self.defer(
+                timeout=self.execution_timeout,
+                trigger=SageMakerTrigger(
+                    job_name=self.config["TrainingJobName"],
+                    job_type="Training",
+                    poke_interval=self.check_interval,
+                    max_attempts=self.max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                ),
+                method_name="execute_complete",
+            )
+
+        result = {"Training": 
serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
+        return result
+
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error while running job: {event}")
         else:
-            return {"Training": 
serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
+            self.log.info(event["message"])
+        result = {"Training": 
serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
+        return result
 
 
 class SageMakerDeleteModelOperator(SageMakerBaseOperator):
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
index 82b76fdee0..e551317d33 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
@@ -21,10 +21,11 @@ from unittest import mock
 import pytest
 from botocore.exceptions import ClientError
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
 from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerTrainingOperator
+from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
 
 EXPECTED_INTEGER_FIELDS: list[list[str]] = [
     ["ResourceConfig", "InstanceCount"],
@@ -113,3 +114,16 @@ class TestSageMakerTrainingOperator:
         }
         with pytest.raises(AirflowException):
             self.sagemaker.execute(None)
+
+    @mock.patch.object(SageMakerHook, "create_training_job")
+    def test_operator_defer(self, mock_training):
+        mock_training.return_value = {
+            "TrainingJobArn": "test_arn",
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        self.sagemaker.deferrable = True
+        self.sagemaker.wait_for_completion = True
+        self.sagemaker.check_if_job_exists = False
+        with pytest.raises(TaskDeferred) as exc:
+            self.sagemaker.execute(context=None)
+        assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"
diff --git a/tests/providers/amazon/aws/triggers/test_sagemaker.py 
b/tests/providers/amazon/aws/triggers/test_sagemaker.py
new file mode 100644
index 0000000000..d8f0bd1cee
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_sagemaker.py
@@ -0,0 +1,81 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.triggers.base import TriggerEvent
+from tests.providers.amazon.aws.utils.compat import AsyncMock, async_mock
+
+JOB_NAME = "job_name"
+JOB_TYPE = "job_type"
+AWS_CONN_ID = "aws_sagemaker_conn"
+POKE_INTERVAL = 30
+MAX_ATTEMPTS = 60
+
+
+class TestSagemakerTrigger:
+    def test_sagemaker_trigger_serialize(self):
+        sagemaker_trigger = SageMakerTrigger(
+            job_name=JOB_NAME,
+            job_type=JOB_TYPE,
+            poke_interval=POKE_INTERVAL,
+            max_attempts=MAX_ATTEMPTS,
+            aws_conn_id=AWS_CONN_ID,
+        )
+        class_path, args = sagemaker_trigger.serialize()
+        assert class_path == 
"airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger"
+        assert args["job_name"] == JOB_NAME
+        assert args["job_type"] == JOB_TYPE
+        assert args["poke_interval"] == POKE_INTERVAL
+        assert args["max_attempts"] == MAX_ATTEMPTS
+        assert args["aws_conn_id"] == AWS_CONN_ID
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.get_waiter")
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.async_conn")
+    
@async_mock.patch("airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger._get_job_type_waiter")
+    @async_mock.patch(
+        
"airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger._get_job_type_waiter_job_name_arg"
+    )
+    async def test_sagemaker_trigger_run(
+        self,
+        mock_get_job_type_waiter_job_name_arg,
+        mock_get_job_type_waiter,
+        mock_async_conn,
+        mock_get_waiter,
+    ):
+        mock_get_job_type_waiter_job_name_arg.return_value = "job_name"
+        mock_get_job_type_waiter.return_value = "waiter"
+        mock = async_mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = mock
+
+        mock_get_waiter().wait = AsyncMock()
+
+        sagemaker_trigger = SageMakerTrigger(
+            job_name=JOB_NAME,
+            job_type=JOB_TYPE,
+            poke_interval=POKE_INTERVAL,
+            max_attempts=MAX_ATTEMPTS,
+            aws_conn_id=AWS_CONN_ID,
+        )
+
+        generator = sagemaker_trigger.run()
+        response = await generator.asend(None)
+
+        assert response == TriggerEvent({"status": "success", "message": "Job 
completed."})

Reply via email to