This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f3687b68a6 Sagemaker trigger: pass the job name as part of the event 
(#39671)
f3687b68a6 is described below

commit f3687b68a677f61a57d78a96b6b9323ab8f8258e
Author: Vincent <[email protected]>
AuthorDate: Thu May 16 15:58:31 2024 -0400

    Sagemaker trigger: pass the job name as part of the event (#39671)
---
 .../providers/amazon/aws/operators/sagemaker.py    | 24 +++++++++-------------
 airflow/providers/amazon/aws/triggers/sagemaker.py |  2 +-
 .../aws/operators/test_sagemaker_transform.py      | 16 ++++++++++++---
 .../amazon/aws/triggers/test_sagemaker.py          |  4 +++-
 4 files changed, 27 insertions(+), 19 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 91b4200f18..fdad5dcb9f 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -750,20 +750,18 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
                 method_name="execute_complete",
             )
 
-        return self.serialize_result()
+        return self.serialize_result(transform_config["TransformJobName"])
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> dict[str, dict]:
         event = validate_execute_complete_event(event)
 
         self.log.info(event["message"])
-        return self.serialize_result()
+        return self.serialize_result(event["job_name"])
 
-    def serialize_result(self) -> dict[str, dict]:
-        transform_config = self.config.get("Transform", self.config)
-        self.serialized_model = 
serialize(self.hook.describe_model(transform_config["ModelName"]))
-        self.serialized_transform = serialize(
-            
self.hook.describe_transform_job(transform_config["TransformJobName"])
-        )
+    def serialize_result(self, job_name: str) -> dict[str, dict]:
+        job_description = self.hook.describe_transform_job(job_name)
+        self.serialized_model = 
serialize(self.hook.describe_model(job_description["ModelName"]))
+        self.serialized_transform = serialize(job_description)
         return {"Model": self.serialized_model, "Transform": 
self.serialized_transform}
 
     def get_openlineage_facets_on_complete(self, task_instance) -> 
OperatorLineage:
@@ -1154,7 +1152,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
                 method_name="execute_complete",
             )
 
-        return self.serialize_result()
+        return self.serialize_result(self.config["TrainingJobName"])
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> dict[str, dict]:
         event = validate_execute_complete_event(event)
@@ -1163,12 +1161,10 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
             raise AirflowException(f"Error while running job: {event}")
 
         self.log.info(event["message"])
-        return self.serialize_result()
+        return self.serialize_result(event["job_name"])
 
-    def serialize_result(self) -> dict[str, dict]:
-        self.serialized_training_data = serialize(
-            self.hook.describe_training_job(self.config["TrainingJobName"])
-        )
+    def serialize_result(self, job_name: str) -> dict[str, dict]:
+        self.serialized_training_data = 
serialize(self.hook.describe_training_job(job_name))
         return {"Training": self.serialized_training_data}
 
     def get_openlineage_facets_on_complete(self, task_instance) -> 
OperatorLineage:
diff --git a/airflow/providers/amazon/aws/triggers/sagemaker.py 
b/airflow/providers/amazon/aws/triggers/sagemaker.py
index 8f10418763..343b90f00f 100644
--- a/airflow/providers/amazon/aws/triggers/sagemaker.py
+++ b/airflow/providers/amazon/aws/triggers/sagemaker.py
@@ -121,7 +121,7 @@ class SageMakerTrigger(BaseTrigger):
                 status_message=f"{self.job_type} job not done yet",
                 status_args=[self._get_response_status_key(self.job_type)],
             )
-            yield TriggerEvent({"status": "success", "message": "Job 
completed."})
+            yield TriggerEvent({"status": "success", "message": "Job 
completed.", "job_name": self.job_name})
 
 
 class SageMakerPipelineTrigger(BaseTrigger):
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index 9f702504e9..9123037062 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -83,7 +83,10 @@ class TestSageMakerTransformOperator:
     @mock.patch.object(SageMakerHook, "create_transform_job")
     @mock.patch.object(sagemaker, "serialize", return_value="")
     def test_integer_fields(self, _, mock_create_transform, __, ___, 
mock_desc):
-        mock_desc.side_effect = [ClientError({"Error": {"Code": 
"ValidationException"}}, "op"), None]
+        mock_desc.side_effect = [
+            ClientError({"Error": {"Code": "ValidationException"}}, "op"),
+            {"ModelName": "model_name"},
+        ]
         mock_create_transform.return_value = {
             "TransformJobArn": "test_arn",
             "ResponseMetadata": {"HTTPStatusCode": 200},
@@ -103,7 +106,10 @@ class TestSageMakerTransformOperator:
     @mock.patch.object(SageMakerHook, "create_transform_job")
     @mock.patch.object(sagemaker, "serialize", return_value="")
     def test_execute(self, _, mock_transform, __, mock_model, mock_desc):
-        mock_desc.side_effect = [ClientError({"Error": {"Code": 
"ValidationException"}}, "op"), None]
+        mock_desc.side_effect = [
+            ClientError({"Error": {"Code": "ValidationException"}}, "op"),
+            {"ModelName": "model_name"},
+        ]
         mock_transform.return_value = {
             "TransformJobArn": "test_arn",
             "ResponseMetadata": {"HTTPStatusCode": 200},
@@ -135,7 +141,10 @@ class TestSageMakerTransformOperator:
     @mock.patch.object(SageMakerHook, "describe_model")
     @mock.patch.object(sagemaker, "serialize", return_value="")
     def test_execute_with_check_if_job_exists(self, _, __, ___, 
mock_transform, mock_desc):
-        mock_desc.side_effect = [ClientError({"Error": {"Code": 
"ValidationException"}}, "op"), None]
+        mock_desc.side_effect = [
+            ClientError({"Error": {"Code": "ValidationException"}}, "op"),
+            {"ModelName": "model_name"},
+        ]
         mock_transform.return_value = {
             "TransformJobArn": "test_arn",
             "ResponseMetadata": {"HTTPStatusCode": 200},
@@ -243,6 +252,7 @@ class TestSageMakerTransformOperator:
         mock_desc.return_value = {
             "TransformInput": {"DataSource": {"S3DataSource": {"S3Uri": 
"s3://input-bucket/input-path"}}},
             "TransformOutput": {"S3OutputPath": 
"s3://output-bucket/output-path"},
+            "ModelName": "model_name",
         }
         mock_transform.return_value = {
             "TransformJobArn": "test_arn",
diff --git a/tests/providers/amazon/aws/triggers/test_sagemaker.py 
b/tests/providers/amazon/aws/triggers/test_sagemaker.py
index f2d05f85a6..a69a7837de 100644
--- a/tests/providers/amazon/aws/triggers/test_sagemaker.py
+++ b/tests/providers/amazon/aws/triggers/test_sagemaker.py
@@ -77,4 +77,6 @@ class TestSagemakerTrigger:
         generator = sagemaker_trigger.run()
         response = await generator.asend(None)
 
-        assert response == TriggerEvent({"status": "success", "message": "Job 
completed."})
+        assert response == TriggerEvent(
+            {"status": "success", "message": "Job completed.", "job_name": 
JOB_NAME}
+        )

Reply via email to