This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 48abf571bec483b0198802e79fb9b948ba41fdd1 Author: cont-m-nakagawa <[email protected]> AuthorDate: Tue Apr 19 15:15:36 2022 +0900 Add `endpoint_id` arg to `google.cloud.operators.vertex_ai.CreateEndpointOperator` --- airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py | 3 +++ .../providers/google/cloud/operators/vertex_ai/endpoint_service.py | 3 +++ .../providers/google/cloud/hooks/vertex_ai/test_endpoint_service.py | 5 +++++ tests/providers/google/cloud/operators/test_vertex_ai.py | 2 ++ 4 files changed, 13 insertions(+) diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py b/airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py index 9cbe470f97..dcf521a978 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py @@ -81,6 +81,7 @@ class EndpointServiceHook(GoogleBaseHook): project_id: str, region: str, endpoint: Union[Endpoint, Dict], + endpoint_id: Optional[str] = None, retry: Union[Retry, _MethodDefault] = DEFAULT, timeout: Optional[float] = None, metadata: Sequence[Tuple[str, str]] = (), @@ -91,6 +92,7 @@ class EndpointServiceHook(GoogleBaseHook): :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param endpoint: Required. The Endpoint to create. + :param endpoint_id: The ID of Endpoint. If not provided, Vertex AI will generate a value for this ID. :param retry: Designation of what errors, if any, should be retried. :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. @@ -102,6 +104,7 @@ class EndpointServiceHook(GoogleBaseHook): request={ 'parent': parent, 'endpoint': endpoint, + 'endpoint_id': endpoint_id, }, retry=retry, timeout=timeout, diff --git a/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py b/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py index 67d40dd1b6..bb7aa45dd4 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py @@ -81,6 +81,7 @@ class CreateEndpointOperator(BaseOperator): region: str, project_id: str, endpoint: Union[Endpoint, Dict], + endpoint_id: Optional[str] = None, retry: Union[Retry, _MethodDefault] = DEFAULT, timeout: Optional[float] = None, metadata: Sequence[Tuple[str, str]] = (), @@ -93,6 +94,7 @@ class CreateEndpointOperator(BaseOperator): self.region = region self.project_id = project_id self.endpoint = endpoint + self.endpoint_id = endpoint_id self.retry = retry self.timeout = timeout self.metadata = metadata @@ -112,6 +114,7 @@ class CreateEndpointOperator(BaseOperator): project_id=self.project_id, region=self.region, endpoint=self.endpoint, + endpoint_id=self.endpoint_id, retry=self.retry, timeout=self.timeout, metadata=self.metadata, diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_endpoint_service.py b/tests/providers/google/cloud/hooks/vertex_ai/test_endpoint_service.py index 0c6cf6c54e..146a1a555e 100644 --- a/tests/providers/google/cloud/hooks/vertex_ai/test_endpoint_service.py +++ b/tests/providers/google/cloud/hooks/vertex_ai/test_endpoint_service.py @@ -31,6 +31,7 @@ TEST_GCP_CONN_ID: str = "test-gcp-conn-id" TEST_REGION: str = "test-region" TEST_PROJECT_ID: str = "test-project-id" TEST_ENDPOINT: dict = {} +TEST_ENDPOINT_ID: str = "test_endpoint_id" TEST_ENDPOINT_NAME: str = "test_endpoint_name" TEST_DEPLOYED_MODEL: dict = {} TEST_DEPLOYED_MODEL_ID: str = "test-deployed-model-id" @@ -54,12 +55,14 @@ class TestEndpointServiceWithDefaultProjectIdHook(TestCase): project_id=TEST_PROJECT_ID, region=TEST_REGION, endpoint=TEST_ENDPOINT, + endpoint_id=TEST_ENDPOINT_ID, ) mock_client.assert_called_once_with(TEST_REGION) mock_client.return_value.create_endpoint.assert_called_once_with( request=dict( parent=mock_client.return_value.common_location_path.return_value, endpoint=TEST_ENDPOINT, + endpoint_id=TEST_ENDPOINT_ID, ), metadata=(), retry=DEFAULT, @@ -223,12 +226,14 @@ class TestEndpointServiceWithoutDefaultProjectIdHook(TestCase): project_id=TEST_PROJECT_ID, region=TEST_REGION, endpoint=TEST_ENDPOINT, + endpoint_id=TEST_ENDPOINT_ID, ) mock_client.assert_called_once_with(TEST_REGION) mock_client.return_value.create_endpoint.assert_called_once_with( request=dict( parent=mock_client.return_value.common_location_path.return_value, endpoint=TEST_ENDPOINT, + endpoint_id=TEST_ENDPOINT_ID, ), metadata=(), retry=DEFAULT, diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index f5fca7d3d8..e90cc3e2ee 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -1137,6 +1137,7 @@ class TestVertexAICreateEndpointOperator: region=GCP_LOCATION, project_id=GCP_PROJECT, endpoint=TEST_ENDPOINT, + endpoint_id=TEST_ENDPOINT_ID, retry=RETRY, timeout=TIMEOUT, metadata=METADATA, @@ -1149,6 +1150,7 @@ class TestVertexAICreateEndpointOperator: region=GCP_LOCATION, project_id=GCP_PROJECT, endpoint=TEST_ENDPOINT, + endpoint_id=TEST_ENDPOINT_ID, retry=RETRY, timeout=TIMEOUT, metadata=METADATA,
