This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1c144ee141 Add deferrable param in SageMakerProcessingOperator (#31062)
1c144ee141 is described below
commit 1c144ee141059a4c7e0450fd086eced2197568cf
Author: Pankaj Koti <[email protected]>
AuthorDate: Mon May 8 23:49:07 2023 +0530
Add deferrable param in SageMakerProcessingOperator (#31062)
* Add deferrable param in SageMakerProcessingOperator
This will allow running SageMakerProcessingOperator in an async
fashion meaning that we only submit a job from the worker to run
a job and then defer to the trigger for polling to wait for the
job status reaching a terminal state. This way, the worker slot
won't be occupied for the whole period of task execution.
* Remove unneeded mocks
---
.../providers/amazon/aws/operators/sagemaker.py | 42 ++++++++-
airflow/providers/amazon/aws/triggers/sagemaker.py | 101 +++++++++++++++++++++
.../providers/amazon/aws/waiters/sagemaker.json | 83 +++++++++++++++++
.../aws/operators/test_sagemaker_processing.py | 21 ++++-
4 files changed, 245 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 502aa0ac9c..4efb1a863b 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -28,6 +28,7 @@ from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarni
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
from airflow.providers.amazon.aws.utils.tags import format_tags
@@ -171,11 +172,15 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
:param print_log: if the operator should print the cloudwatch log during
processing
:param check_interval: if wait is set to be true, this is the time interval
in seconds which the operator will check the status of the processing
job
+ :param max_attempts: Number of times to poll for query state before
returning the current state,
+ defaults to None.
:param max_ingestion_time: If wait is set to True, the operation fails if
the processing job
doesn't finish within max_ingestion_time seconds. If you set this
parameter to None,
the operation does not timeout.
:param action_if_job_exists: Behaviour if the job name already exists.
Possible options are "timestamp"
(default), "increment" (deprecated) and "fail".
+ :param deferrable: Run operator in the deferrable mode. This is only
effective if wait_for_completion is
+ set to True.
:return Dict: Returns The ARN of the processing job created in Amazon
SageMaker.
"""
@@ -187,8 +192,10 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
wait_for_completion: bool = True,
print_log: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
+ max_attempts: int | None = None,
max_ingestion_time: int | None = None,
action_if_job_exists: str = "timestamp",
+ deferrable: bool = False,
**kwargs,
):
super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
@@ -208,7 +215,9 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
self.wait_for_completion = wait_for_completion
self.print_log = print_log
self.check_interval = check_interval
+ self.max_attempts = max_attempts or 60
self.max_ingestion_time = max_ingestion_time
+ self.deferrable = deferrable
def _create_integer_fields(self) -> None:
"""Set fields which should be cast to integers."""
@@ -234,14 +243,45 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
self.hook.describe_processing_job,
)
+ if self.deferrable and not self.wait_for_completion:
+ self.log.warning(
+ "Setting deferrable to True does not have effect when
wait_for_completion is set to False."
+ )
+
+ wait_for_completion = self.wait_for_completion
+ if self.deferrable and self.wait_for_completion:
+ # Set wait_for_completion to False so that it waits for the status
in the deferred task.
+ wait_for_completion = False
+
response = self.hook.create_processing_job(
self.config,
- wait_for_completion=self.wait_for_completion,
+ wait_for_completion=wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
)
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Sagemaker Processing Job creation failed:
{response}")
+
+ if self.deferrable and self.wait_for_completion:
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=SageMakerTrigger(
+ job_name=self.config["ProcessingJobName"],
+ job_type="Processing",
+ poke_interval=self.check_interval,
+ max_attempts=self.max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ )
+
+ return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
+
+ def execute_complete(self, context, event=None):
+ if event["status"] != "success":
+ raise AirflowException(f"Error while running job: {event}")
+ else:
+ self.log.info(event["message"])
return {"Processing":
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
diff --git a/airflow/providers/amazon/aws/triggers/sagemaker.py
b/airflow/providers/amazon/aws/triggers/sagemaker.py
new file mode 100644
index 0000000000..773a3243e9
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/sagemaker.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import Any
+
+from airflow.compat.functools import cached_property
+from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class SageMakerTrigger(BaseTrigger):
+ """
+ SageMakerTrigger is fired as deferred class with params to run the task in
triggerer.
+
+ :param job_name: name of the job to check status
+ :param job_type: Type of the sagemaker job whether it is Transform or
Training
+ :param poke_interval: polling period in seconds to check for the status
+ :param max_attempts: Number of times to poll for query state before
returning the current state,
+ defaults to None.
+ :param aws_conn_id: AWS connection ID for sagemaker
+ """
+
+ def __init__(
+ self,
+ job_name: str,
+ job_type: str,
+ poke_interval: int = 30,
+ max_attempts: int | None = None,
+ aws_conn_id: str = "aws_default",
+ ):
+ super().__init__()
+ self.job_name = job_name
+ self.job_type = job_type
+ self.poke_interval = poke_interval
+ self.max_attempts = max_attempts
+ self.aws_conn_id = aws_conn_id
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes SagemakerTrigger arguments and classpath."""
+ return (
+ "airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger",
+ {
+ "job_name": self.job_name,
+ "job_type": self.job_type,
+ "poke_interval": self.poke_interval,
+ "max_attempts": self.max_attempts,
+ "aws_conn_id": self.aws_conn_id,
+ },
+ )
+
+ @cached_property
+ def hook(self) -> SageMakerHook:
+ return SageMakerHook(aws_conn_id=self.aws_conn_id)
+
+ @staticmethod
+ def _get_job_type_waiter(job_type: str) -> str:
+ return {
+ "training": "TrainingJobComplete",
+ "transform": "TransformJobComplete",
+ "processing": "ProcessingJobComplete",
+ }[job_type.lower()]
+
+ @staticmethod
+ def _get_job_type_waiter_job_name_arg(job_type: str) -> str:
+ return {
+ "training": "TrainingJobName",
+ "transform": "TransformJobName",
+ "processing": "ProcessingJobName",
+ }[job_type.lower()]
+
+ async def run(self):
+ self.log.info("job name is %s and job type is %s", self.job_name,
self.job_type)
+ async with self.hook.async_conn as client:
+ waiter = self.hook.get_waiter(
+ self._get_job_type_waiter(self.job_type), deferrable=True,
client=client
+ )
+ waiter_args = {
+ self._get_job_type_waiter_job_name_arg(self.job_type):
self.job_name,
+ "WaiterConfig": {
+ "Delay": self.poke_interval,
+ "MaxAttempts": self.max_attempts,
+ },
+ }
+ await waiter.wait(**waiter_args)
+ yield TriggerEvent({"status": "success", "message": "Job completed."})
diff --git a/airflow/providers/amazon/aws/waiters/sagemaker.json
b/airflow/providers/amazon/aws/waiters/sagemaker.json
new file mode 100644
index 0000000000..73e3f09925
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/sagemaker.json
@@ -0,0 +1,83 @@
+{
+ "version": 2,
+ "waiters": {
+ "TrainingJobComplete": {
+ "delay": 30,
+ "operation": "DescribeTrainingJob",
+ "maxAttempts": 60,
+ "description": "Wait until job is COMPLETED",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "TrainingJobStatus",
+ "expected": "Completed",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "TrainingJobStatus",
+ "expected": "Failed",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "TrainingJobStatus",
+ "expected": "Stopped",
+ "state": "failure"
+ }
+ ]
+ },
+ "TransformJobComplete": {
+ "delay": 30,
+ "operation": "DescribeTransformJob",
+ "maxAttempts": 60,
+ "description": "Wait until job is COMPLETED",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "TransformJobStatus",
+ "expected": "Completed",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "TransformJobStatus",
+ "expected": "Failed",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "TransformJobStatus",
+ "expected": "Stopped",
+ "state": "failure"
+ }
+ ]
+ },
+ "ProcessingJobComplete": {
+ "delay": 30,
+ "operation": "DescribeProcessingJob",
+ "maxAttempts": 60,
+ "description": "Wait until job is COMPLETED",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "ProcessingJobStatus",
+ "expected": "Completed",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "ProcessingJobStatus",
+ "expected": "Failed",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "ProcessingJobStatus",
+ "expected": "Stopped",
+ "state": "failure"
+ }
+ ]
+ }
+ }
+}
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
index 19be1eabb1..1d73d44bdf 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
@@ -21,10 +21,11 @@ from unittest import mock
import pytest
from botocore.exceptions import ClientError
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerProcessingOperator
+from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
CREATE_PROCESSING_PARAMS: dict = {
"AppSpecification": {
@@ -236,3 +237,21 @@ class TestSageMakerProcessingOperator:
config=CREATE_PROCESSING_PARAMS,
action_if_job_exists="not_fail_or_increment",
)
+
+ @mock.patch.object(SageMakerHook, "create_processing_job")
+
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator._check_if_job_exists")
+ def test_operator_defer(self, mock_job_exists, mock_processing):
+ mock_processing.return_value = {
+ "ProcessingJobArn": "test_arn",
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+ mock_job_exists.return_value = False
+ sagemaker_operator = SageMakerProcessingOperator(
+ **self.processing_config_kwargs,
+ config=CREATE_PROCESSING_PARAMS,
+ deferrable=True,
+ )
+ sagemaker_operator.wait_for_completion = True
+ with pytest.raises(TaskDeferred) as exc:
+ sagemaker_operator.execute(context=None)
+ assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is
not a SagemakerTrigger"