This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4c9b5fe4c1 Add deferrable param in SageMakerTransformOperator (#31063)
4c9b5fe4c1 is described below
commit 4c9b5fe4c15ff9d813a34e5f31b5e2910f70cff8
Author: Pankaj Koti <[email protected]>
AuthorDate: Tue May 9 12:13:47 2023 +0530
Add deferrable param in SageMakerTransformOperator (#31063)
This will allow running SageMakerTransformOperator in an async
fashion meaning that we only submit a job from the worker to
run a job and then defer to the trigger for polling to wait for
the job status reaching a terminal state. This way, the worker
slot won't be occupied for the whole period of task execution.
---
.../providers/amazon/aws/operators/sagemaker.py | 53 +++++++++++++++++++---
.../aws/operators/test_sagemaker_transform.py | 17 ++++++-
2 files changed, 62 insertions(+), 8 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 587bf5b0a1..f4041b465a 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -493,6 +493,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
:param wait_for_completion: Set to True to wait until the transform job
finishes.
:param check_interval: If wait is set to True, the time interval, in
seconds,
that this operation waits to check the status of the transform job.
+ :param max_attempts: Number of times to poll for query state before
returning the current state,
+ defaults to None.
:param max_ingestion_time: If wait is set to True, the operation fails
if the transform job doesn't finish within max_ingestion_time seconds.
If you
set this parameter to None, the operation does not timeout.
@@ -511,14 +513,17 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
+ max_attempts: int | None = None,
max_ingestion_time: int | None = None,
check_if_job_exists: bool = True,
action_if_job_exists: str = "timestamp",
+ deferrable: bool = False,
**kwargs,
):
super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
+ self.max_attempts = max_attempts or 60
self.max_ingestion_time = max_ingestion_time
self.check_if_job_exists = check_if_job_exists
if action_if_job_exists in ("increment", "fail", "timestamp"):
@@ -535,6 +540,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
f"Argument action_if_job_exists accepts only 'timestamp',
'increment' and 'fail'. \
Provided value: '{action_if_job_exists}'."
)
+ self.deferrable = deferrable
def _create_integer_fields(self) -> None:
"""Set fields which should be cast to integers."""
@@ -573,21 +579,54 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
self.hook.create_model(model_config)
self.log.info("Creating SageMaker transform Job %s.",
transform_config["TransformJobName"])
+
+ if self.deferrable and not self.wait_for_completion:
+ self.log.warning(
+ "Setting deferrable to True does not have effect when
wait_for_completion is set to False."
+ )
+
+ wait_for_completion = self.wait_for_completion
+ if self.deferrable and self.wait_for_completion:
+ # Set wait_for_completion to False so that it waits for the status
in the deferred task.
+ wait_for_completion = False
+
response = self.hook.create_transform_job(
transform_config,
- wait_for_completion=self.wait_for_completion,
+ wait_for_completion=wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
)
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Sagemaker transform Job creation failed:
{response}")
- else:
- return {
- "Model":
serialize(self.hook.describe_model(transform_config["ModelName"])),
- "Transform": serialize(
-
self.hook.describe_transform_job(transform_config["TransformJobName"])
+
+ if self.deferrable and self.wait_for_completion:
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=SageMakerTrigger(
+ job_name=transform_config["TransformJobName"],
+ job_type="Transform",
+ poke_interval=self.check_interval,
+ max_attempts=self.max_attempts,
+ aws_conn_id=self.aws_conn_id,
),
- }
+ method_name="execute_complete",
+ )
+
+ return {
+ "Model":
serialize(self.hook.describe_model(transform_config["ModelName"])),
+ "Transform":
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
+ }
+
+ def execute_complete(self, context, event=None):
+ if event["status"] != "success":
+ raise AirflowException(f"Error while running job: {event}")
+ else:
+ self.log.info(event["message"])
+ transform_config = self.config.get("Transform", self.config)
+ return {
+ "Model":
serialize(self.hook.describe_model(transform_config["ModelName"])),
+ "Transform":
serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
+ }
class SageMakerTuningOperator(SageMakerBaseOperator):
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index 482c7201cc..76a4d877b6 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -23,10 +23,11 @@ from unittest import mock
import pytest
from botocore.exceptions import ClientError
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerTransformOperator
+from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
EXPECTED_INTEGER_FIELDS: list[list[str]] = [
["Transform", "TransformResources", "InstanceCount"],
@@ -163,3 +164,17 @@ class TestSageMakerTransformOperator:
check_interval=5,
max_ingestion_time=None,
)
+
+ @mock.patch.object(SageMakerHook, "create_transform_job")
+ @mock.patch.object(SageMakerHook, "create_model")
+ def test_operator_defer(self, _, mock_transform):
+ mock_transform.return_value = {
+ "TransformJobArn": "test_arn",
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+ self.sagemaker.deferrable = True
+ self.sagemaker.wait_for_completion = True
+ self.sagemaker.check_if_job_exists = False
+ with pytest.raises(TaskDeferred) as exc:
+ self.sagemaker.execute(context=None)
+ assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is
not a SagemakerTrigger"