This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5a6f959bd5 check sagemaker processing job status before deferring
(#36658)
5a6f959bd5 is described below
commit 5a6f959bd5826409a8d15a894edf36d0e76ef77a
Author: Wei Lee <[email protected]>
AuthorDate: Wed Jan 10 20:52:35 2024 +0800
check sagemaker processing job status before deferring (#36658)
---
.../providers/amazon/aws/operators/sagemaker.py | 15 +++++-
.../aws/operators/test_sagemaker_processing.py | 60 ++++++++++++++++++++--
2 files changed, 71 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py
b/airflow/providers/amazon/aws/operators/sagemaker.py
index e8f5f0880c..1b4ffc45af 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -283,8 +283,20 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
raise AirflowException(f"Sagemaker Processing Job creation failed:
{response}")
if self.deferrable and self.wait_for_completion:
+ response =
self.hook.describe_processing_job(self.config["ProcessingJobName"])
+ status = response["ProcessingJobStatus"]
+ if status in self.hook.failed_states:
+ raise AirflowException(f"SageMaker job failed because
{response['FailureReason']}")
+ elif status == "Completed":
+ self.log.info("%s completed successfully.", self.task_id)
+ return {"Processing": serialize(response)}
+
+ timeout = self.execution_timeout
+ if self.max_ingestion_time:
+ timeout = datetime.timedelta(seconds=self.max_ingestion_time)
+
self.defer(
- timeout=self.execution_timeout,
+ timeout=timeout,
trigger=SageMakerTrigger(
job_name=self.config["ProcessingJobName"],
job_type="Processing",
@@ -304,6 +316,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
else:
self.log.info(event["message"])
self.serialized_job =
serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
+ self.log.info("%s completed successfully.", self.task_id)
return {"Processing": self.serialized_job}
def get_openlineage_facets_on_complete(self, task_instance) ->
OperatorLineage:
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
index 0135ba13fe..3a9c9c21f1 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py
@@ -101,6 +101,10 @@ class TestSageMakerProcessingOperator:
check_interval=5,
)
+ self.defer_processing_config_kwargs = dict(
+ task_id="test_sagemaker_operator", wait_for_completion=True,
check_interval=5, deferrable=True
+ )
+
@mock.patch.object(SageMakerHook, "describe_processing_job")
@mock.patch.object(SageMakerHook, "count_processing_jobs_by_name",
return_value=0)
@mock.patch.object(
@@ -243,6 +247,9 @@ class TestSageMakerProcessingOperator:
action_if_job_exists="not_fail_or_increment",
)
+ @mock.patch.object(
+ SageMakerHook, "describe_processing_job",
return_value={"ProcessingJobStatus": "InProgress"}
+ )
@mock.patch.object(
SageMakerHook,
"create_processing_job",
@@ -252,17 +259,64 @@ class TestSageMakerProcessingOperator:
},
)
@mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists",
return_value=False)
- def test_operator_defer(self, mock_job_exists, mock_processing):
+ def test_operator_defer(self, mock_job_exists, mock_processing,
mock_describe):
sagemaker_operator = SageMakerProcessingOperator(
- **self.processing_config_kwargs,
+ **self.defer_processing_config_kwargs,
config=CREATE_PROCESSING_PARAMS,
- deferrable=True,
)
sagemaker_operator.wait_for_completion = True
with pytest.raises(TaskDeferred) as exc:
sagemaker_operator.execute(context=None)
assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is
not a SagemakerTrigger"
+
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator.defer")
+ @mock.patch.object(
+ SageMakerHook, "describe_processing_job",
return_value={"ProcessingJobStatus": "Completed"}
+ )
+ @mock.patch.object(
+ SageMakerHook,
+ "create_processing_job",
+ return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata":
{"HTTPStatusCode": 200}},
+ )
+ @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists",
return_value=False)
+ def test_operator_complete_before_defer(
+ self, mock_job_exists, mock_processing, mock_describe, mock_defer
+ ):
+ sagemaker_operator = SageMakerProcessingOperator(
+ **self.defer_processing_config_kwargs,
+ config=CREATE_PROCESSING_PARAMS,
+ )
+ sagemaker_operator.execute(context=None)
+ assert not mock_defer.called
+
+
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator.defer")
+ @mock.patch.object(
+ SageMakerHook,
+ "describe_processing_job",
+ return_value={"ProcessingJobStatus": "Failed", "FailureReason": "It
failed"},
+ )
+ @mock.patch.object(
+ SageMakerHook,
+ "create_processing_job",
+ return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata":
{"HTTPStatusCode": 200}},
+ )
+ @mock.patch.object(SageMakerBaseOperator, "_check_if_job_exists",
return_value=False)
+ def test_operator_failed_before_defer(
+ self,
+ mock_job_exists,
+ mock_processing,
+ mock_describe,
+ mock_defer,
+ ):
+ sagemaker_operator = SageMakerProcessingOperator(
+ **self.defer_processing_config_kwargs,
+ config=CREATE_PROCESSING_PARAMS,
+ )
+ with pytest.raises(AirflowException):
+ sagemaker_operator.execute(context=None)
+
+ assert not mock_defer.called
+
@mock.patch.object(
SageMakerHook,
"describe_processing_job",