This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 48abf571bec483b0198802e79fb9b948ba41fdd1
Author: cont-m-nakagawa <[email protected]>
AuthorDate: Tue Apr 19 15:15:36 2022 +0900

    Add `endpoint_id` arg to 
`google.cloud.operators.vertex_ai.CreateEndpointOperator`
---
 airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py   | 3 +++
 .../providers/google/cloud/operators/vertex_ai/endpoint_service.py   | 3 +++
 .../providers/google/cloud/hooks/vertex_ai/test_endpoint_service.py  | 5 +++++
 tests/providers/google/cloud/operators/test_vertex_ai.py             | 2 ++
 4 files changed, 13 insertions(+)

diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py 
b/airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py
index 9cbe470f97..dcf521a978 100644
--- a/airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py
+++ b/airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py
@@ -81,6 +81,7 @@ class EndpointServiceHook(GoogleBaseHook):
         project_id: str,
         region: str,
         endpoint: Union[Endpoint, Dict],
+        endpoint_id: Optional[str] = None,
         retry: Union[Retry, _MethodDefault] = DEFAULT,
         timeout: Optional[float] = None,
         metadata: Sequence[Tuple[str, str]] = (),
@@ -91,6 +92,7 @@ class EndpointServiceHook(GoogleBaseHook):
         :param project_id: Required. The ID of the Google Cloud project that 
the service belongs to.
         :param region: Required. The ID of the Google Cloud region that the 
service belongs to.
         :param endpoint: Required. The Endpoint to create.
+        :param endpoint_id: The ID of Endpoint. If not provided, Vertex AI 
will generate a value for this ID.
         :param retry: Designation of what errors, if any, should be retried.
         :param timeout: The timeout for this request.
         :param metadata: Strings which should be sent along with the request 
as metadata.
@@ -102,6 +104,7 @@ class EndpointServiceHook(GoogleBaseHook):
             request={
                 'parent': parent,
                 'endpoint': endpoint,
+                'endpoint_id': endpoint_id,
             },
             retry=retry,
             timeout=timeout,
diff --git 
a/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py 
b/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py
index 67d40dd1b6..bb7aa45dd4 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py
@@ -81,6 +81,7 @@ class CreateEndpointOperator(BaseOperator):
         region: str,
         project_id: str,
         endpoint: Union[Endpoint, Dict],
+        endpoint_id: Optional[str] = None,
         retry: Union[Retry, _MethodDefault] = DEFAULT,
         timeout: Optional[float] = None,
         metadata: Sequence[Tuple[str, str]] = (),
@@ -93,6 +94,7 @@ class CreateEndpointOperator(BaseOperator):
         self.region = region
         self.project_id = project_id
         self.endpoint = endpoint
+        self.endpoint_id = endpoint_id
         self.retry = retry
         self.timeout = timeout
         self.metadata = metadata
@@ -112,6 +114,7 @@ class CreateEndpointOperator(BaseOperator):
             project_id=self.project_id,
             region=self.region,
             endpoint=self.endpoint,
+            endpoint_id=self.endpoint_id,
             retry=self.retry,
             timeout=self.timeout,
             metadata=self.metadata,
diff --git 
a/tests/providers/google/cloud/hooks/vertex_ai/test_endpoint_service.py 
b/tests/providers/google/cloud/hooks/vertex_ai/test_endpoint_service.py
index 0c6cf6c54e..146a1a555e 100644
--- a/tests/providers/google/cloud/hooks/vertex_ai/test_endpoint_service.py
+++ b/tests/providers/google/cloud/hooks/vertex_ai/test_endpoint_service.py
@@ -31,6 +31,7 @@ TEST_GCP_CONN_ID: str = "test-gcp-conn-id"
 TEST_REGION: str = "test-region"
 TEST_PROJECT_ID: str = "test-project-id"
 TEST_ENDPOINT: dict = {}
+TEST_ENDPOINT_ID: str = "test_endpoint_id"
 TEST_ENDPOINT_NAME: str = "test_endpoint_name"
 TEST_DEPLOYED_MODEL: dict = {}
 TEST_DEPLOYED_MODEL_ID: str = "test-deployed-model-id"
@@ -54,12 +55,14 @@ class TestEndpointServiceWithDefaultProjectIdHook(TestCase):
             project_id=TEST_PROJECT_ID,
             region=TEST_REGION,
             endpoint=TEST_ENDPOINT,
+            endpoint_id=TEST_ENDPOINT_ID,
         )
         mock_client.assert_called_once_with(TEST_REGION)
         mock_client.return_value.create_endpoint.assert_called_once_with(
             request=dict(
                 
parent=mock_client.return_value.common_location_path.return_value,
                 endpoint=TEST_ENDPOINT,
+                endpoint_id=TEST_ENDPOINT_ID,
             ),
             metadata=(),
             retry=DEFAULT,
@@ -223,12 +226,14 @@ class 
TestEndpointServiceWithoutDefaultProjectIdHook(TestCase):
             project_id=TEST_PROJECT_ID,
             region=TEST_REGION,
             endpoint=TEST_ENDPOINT,
+            endpoint_id=TEST_ENDPOINT_ID,
         )
         mock_client.assert_called_once_with(TEST_REGION)
         mock_client.return_value.create_endpoint.assert_called_once_with(
             request=dict(
                 
parent=mock_client.return_value.common_location_path.return_value,
                 endpoint=TEST_ENDPOINT,
+                endpoint_id=TEST_ENDPOINT_ID,
             ),
             metadata=(),
             retry=DEFAULT,
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py 
b/tests/providers/google/cloud/operators/test_vertex_ai.py
index f5fca7d3d8..e90cc3e2ee 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai.py
@@ -1137,6 +1137,7 @@ class TestVertexAICreateEndpointOperator:
             region=GCP_LOCATION,
             project_id=GCP_PROJECT,
             endpoint=TEST_ENDPOINT,
+            endpoint_id=TEST_ENDPOINT_ID,
             retry=RETRY,
             timeout=TIMEOUT,
             metadata=METADATA,
@@ -1149,6 +1150,7 @@ class TestVertexAICreateEndpointOperator:
             region=GCP_LOCATION,
             project_id=GCP_PROJECT,
             endpoint=TEST_ENDPOINT,
+            endpoint_id=TEST_ENDPOINT_ID,
             retry=RETRY,
             timeout=TIMEOUT,
             metadata=METADATA,

Reply via email to