This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b555ed6f35 Fix striping tags when falling back to update in
`SageMakerEndpointOperator` (#33487)
b555ed6f35 is described below
commit b555ed6f358f738e2484db77b0782755440c8c8d
Author: Raphaƫl Vandon <[email protected]>
AuthorDate: Fri Aug 18 11:15:59 2023 -0700
Fix striping tags when falling back to update in
`SageMakerEndpointOperator` (#33487)
also fixed the condition to fallback so that we don't retry when it's
useless
+ added a warning on fallback to make the behavior more obvious to users
---
.../providers/amazon/aws/operators/sagemaker.py | 34 ++++++++++++++------
.../aws/operators/test_sagemaker_endpoint.py | 37 ++++++++++++++++++++++
2 files changed, 62 insertions(+), 9 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py
b/airflow/providers/amazon/aws/operators/sagemaker.py
index ce0fa6f7c5..1547d2203c 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -494,16 +494,32 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
try:
response = sagemaker_operation(
endpoint_info,
- wait_for_completion=False,
- )
- # waiting for completion is handled here in the operator
- except ClientError:
- self.operation = "update"
- sagemaker_operation = self.hook.update_endpoint
- response = sagemaker_operation(
- endpoint_info,
- wait_for_completion=False,
+ wait_for_completion=False, # waiting for completion is
handled here in the operator
)
+ except ClientError as ce:
+ if self.operation == "create" and
ce.response["Error"]["Message"].startswith(
+ "Cannot create already existing endpoint"
+ ):
+ # if we get an error because the endpoint already exists, we
try to update it instead
+ self.operation = "update"
+ sagemaker_operation = self.hook.update_endpoint
+ self.log.warning(
+ "cannot create already existing endpoint %s, "
+ "updating it with the given config instead",
+ endpoint_info["EndpointName"],
+ )
+ if "Tags" in endpoint_info:
+ self.log.warning(
+ "Provided tags will be ignored in the update operation
"
+ "(tags on the existing endpoint will be unchanged)"
+ )
+ endpoint_info.pop("Tags")
+ response = sagemaker_operation(
+ endpoint_info,
+ wait_for_completion=False,
+ )
+ else:
+ raise
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Sagemaker endpoint creation failed:
{response}")
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
index 8a566535b9..d31556f9bf 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
@@ -122,6 +122,43 @@ class TestSageMakerEndpointOperator:
}
self.sagemaker.execute(None)
+ @mock.patch.object(SageMakerHook, "get_conn")
+ @mock.patch.object(SageMakerHook, "create_model")
+ @mock.patch.object(SageMakerHook, "create_endpoint_config")
+ @mock.patch.object(SageMakerHook, "create_endpoint")
+ @mock.patch.object(SageMakerHook, "update_endpoint")
+ @mock.patch.object(sagemaker, "serialize", return_value="")
+ def test_execute_with_duplicate_endpoint_removes_tags(
+ self,
+ serialize,
+ mock_endpoint_update,
+ mock_endpoint_create,
+ mock_endpoint_config,
+ mock_model,
+ mock_client,
+ ):
+ mock_endpoint_create.side_effect = ClientError(
+ error_response={
+ "Error": {
+ "Code": "ValidationException",
+ "Message": "Cannot create already existing endpoint.",
+ }
+ },
+ operation_name="CreateEndpoint",
+ )
+
+ def _check_no_tags(config, wait_for_completion):
+ assert "Tags" not in config
+ return {
+ "EndpointArn": "test_arn",
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+
+ mock_endpoint_update.side_effect = _check_no_tags
+
+ self.sagemaker.config["Endpoint"]["Tags"] = {"Key": "k", "Value": "v"}
+ self.sagemaker.execute(None)
+
@mock.patch.object(SageMakerHook, "create_model")
@mock.patch.object(SageMakerHook, "create_endpoint_config")
@mock.patch.object(SageMakerHook, "create_endpoint")