This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4c9b5fe4c1 Add deferrable param in SageMakerTransformOperator (#31063)
4c9b5fe4c1 is described below

commit 4c9b5fe4c15ff9d813a34e5f31b5e2910f70cff8
Author: Pankaj Koti <[email protected]>
AuthorDate: Tue May 9 12:13:47 2023 +0530

    Add deferrable param in SageMakerTransformOperator (#31063)
    
    This will allow running SageMakerTransformOperator in an async
    fashion meaning that we only submit a job from the worker to
    run a job and then defer to the trigger for polling to wait for
    the job status reaching a terminal state. This way, the worker
    slot won't be occupied for the whole period of task execution.
---
 .../providers/amazon/aws/operators/sagemaker.py    | 53 +++++++++++++++++++---
 .../aws/operators/test_sagemaker_transform.py      | 17 ++++++-
 2 files changed, 62 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 587bf5b0a1..f4041b465a 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -493,6 +493,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
     :param wait_for_completion: Set to True to wait until the transform job 
finishes.
     :param check_interval: If wait is set to True, the time interval, in 
seconds,
         that this operation waits to check the status of the transform job.
+    :param max_attempts: Number of times to poll for query state before 
returning the current state,
+        defaults to None.
     :param max_ingestion_time: If wait is set to True, the operation fails
         if the transform job doesn't finish within max_ingestion_time seconds. 
If you
         set this parameter to None, the operation does not timeout.
@@ -511,14 +513,17 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
         aws_conn_id: str = DEFAULT_CONN_ID,
         wait_for_completion: bool = True,
         check_interval: int = CHECK_INTERVAL_SECOND,
+        max_attempts: int | None = None,
         max_ingestion_time: int | None = None,
         check_if_job_exists: bool = True,
         action_if_job_exists: str = "timestamp",
+        deferrable: bool = False,
         **kwargs,
     ):
         super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
         self.wait_for_completion = wait_for_completion
         self.check_interval = check_interval
+        self.max_attempts = max_attempts or 60
         self.max_ingestion_time = max_ingestion_time
         self.check_if_job_exists = check_if_job_exists
         if action_if_job_exists in ("increment", "fail", "timestamp"):
@@ -535,6 +540,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
                 f"Argument action_if_job_exists accepts only 'timestamp', 
'increment' and 'fail'. \
                 Provided value: '{action_if_job_exists}'."
             )
+        self.deferrable = deferrable
 
     def _create_integer_fields(self) -> None:
         """Set fields which should be cast to integers."""
@@ -573,21 +579,54 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
             self.hook.create_model(model_config)
 
         self.log.info("Creating SageMaker transform Job %s.", 
transform_config["TransformJobName"])
+
+        if self.deferrable and not self.wait_for_completion:
+            self.log.warning(
+                "Setting deferrable to True does not have effect when 
wait_for_completion is set to False."
+            )
+
+        wait_for_completion = self.wait_for_completion
+        if self.deferrable and self.wait_for_completion:
+            # Set wait_for_completion to False so that it waits for the status 
in the deferred task.
+            wait_for_completion = False
+
         response = self.hook.create_transform_job(
             transform_config,
-            wait_for_completion=self.wait_for_completion,
+            wait_for_completion=wait_for_completion,
             check_interval=self.check_interval,
             max_ingestion_time=self.max_ingestion_time,
         )
         if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
             raise AirflowException(f"Sagemaker transform Job creation failed: 
{response}")
-        else:
-            return {
-                "Model": 
serialize(self.hook.describe_model(transform_config["ModelName"])),
-                "Transform": serialize(
-                    
self.hook.describe_transform_job(transform_config["TransformJobName"])
+
+        if self.deferrable and self.wait_for_completion:
+            self.defer(
+                timeout=self.execution_timeout,
+                trigger=SageMakerTrigger(
+                    job_name=transform_config["TransformJobName"],
+                    job_type="Transform",
+                    poke_interval=self.check_interval,
+                    max_attempts=self.max_attempts,
+                    aws_conn_id=self.aws_conn_id,
                 ),
-            }
+                method_name="execute_complete",
+            )
+
+        return {
+            "Model": 
serialize(self.hook.describe_model(transform_config["ModelName"])),
+            "Transform": 
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
+        }
+
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error while running job: {event}")
+        else:
+            self.log.info(event["message"])
+        transform_config = self.config.get("Transform", self.config)
+        return {
+            "Model": 
serialize(self.hook.describe_model(transform_config["ModelName"])),
+            "Transform": 
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
+        }
 
 
 class SageMakerTuningOperator(SageMakerBaseOperator):
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index 482c7201cc..76a4d877b6 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -23,10 +23,11 @@ from unittest import mock
 import pytest
 from botocore.exceptions import ClientError
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
 from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerTransformOperator
+from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
 
 EXPECTED_INTEGER_FIELDS: list[list[str]] = [
     ["Transform", "TransformResources", "InstanceCount"],
@@ -163,3 +164,17 @@ class TestSageMakerTransformOperator:
             check_interval=5,
             max_ingestion_time=None,
         )
+
+    @mock.patch.object(SageMakerHook, "create_transform_job")
+    @mock.patch.object(SageMakerHook, "create_model")
+    def test_operator_defer(self, _, mock_transform):
+        mock_transform.return_value = {
+            "TransformJobArn": "test_arn",
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        self.sagemaker.deferrable = True
+        self.sagemaker.wait_for_completion = True
+        self.sagemaker.check_if_job_exists = False
+        with pytest.raises(TaskDeferred) as exc:
+            self.sagemaker.execute(context=None)
+        assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"

Reply via email to