This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b555ed6f35 Fix striping tags when falling back to update in 
`SageMakerEndpointOperator` (#33487)
b555ed6f35 is described below

commit b555ed6f358f738e2484db77b0782755440c8c8d
Author: RaphaĆ«l Vandon <[email protected]>
AuthorDate: Fri Aug 18 11:15:59 2023 -0700

    Fix striping tags when falling back to update in 
`SageMakerEndpointOperator` (#33487)
    
    also fixed the condition to fallback so that we don't retry when it's 
useless
    + added a warning on fallback to make the behavior more obvious to users
---
 .../providers/amazon/aws/operators/sagemaker.py    | 34 ++++++++++++++------
 .../aws/operators/test_sagemaker_endpoint.py       | 37 ++++++++++++++++++++++
 2 files changed, 62 insertions(+), 9 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index ce0fa6f7c5..1547d2203c 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -494,16 +494,32 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
         try:
             response = sagemaker_operation(
                 endpoint_info,
-                wait_for_completion=False,
-            )
-            # waiting for completion is handled here in the operator
-        except ClientError:
-            self.operation = "update"
-            sagemaker_operation = self.hook.update_endpoint
-            response = sagemaker_operation(
-                endpoint_info,
-                wait_for_completion=False,
+                wait_for_completion=False,  # waiting for completion is 
handled here in the operator
             )
+        except ClientError as ce:
+            if self.operation == "create" and 
ce.response["Error"]["Message"].startswith(
+                "Cannot create already existing endpoint"
+            ):
+                # if we get an error because the endpoint already exists, we 
try to update it instead
+                self.operation = "update"
+                sagemaker_operation = self.hook.update_endpoint
+                self.log.warning(
+                    "cannot create already existing endpoint %s, "
+                    "updating it with the given config instead",
+                    endpoint_info["EndpointName"],
+                )
+                if "Tags" in endpoint_info:
+                    self.log.warning(
+                        "Provided tags will be ignored in the update operation 
"
+                        "(tags on the existing endpoint will be unchanged)"
+                    )
+                    endpoint_info.pop("Tags")
+                response = sagemaker_operation(
+                    endpoint_info,
+                    wait_for_completion=False,
+                )
+            else:
+                raise
 
         if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
             raise AirflowException(f"Sagemaker endpoint creation failed: 
{response}")
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
index 8a566535b9..d31556f9bf 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
@@ -122,6 +122,43 @@ class TestSageMakerEndpointOperator:
         }
         self.sagemaker.execute(None)
 
+    @mock.patch.object(SageMakerHook, "get_conn")
+    @mock.patch.object(SageMakerHook, "create_model")
+    @mock.patch.object(SageMakerHook, "create_endpoint_config")
+    @mock.patch.object(SageMakerHook, "create_endpoint")
+    @mock.patch.object(SageMakerHook, "update_endpoint")
+    @mock.patch.object(sagemaker, "serialize", return_value="")
+    def test_execute_with_duplicate_endpoint_removes_tags(
+        self,
+        serialize,
+        mock_endpoint_update,
+        mock_endpoint_create,
+        mock_endpoint_config,
+        mock_model,
+        mock_client,
+    ):
+        mock_endpoint_create.side_effect = ClientError(
+            error_response={
+                "Error": {
+                    "Code": "ValidationException",
+                    "Message": "Cannot create already existing endpoint.",
+                }
+            },
+            operation_name="CreateEndpoint",
+        )
+
+        def _check_no_tags(config, wait_for_completion):
+            assert "Tags" not in config
+            return {
+                "EndpointArn": "test_arn",
+                "ResponseMetadata": {"HTTPStatusCode": 200},
+            }
+
+        mock_endpoint_update.side_effect = _check_no_tags
+
+        self.sagemaker.config["Endpoint"]["Tags"] = {"Key": "k", "Value": "v"}
+        self.sagemaker.execute(None)
+
     @mock.patch.object(SageMakerHook, "create_model")
     @mock.patch.object(SageMakerHook, "create_endpoint_config")
     @mock.patch.object(SageMakerHook, "create_endpoint")

Reply via email to