This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f3687b68a6 Sagemaker trigger: pass the job name as part of the event
(#39671)
f3687b68a6 is described below
commit f3687b68a677f61a57d78a96b6b9323ab8f8258e
Author: Vincent <[email protected]>
AuthorDate: Thu May 16 15:58:31 2024 -0400
Sagemaker trigger: pass the job name as part of the event (#39671)
---
.../providers/amazon/aws/operators/sagemaker.py | 24 +++++++++-------------
airflow/providers/amazon/aws/triggers/sagemaker.py | 2 +-
.../aws/operators/test_sagemaker_transform.py | 16 ++++++++++++---
.../amazon/aws/triggers/test_sagemaker.py | 4 +++-
4 files changed, 27 insertions(+), 19 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 91b4200f18..fdad5dcb9f 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -750,20 +750,18 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
method_name="execute_complete",
)
- return self.serialize_result()
+ return self.serialize_result(transform_config["TransformJobName"])
def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> dict[str, dict]:
event = validate_execute_complete_event(event)
self.log.info(event["message"])
- return self.serialize_result()
+ return self.serialize_result(event["job_name"])
- def serialize_result(self) -> dict[str, dict]:
- transform_config = self.config.get("Transform", self.config)
- self.serialized_model =
serialize(self.hook.describe_model(transform_config["ModelName"]))
- self.serialized_transform = serialize(
-
self.hook.describe_transform_job(transform_config["TransformJobName"])
- )
+ def serialize_result(self, job_name: str) -> dict[str, dict]:
+ job_description = self.hook.describe_transform_job(job_name)
+ self.serialized_model =
serialize(self.hook.describe_model(job_description["ModelName"]))
+ self.serialized_transform = serialize(job_description)
return {"Model": self.serialized_model, "Transform":
self.serialized_transform}
def get_openlineage_facets_on_complete(self, task_instance) ->
OperatorLineage:
@@ -1154,7 +1152,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
method_name="execute_complete",
)
- return self.serialize_result()
+ return self.serialize_result(self.config["TrainingJobName"])
def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> dict[str, dict]:
event = validate_execute_complete_event(event)
@@ -1163,12 +1161,10 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
raise AirflowException(f"Error while running job: {event}")
self.log.info(event["message"])
- return self.serialize_result()
+ return self.serialize_result(event["job_name"])
- def serialize_result(self) -> dict[str, dict]:
- self.serialized_training_data = serialize(
- self.hook.describe_training_job(self.config["TrainingJobName"])
- )
+ def serialize_result(self, job_name: str) -> dict[str, dict]:
+ self.serialized_training_data =
serialize(self.hook.describe_training_job(job_name))
return {"Training": self.serialized_training_data}
def get_openlineage_facets_on_complete(self, task_instance) ->
OperatorLineage:
diff --git a/airflow/providers/amazon/aws/triggers/sagemaker.py
b/airflow/providers/amazon/aws/triggers/sagemaker.py
index 8f10418763..343b90f00f 100644
--- a/airflow/providers/amazon/aws/triggers/sagemaker.py
+++ b/airflow/providers/amazon/aws/triggers/sagemaker.py
@@ -121,7 +121,7 @@ class SageMakerTrigger(BaseTrigger):
status_message=f"{self.job_type} job not done yet",
status_args=[self._get_response_status_key(self.job_type)],
)
- yield TriggerEvent({"status": "success", "message": "Job
completed."})
+ yield TriggerEvent({"status": "success", "message": "Job
completed.", "job_name": self.job_name})
class SageMakerPipelineTrigger(BaseTrigger):
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
index 9f702504e9..9123037062 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py
@@ -83,7 +83,10 @@ class TestSageMakerTransformOperator:
@mock.patch.object(SageMakerHook, "create_transform_job")
@mock.patch.object(sagemaker, "serialize", return_value="")
def test_integer_fields(self, _, mock_create_transform, __, ___,
mock_desc):
- mock_desc.side_effect = [ClientError({"Error": {"Code":
"ValidationException"}}, "op"), None]
+ mock_desc.side_effect = [
+ ClientError({"Error": {"Code": "ValidationException"}}, "op"),
+ {"ModelName": "model_name"},
+ ]
mock_create_transform.return_value = {
"TransformJobArn": "test_arn",
"ResponseMetadata": {"HTTPStatusCode": 200},
@@ -103,7 +106,10 @@ class TestSageMakerTransformOperator:
@mock.patch.object(SageMakerHook, "create_transform_job")
@mock.patch.object(sagemaker, "serialize", return_value="")
def test_execute(self, _, mock_transform, __, mock_model, mock_desc):
- mock_desc.side_effect = [ClientError({"Error": {"Code":
"ValidationException"}}, "op"), None]
+ mock_desc.side_effect = [
+ ClientError({"Error": {"Code": "ValidationException"}}, "op"),
+ {"ModelName": "model_name"},
+ ]
mock_transform.return_value = {
"TransformJobArn": "test_arn",
"ResponseMetadata": {"HTTPStatusCode": 200},
@@ -135,7 +141,10 @@ class TestSageMakerTransformOperator:
@mock.patch.object(SageMakerHook, "describe_model")
@mock.patch.object(sagemaker, "serialize", return_value="")
def test_execute_with_check_if_job_exists(self, _, __, ___,
mock_transform, mock_desc):
- mock_desc.side_effect = [ClientError({"Error": {"Code":
"ValidationException"}}, "op"), None]
+ mock_desc.side_effect = [
+ ClientError({"Error": {"Code": "ValidationException"}}, "op"),
+ {"ModelName": "model_name"},
+ ]
mock_transform.return_value = {
"TransformJobArn": "test_arn",
"ResponseMetadata": {"HTTPStatusCode": 200},
@@ -243,6 +252,7 @@ class TestSageMakerTransformOperator:
mock_desc.return_value = {
"TransformInput": {"DataSource": {"S3DataSource": {"S3Uri":
"s3://input-bucket/input-path"}}},
"TransformOutput": {"S3OutputPath":
"s3://output-bucket/output-path"},
+ "ModelName": "model_name",
}
mock_transform.return_value = {
"TransformJobArn": "test_arn",
diff --git a/tests/providers/amazon/aws/triggers/test_sagemaker.py
b/tests/providers/amazon/aws/triggers/test_sagemaker.py
index f2d05f85a6..a69a7837de 100644
--- a/tests/providers/amazon/aws/triggers/test_sagemaker.py
+++ b/tests/providers/amazon/aws/triggers/test_sagemaker.py
@@ -77,4 +77,6 @@ class TestSagemakerTrigger:
generator = sagemaker_trigger.run()
response = await generator.asend(None)
- assert response == TriggerEvent({"status": "success", "message": "Job
completed."})
+ assert response == TriggerEvent(
+ {"status": "success", "message": "Job completed.", "job_name":
JOB_NAME}
+ )