This is an automated email from the ASF dual-hosted git repository.
ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ab3429c318 Add STOPPED to the failure cases for Sagemaker Training
Jobs (#42423)
ab3429c318 is described below
commit ab3429c3189ceb244eb3d78062159859dbe611ce
Author: D. Ferruzzi <[email protected]>
AuthorDate: Tue Sep 24 15:07:40 2024 -0700
Add STOPPED to the failure cases for Sagemaker Training Jobs (#42423)
---
airflow/providers/amazon/aws/hooks/sagemaker.py | 3 ++-
airflow/providers/amazon/aws/sensors/sagemaker.py | 2 +-
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py
b/airflow/providers/amazon/aws/hooks/sagemaker.py
index af131697a5..2c0f4fb25e 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -155,6 +155,7 @@ class SageMakerHook(AwsBaseHook):
endpoint_non_terminal_states = {"Creating", "Updating", "SystemUpdating",
"RollingBack", "Deleting"}
pipeline_non_terminal_states = {"Executing", "Stopping"}
failed_states = {"Failed"}
+ training_failed_states = {*failed_states, "Stopped"}
def __init__(self, *args, **kwargs):
super().__init__(client_type="sagemaker", *args, **kwargs)
@@ -309,7 +310,7 @@ class SageMakerHook(AwsBaseHook):
self.check_training_status_with_log(
config["TrainingJobName"],
self.non_terminal_states,
- self.failed_states,
+ self.training_failed_states,
wait_for_completion,
check_interval,
max_ingestion_time,
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py
b/airflow/providers/amazon/aws/sensors/sagemaker.py
index b01e24cd5b..af07c504aa 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker.py
@@ -238,7 +238,7 @@ class SageMakerTrainingSensor(SageMakerBaseSensor):
return SageMakerHook.non_terminal_states
def failed_states(self):
- return SageMakerHook.failed_states
+ return SageMakerHook.training_failed_states
def get_sagemaker_response(self):
if self.print_log: