This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b260367208 check transform job status before deferring
SageMakerTransformOperator (#36680)
b260367208 is described below
commit b260367208f9c3c09bc1da2a32abf59867ddd789
Author: Wei Lee <[email protected]>
AuthorDate: Thu Jan 11 13:47:45 2024 +0800
check transform job status before deferring SageMakerTransformOperator
(#36680)
---
.../providers/amazon/aws/operators/sagemaker.py | 39 ++++++++++++-----
.../aws/operators/test_sagemaker_transform.py | 50 +++++++++++++++++++++-
2 files changed, 77 insertions(+), 12 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 1b4ffc45af..66a616811e 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -713,8 +713,24 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
raise AirflowException(f"Sagemaker transform Job creation failed:
{response}")
if self.deferrable and self.wait_for_completion:
+ response =
self.hook.describe_transform_job(transform_config["TransformJobName"])
+ status = response["TransformJobStatus"]
+ if status in self.hook.failed_states:
+ raise AirflowException(f"SageMaker job failed because
{response['FailureReason']}")
+
+ if status == "Completed":
+ self.log.info("%s completed successfully.", self.task_id)
+ return {
+ "Model":
serialize(self.hook.describe_model(transform_config["ModelName"])),
+ "Transform": serialize(response),
+ }
+
+ timeout = self.execution_timeout
+ if self.max_ingestion_time:
+ timeout = datetime.timedelta(seconds=self.max_ingestion_time)
+
self.defer(
- timeout=self.execution_timeout,
+ timeout=timeout,
trigger=SageMakerTrigger(
job_name=transform_config["TransformJobName"],
job_type="Transform",
@@ -725,17 +741,18 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
method_name="execute_complete",
)
- self.serialized_model =
serialize(self.hook.describe_model(transform_config["ModelName"]))
- self.serialized_transform = serialize(
-
self.hook.describe_transform_job(transform_config["TransformJobName"])
- )
- return {"Model": self.serialized_model, "Transform":
self.serialized_transform}
+ return self.serialize_result()
- def execute_complete(self, context, event=None):
- if event["status"] != "success":
- raise AirflowException(f"Error while running job: {event}")
- else:
- self.log.info(event["message"])
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> dict[str, dict]:
+ if event is None:
+ err_msg = "Trigger error: event is None"
+ self.log.error(err_msg)
+ raise AirflowException(err_msg)
+
+ self.log.info(event["message"])
+ return self.serialize_result()
+
+ def serialize_result(self) -> dict[str, dict]:
transform_config = self.config.get("Transform", self.config)
self.serialized_model =
serialize(self.hook.describe_model(transform_config["ModelName"]))
self.serialized_transform = serialize(
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index d3bf9f8837..9f702504e9 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -167,9 +167,15 @@ class TestSageMakerTransformOperator:
max_ingestion_time=None,
)
+
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator.defer")
+ @mock.patch.object(
+ SageMakerHook,
+ "describe_transform_job",
+ return_value={"TransformJobStatus": "Failed", "FailureReason": "it
failed"},
+ )
@mock.patch.object(SageMakerHook, "create_transform_job")
@mock.patch.object(SageMakerHook, "create_model")
- def test_operator_defer(self, _, mock_transform):
+ def test_operator_failed_before_defer(self, _, mock_transform,
mock_describe_transform_job, mock_defer):
mock_transform.return_value = {
"TransformJobArn": "test_arn",
"ResponseMetadata": {"HTTPStatusCode": 200},
@@ -177,8 +183,50 @@ class TestSageMakerTransformOperator:
self.sagemaker.deferrable = True
self.sagemaker.wait_for_completion = True
self.sagemaker.check_if_job_exists = False
+
+ with pytest.raises(AirflowException):
+ self.sagemaker.execute(context=None)
+ assert not mock_defer.called
+
+
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator.defer")
+ @mock.patch.object(SageMakerHook, "describe_model")
+ @mock.patch.object(
+ SageMakerHook, "describe_transform_job",
return_value={"TransformJobStatus": "Completed"}
+ )
+ @mock.patch.object(SageMakerHook, "create_transform_job")
+ @mock.patch.object(SageMakerHook, "create_model")
+ def test_operator_complete_before_defer(
+ self, _, mock_transform, mock_describe_transform_job,
mock_describe_model, mock_defer
+ ):
+ mock_transform.return_value = {
+ "TransformJobArn": "test_arn",
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+ mock_describe_model.return_value = {"PrimaryContainer":
{"ModelPackageName": "package-name"}}
+ self.sagemaker.deferrable = True
+ self.sagemaker.wait_for_completion = True
+ self.sagemaker.check_if_job_exists = False
+
+ self.sagemaker.execute(context=None)
+ assert not mock_defer.called
+
+ @mock.patch.object(
+ SageMakerHook, "describe_transform_job",
return_value={"TransformJobStatus": "InProgress"}
+ )
+ @mock.patch.object(SageMakerHook, "create_transform_job")
+ @mock.patch.object(SageMakerHook, "create_model")
+ def test_operator_defer(self, _, mock_transform,
mock_describe_transform_job):
+ mock_transform.return_value = {
+ "TransformJobArn": "test_arn",
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+ self.sagemaker.deferrable = True
+ self.sagemaker.wait_for_completion = True
+ self.sagemaker.check_if_job_exists = False
+
with pytest.raises(TaskDeferred) as exc:
self.sagemaker.execute(context=None)
+
assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is
not a SagemakerTrigger"
@mock.patch.object(SageMakerHook, "describe_transform_job")