This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b260367208 check transform job status before deferring 
SageMakerTransformOperator (#36680)
b260367208 is described below

commit b260367208f9c3c09bc1da2a32abf59867ddd789
Author: Wei Lee <[email protected]>
AuthorDate: Thu Jan 11 13:47:45 2024 +0800

    check transform job status before deferring SageMakerTransformOperator 
(#36680)
---
 .../providers/amazon/aws/operators/sagemaker.py    | 39 ++++++++++++-----
 .../aws/operators/test_sagemaker_transform.py      | 50 +++++++++++++++++++++-
 2 files changed, 77 insertions(+), 12 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 1b4ffc45af..66a616811e 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -713,8 +713,24 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
             raise AirflowException(f"Sagemaker transform Job creation failed: 
{response}")
 
         if self.deferrable and self.wait_for_completion:
+            response = 
self.hook.describe_transform_job(transform_config["TransformJobName"])
+            status = response["TransformJobStatus"]
+            if status in self.hook.failed_states:
+                raise AirflowException(f"SageMaker job failed because 
{response['FailureReason']}")
+
+            if status == "Completed":
+                self.log.info("%s completed successfully.", self.task_id)
+                return {
+                    "Model": 
serialize(self.hook.describe_model(transform_config["ModelName"])),
+                    "Transform": serialize(response),
+                }
+
+            timeout = self.execution_timeout
+            if self.max_ingestion_time:
+                timeout = datetime.timedelta(seconds=self.max_ingestion_time)
+
             self.defer(
-                timeout=self.execution_timeout,
+                timeout=timeout,
                 trigger=SageMakerTrigger(
                     job_name=transform_config["TransformJobName"],
                     job_type="Transform",
@@ -725,17 +741,18 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
                 method_name="execute_complete",
             )
 
-        self.serialized_model = 
serialize(self.hook.describe_model(transform_config["ModelName"]))
-        self.serialized_transform = serialize(
-            
self.hook.describe_transform_job(transform_config["TransformJobName"])
-        )
-        return {"Model": self.serialized_model, "Transform": 
self.serialized_transform}
+        return self.serialize_result()
 
-    def execute_complete(self, context, event=None):
-        if event["status"] != "success":
-            raise AirflowException(f"Error while running job: {event}")
-        else:
-            self.log.info(event["message"])
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> dict[str, dict]:
+        if event is None:
+            err_msg = "Trigger error: event is None"
+            self.log.error(err_msg)
+            raise AirflowException(err_msg)
+
+        self.log.info(event["message"])
+        return self.serialize_result()
+
+    def serialize_result(self) -> dict[str, dict]:
         transform_config = self.config.get("Transform", self.config)
         self.serialized_model = 
serialize(self.hook.describe_model(transform_config["ModelName"]))
         self.serialized_transform = serialize(
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index d3bf9f8837..9f702504e9 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -167,9 +167,15 @@ class TestSageMakerTransformOperator:
             max_ingestion_time=None,
         )
 
+    
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator.defer")
+    @mock.patch.object(
+        SageMakerHook,
+        "describe_transform_job",
+        return_value={"TransformJobStatus": "Failed", "FailureReason": "it 
failed"},
+    )
     @mock.patch.object(SageMakerHook, "create_transform_job")
     @mock.patch.object(SageMakerHook, "create_model")
-    def test_operator_defer(self, _, mock_transform):
+    def test_operator_failed_before_defer(self, _, mock_transform, 
mock_describe_transform_job, mock_defer):
         mock_transform.return_value = {
             "TransformJobArn": "test_arn",
             "ResponseMetadata": {"HTTPStatusCode": 200},
@@ -177,8 +183,50 @@ class TestSageMakerTransformOperator:
         self.sagemaker.deferrable = True
         self.sagemaker.wait_for_completion = True
         self.sagemaker.check_if_job_exists = False
+
+        with pytest.raises(AirflowException):
+            self.sagemaker.execute(context=None)
+        assert not mock_defer.called
+
+    
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator.defer")
+    @mock.patch.object(SageMakerHook, "describe_model")
+    @mock.patch.object(
+        SageMakerHook, "describe_transform_job", 
return_value={"TransformJobStatus": "Completed"}
+    )
+    @mock.patch.object(SageMakerHook, "create_transform_job")
+    @mock.patch.object(SageMakerHook, "create_model")
+    def test_operator_complete_before_defer(
+        self, _, mock_transform, mock_describe_transform_job, 
mock_describe_model, mock_defer
+    ):
+        mock_transform.return_value = {
+            "TransformJobArn": "test_arn",
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        mock_describe_model.return_value = {"PrimaryContainer": 
{"ModelPackageName": "package-name"}}
+        self.sagemaker.deferrable = True
+        self.sagemaker.wait_for_completion = True
+        self.sagemaker.check_if_job_exists = False
+
+        self.sagemaker.execute(context=None)
+        assert not mock_defer.called
+
+    @mock.patch.object(
+        SageMakerHook, "describe_transform_job", 
return_value={"TransformJobStatus": "InProgress"}
+    )
+    @mock.patch.object(SageMakerHook, "create_transform_job")
+    @mock.patch.object(SageMakerHook, "create_model")
+    def test_operator_defer(self, _, mock_transform, 
mock_describe_transform_job):
+        mock_transform.return_value = {
+            "TransformJobArn": "test_arn",
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        self.sagemaker.deferrable = True
+        self.sagemaker.wait_for_completion = True
+        self.sagemaker.check_if_job_exists = False
+
         with pytest.raises(TaskDeferred) as exc:
             self.sagemaker.execute(context=None)
+
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"
 
     @mock.patch.object(SageMakerHook, "describe_transform_job")

Reply via email to