This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1c144ee141 Add deferrable param in SageMakerProcessingOperator (#31062)
1c144ee141 is described below

commit 1c144ee141059a4c7e0450fd086eced2197568cf
Author: Pankaj Koti <[email protected]>
AuthorDate: Mon May 8 23:49:07 2023 +0530

    Add deferrable param in SageMakerProcessingOperator (#31062)
    
    * Add deferrable param in SageMakerProcessingOperator
    
    This will allow running SageMakerProcessingOperator in an async
    fashion meaning that we only submit a job from the worker to run
    a job and then defer to the trigger for polling to wait for the
    job status reaching a terminal state. This way, the worker slot
    won't be occupied for the whole period of task execution.
    
    * Remove unneeded mocks
---
 .../providers/amazon/aws/operators/sagemaker.py    |  42 ++++++++-
 airflow/providers/amazon/aws/triggers/sagemaker.py | 101 +++++++++++++++++++++
 .../providers/amazon/aws/waiters/sagemaker.json    |  83 +++++++++++++++++
 .../aws/operators/test_sagemaker_processing.py     |  21 ++++-
 4 files changed, 245 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 502aa0ac9c..4efb1a863b 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -28,6 +28,7 @@ from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarni
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
 from airflow.providers.amazon.aws.utils import trim_none_values
 from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
 from airflow.providers.amazon.aws.utils.tags import format_tags
@@ -171,11 +172,15 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
     :param print_log: if the operator should print the cloudwatch log during 
processing
     :param check_interval: if wait is set to be true, this is the time interval
         in seconds which the operator will check the status of the processing 
job
+    :param max_attempts: Number of times to poll for query state before 
returning the current state,
+        defaults to None.
     :param max_ingestion_time: If wait is set to True, the operation fails if 
the processing job
         doesn't finish within max_ingestion_time seconds. If you set this 
parameter to None,
         the operation does not timeout.
     :param action_if_job_exists: Behaviour if the job name already exists. 
Possible options are "timestamp"
         (default), "increment" (deprecated) and "fail".
+    :param deferrable: Run operator in the deferrable mode. This is only 
effective if wait_for_completion is
+        set to True.
     :return Dict: Returns The ARN of the processing job created in Amazon 
SageMaker.
     """
 
@@ -187,8 +192,10 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
         wait_for_completion: bool = True,
         print_log: bool = True,
         check_interval: int = CHECK_INTERVAL_SECOND,
+        max_attempts: int | None = None,
         max_ingestion_time: int | None = None,
         action_if_job_exists: str = "timestamp",
+        deferrable: bool = False,
         **kwargs,
     ):
         super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
@@ -208,7 +215,9 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
         self.wait_for_completion = wait_for_completion
         self.print_log = print_log
         self.check_interval = check_interval
+        self.max_attempts = max_attempts or 60
         self.max_ingestion_time = max_ingestion_time
+        self.deferrable = deferrable
 
     def _create_integer_fields(self) -> None:
         """Set fields which should be cast to integers."""
@@ -234,14 +243,45 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
             self.hook.describe_processing_job,
         )
 
+        if self.deferrable and not self.wait_for_completion:
+            self.log.warning(
+                "Setting deferrable to True does not have effect when 
wait_for_completion is set to False."
+            )
+
+        wait_for_completion = self.wait_for_completion
+        if self.deferrable and self.wait_for_completion:
+            # Set wait_for_completion to False so that it waits for the status 
in the deferred task.
+            wait_for_completion = False
+
         response = self.hook.create_processing_job(
             self.config,
-            wait_for_completion=self.wait_for_completion,
+            wait_for_completion=wait_for_completion,
             check_interval=self.check_interval,
             max_ingestion_time=self.max_ingestion_time,
         )
         if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
             raise AirflowException(f"Sagemaker Processing Job creation failed: 
{response}")
+
+        if self.deferrable and self.wait_for_completion:
+            self.defer(
+                timeout=self.execution_timeout,
+                trigger=SageMakerTrigger(
+                    job_name=self.config["ProcessingJobName"],
+                    job_type="Processing",
+                    poke_interval=self.check_interval,
+                    max_attempts=self.max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                ),
+                method_name="execute_complete",
+            )
+
+        return {"Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error while running job: {event}")
+        else:
+            self.log.info(event["message"])
         return {"Processing": 
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
 
 
diff --git a/airflow/providers/amazon/aws/triggers/sagemaker.py 
b/airflow/providers/amazon/aws/triggers/sagemaker.py
new file mode 100644
index 0000000000..773a3243e9
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/sagemaker.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import Any
+
+from airflow.compat.functools import cached_property
+from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class SageMakerTrigger(BaseTrigger):
+    """
+    SageMakerTrigger is fired as deferred class with params to run the task in 
triggerer.
+
+    :param job_name: name of the job to check status
+    :param job_type: Type of the sagemaker job whether it is Transform or 
Training
+    :param poke_interval:  polling period in seconds to check for the status
+    :param max_attempts: Number of times to poll for query state before 
returning the current state,
+        defaults to None.
+    :param aws_conn_id: AWS connection ID for sagemaker
+    """
+
+    def __init__(
+        self,
+        job_name: str,
+        job_type: str,
+        poke_interval: int = 30,
+        max_attempts: int | None = None,
+        aws_conn_id: str = "aws_default",
+    ):
+        super().__init__()
+        self.job_name = job_name
+        self.job_type = job_type
+        self.poke_interval = poke_interval
+        self.max_attempts = max_attempts
+        self.aws_conn_id = aws_conn_id
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes SagemakerTrigger arguments and classpath."""
+        return (
+            "airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger",
+            {
+                "job_name": self.job_name,
+                "job_type": self.job_type,
+                "poke_interval": self.poke_interval,
+                "max_attempts": self.max_attempts,
+                "aws_conn_id": self.aws_conn_id,
+            },
+        )
+
+    @cached_property
+    def hook(self) -> SageMakerHook:
+        return SageMakerHook(aws_conn_id=self.aws_conn_id)
+
+    @staticmethod
+    def _get_job_type_waiter(job_type: str) -> str:
+        return {
+            "training": "TrainingJobComplete",
+            "transform": "TransformJobComplete",
+            "processing": "ProcessingJobComplete",
+        }[job_type.lower()]
+
+    @staticmethod
+    def _get_job_type_waiter_job_name_arg(job_type: str) -> str:
+        return {
+            "training": "TrainingJobName",
+            "transform": "TransformJobName",
+            "processing": "ProcessingJobName",
+        }[job_type.lower()]
+
+    async def run(self):
+        self.log.info("job name is %s and job type is %s", self.job_name, 
self.job_type)
+        async with self.hook.async_conn as client:
+            waiter = self.hook.get_waiter(
+                self._get_job_type_waiter(self.job_type), deferrable=True, 
client=client
+            )
+            waiter_args = {
+                self._get_job_type_waiter_job_name_arg(self.job_type): 
self.job_name,
+                "WaiterConfig": {
+                    "Delay": self.poke_interval,
+                    "MaxAttempts": self.max_attempts,
+                },
+            }
+            await waiter.wait(**waiter_args)
+        yield TriggerEvent({"status": "success", "message": "Job completed."})
diff --git a/airflow/providers/amazon/aws/waiters/sagemaker.json 
b/airflow/providers/amazon/aws/waiters/sagemaker.json
new file mode 100644
index 0000000000..73e3f09925
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/sagemaker.json
@@ -0,0 +1,83 @@
+{
+    "version": 2,
+    "waiters": {
+        "TrainingJobComplete": {
+            "delay": 30,
+            "operation": "DescribeTrainingJob",
+            "maxAttempts": 60,
+            "description": "Wait until job is COMPLETED",
+            "acceptors": [
+                {
+                    "matcher": "path",
+                    "argument": "TrainingJobStatus",
+                    "expected": "Completed",
+                    "state": "success"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "TrainingJobStatus",
+                    "expected": "Failed",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "TrainingJobStatus",
+                    "expected": "Stopped",
+                    "state": "failure"
+                }
+            ]
+        },
+        "TransformJobComplete": {
+            "delay": 30,
+            "operation": "DescribeTransformJob",
+            "maxAttempts": 60,
+            "description": "Wait until job is COMPLETED",
+            "acceptors": [
+                {
+                    "matcher": "path",
+                    "argument": "TransformJobStatus",
+                    "expected": "Completed",
+                    "state": "success"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "TransformJobStatus",
+                    "expected": "Failed",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "TransformJobStatus",
+                    "expected": "Stopped",
+                    "state": "failure"
+                }
+            ]
+        },
+        "ProcessingJobComplete": {
+            "delay": 30,
+            "operation": "DescribeProcessingJob",
+            "maxAttempts": 60,
+            "description": "Wait until job is COMPLETED",
+            "acceptors": [
+                {
+                    "matcher": "path",
+                    "argument": "ProcessingJobStatus",
+                    "expected": "Completed",
+                    "state": "success"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "ProcessingJobStatus",
+                    "expected": "Failed",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "ProcessingJobStatus",
+                    "expected": "Stopped",
+                    "state": "failure"
+                }
+            ]
+        }
+    }
+}
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
index 19be1eabb1..1d73d44bdf 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
@@ -21,10 +21,11 @@ from unittest import mock
 import pytest
 from botocore.exceptions import ClientError
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
 from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerProcessingOperator
+from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
 
 CREATE_PROCESSING_PARAMS: dict = {
     "AppSpecification": {
@@ -236,3 +237,21 @@ class TestSageMakerProcessingOperator:
                 config=CREATE_PROCESSING_PARAMS,
                 action_if_job_exists="not_fail_or_increment",
             )
+
+    @mock.patch.object(SageMakerHook, "create_processing_job")
+    
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator._check_if_job_exists")
+    def test_operator_defer(self, mock_job_exists, mock_processing):
+        mock_processing.return_value = {
+            "ProcessingJobArn": "test_arn",
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        mock_job_exists.return_value = False
+        sagemaker_operator = SageMakerProcessingOperator(
+            **self.processing_config_kwargs,
+            config=CREATE_PROCESSING_PARAMS,
+            deferrable=True,
+        )
+        sagemaker_operator.wait_for_completion = True
+        with pytest.raises(TaskDeferred) as exc:
+            sagemaker_operator.execute(context=None)
+        assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"

Reply via email to