This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new da8f133053 Use AsyncClient for Composer Operators in deferrable mode
(#25951)
da8f133053 is described below
commit da8f133053f7483cfe45109142943a7ded1ed1a2
Author: Ćukasz Wyszomirski <[email protected]>
AuthorDate: Mon Aug 29 12:35:45 2022 +0200
Use AsyncClient for Composer Operators in deferrable mode (#25951)
---
.../providers/google/cloud/hooks/cloud_composer.py | 127 ++++++++++++++++++++-
.../google/cloud/triggers/cloud_composer.py | 6 +-
.../google/cloud/hooks/test_cloud_composer.py | 81 ++++++++++++-
.../google/cloud/operators/test_cloud_composer.py | 6 +-
4 files changed, 212 insertions(+), 8 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/cloud_composer.py
b/airflow/providers/google/cloud/hooks/cloud_composer.py
index 51704c3938..21765dfe8b 100644
--- a/airflow/providers/google/cloud/hooks/cloud_composer.py
+++ b/airflow/providers/google/cloud/hooks/cloud_composer.py
@@ -21,8 +21,13 @@ from typing import Dict, Optional, Sequence, Tuple, Union
from google.api_core.client_options import ClientOptions
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.api_core.operation import Operation
+from google.api_core.operation_async import AsyncOperation
from google.api_core.retry import Retry
-from google.cloud.orchestration.airflow.service_v1 import EnvironmentsClient,
ImageVersionsClient
+from google.cloud.orchestration.airflow.service_v1 import (
+ EnvironmentsAsyncClient,
+ EnvironmentsClient,
+ ImageVersionsClient,
+)
from
google.cloud.orchestration.airflow.service_v1.services.environments.pagers
import ListEnvironmentsPager
from
google.cloud.orchestration.airflow.service_v1.services.image_versions.pagers
import (
ListImageVersionsPager,
@@ -275,3 +280,123 @@ class CloudComposerHook(GoogleBaseHook):
metadata=metadata,
)
return result
+
+
+class CloudComposerAsyncHook(GoogleBaseHook):
+ """Hook for Google Cloud Composer async APIs."""
+
+ client_options = ClientOptions(api_endpoint='composer.googleapis.com:443')
+
+ def get_environment_client(self) -> EnvironmentsAsyncClient:
+ """Retrieves client library object that allow access Environments
service."""
+ return EnvironmentsAsyncClient(
+ credentials=self.get_credentials(),
+ client_info=CLIENT_INFO,
+ client_options=self.client_options,
+ )
+
+ def get_environment_name(self, project_id, region, environment_id):
+ return
f'projects/{project_id}/locations/{region}/environments/{environment_id}'
+
+ def get_parent(self, project_id, region):
+ return f'projects/{project_id}/locations/{region}'
+
+ async def get_operation(self, operation_name):
+ return await
self.get_environment_client().transport.operations_client.get_operation(
+ name=operation_name
+ )
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def create_environment(
+ self,
+ project_id: str,
+ region: str,
+ environment: Union[Environment, Dict],
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> AsyncOperation:
+ """
+ Create a new environment.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param environment: The environment to create. This corresponds to
the ``environment`` field on the
+ ``request`` instance; if ``request`` is provided, this should not
be set.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = self.get_environment_client()
+ return await client.create_environment(
+ request={'parent': self.get_parent(project_id, region),
'environment': environment},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def delete_environment(
+ self,
+ project_id: str,
+ region: str,
+ environment_id: str,
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> AsyncOperation:
+ """
+ Delete an environment.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param environment_id: Required. The ID of the Google Cloud
environment that the service belongs to.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = self.get_environment_client()
+ name = self.get_environment_name(project_id, region, environment_id)
+ return await client.delete_environment(
+ request={"name": name}, retry=retry, timeout=timeout,
metadata=metadata
+ )
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def update_environment(
+ self,
+ project_id: str,
+ region: str,
+ environment_id: str,
+ environment: Union[Environment, Dict],
+ update_mask: Union[Dict, FieldMask],
+ retry: Union[Retry, _MethodDefault] = DEFAULT,
+ timeout: Optional[float] = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> AsyncOperation:
+ r"""
+ Update an environment.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param environment_id: Required. The ID of the Google Cloud
environment that the service belongs to.
+ :param environment: A patch environment. Fields specified by the
``updateMask`` will be copied from
+ the patch environment into the environment under update.
+
+ This corresponds to the ``environment`` field on the ``request``
instance; if ``request`` is
+ provided, this should not be set.
+ :param update_mask: Required. A comma-separated list of paths,
relative to ``Environment``, of fields
+ to update. If a dict is provided, it must be of the same form as
the protobuf message
+ :class:`~google.protobuf.field_mask_pb2.FieldMask`
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = self.get_environment_client()
+ name = self.get_environment_name(project_id, region, environment_id)
+
+ return await client.update_environment(
+ request={"name": name, "environment": environment, "update_mask":
update_mask},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
diff --git a/airflow/providers/google/cloud/triggers/cloud_composer.py
b/airflow/providers/google/cloud/triggers/cloud_composer.py
index e1e5e009a5..87bf233ac3 100644
--- a/airflow/providers/google/cloud/triggers/cloud_composer.py
+++ b/airflow/providers/google/cloud/triggers/cloud_composer.py
@@ -21,7 +21,7 @@ import logging
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from airflow import AirflowException
-from airflow.providers.google.cloud.hooks.cloud_composer import
CloudComposerHook
+from airflow.providers.google.cloud.hooks.cloud_composer import
CloudComposerAsyncHook
try:
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -58,7 +58,7 @@ class CloudComposerExecutionTrigger(BaseTrigger):
self.pooling_period_seconds = pooling_period_seconds
- self.gcp_hook = CloudComposerHook(
+ self.gcp_hook = CloudComposerAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
delegate_to=self.delegate_to,
@@ -80,7 +80,7 @@ class CloudComposerExecutionTrigger(BaseTrigger):
async def run(self):
while True:
- operation =
self.gcp_hook.get_operation(operation_name=self.operation_name)
+ operation = await
self.gcp_hook.get_operation(operation_name=self.operation_name)
if operation.done:
break
elif operation.error.message:
diff --git a/tests/providers/google/cloud/hooks/test_cloud_composer.py
b/tests/providers/google/cloud/hooks/test_cloud_composer.py
index 4d54c9397b..3bec308c03 100644
--- a/tests/providers/google/cloud/hooks/test_cloud_composer.py
+++ b/tests/providers/google/cloud/hooks/test_cloud_composer.py
@@ -20,9 +20,10 @@
import unittest
from unittest import mock
+import pytest
from google.api_core.gapic_v1.method import DEFAULT
-from airflow.providers.google.cloud.hooks.cloud_composer import
CloudComposerHook
+from airflow.providers.google.cloud.hooks.cloud_composer import
CloudComposerAsyncHook, CloudComposerHook
TEST_GCP_REGION = "global"
TEST_GCP_PROJECT = "test-project"
@@ -193,3 +194,81 @@ class TestCloudComposerHook(unittest.TestCase):
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
+
+
+class TestCloudComposerAsyncHook(unittest.TestCase):
+ def setUp(
+ self,
+ ) -> None:
+ with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"),
new=mock_init):
+ self.hook = CloudComposerAsyncHook(gcp_conn_id="test")
+
+ @pytest.mark.asyncio
+
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
+ async def test_create_environment(self, mock_client) -> None:
+ await self.hook.create_environment(
+ project_id=TEST_GCP_PROJECT,
+ region=TEST_GCP_REGION,
+ environment=TEST_ENVIRONMENT,
+ retry=TEST_RETRY,
+ timeout=TEST_TIMEOUT,
+ metadata=TEST_METADATA,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.create_environment.assert_called_once_with(
+ request={
+ 'parent': self.hook.get_parent(TEST_GCP_PROJECT,
TEST_GCP_REGION),
+ 'environment': TEST_ENVIRONMENT,
+ },
+ retry=TEST_RETRY,
+ timeout=TEST_TIMEOUT,
+ metadata=TEST_METADATA,
+ )
+
+ @pytest.mark.asyncio
+
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
+ async def test_delete_environment(self, mock_client) -> None:
+ await self.hook.delete_environment(
+ project_id=TEST_GCP_PROJECT,
+ region=TEST_GCP_REGION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ retry=TEST_RETRY,
+ timeout=TEST_TIMEOUT,
+ metadata=TEST_METADATA,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.delete_environment.assert_called_once_with(
+ request={
+ "name": self.hook.get_environment_name(TEST_GCP_PROJECT,
TEST_GCP_REGION, TEST_ENVIRONMENT_ID)
+ },
+ retry=TEST_RETRY,
+ timeout=TEST_TIMEOUT,
+ metadata=TEST_METADATA,
+ )
+
+ @pytest.mark.asyncio
+
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
+ async def test_update_environment(self, mock_client) -> None:
+ await self.hook.update_environment(
+ project_id=TEST_GCP_PROJECT,
+ region=TEST_GCP_REGION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ environment=TEST_UPDATED_ENVIRONMENT,
+ update_mask=TEST_UPDATE_MASK,
+ retry=TEST_RETRY,
+ timeout=TEST_TIMEOUT,
+ metadata=TEST_METADATA,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.update_environment.assert_called_once_with(
+ request={
+ "name": self.hook.get_environment_name(
+ TEST_GCP_PROJECT, TEST_GCP_REGION, TEST_ENVIRONMENT_ID
+ ),
+ "environment": TEST_UPDATED_ENVIRONMENT,
+ "update_mask": TEST_UPDATE_MASK,
+ },
+ retry=TEST_RETRY,
+ timeout=TEST_TIMEOUT,
+ metadata=TEST_METADATA,
+ )
diff --git a/tests/providers/google/cloud/operators/test_cloud_composer.py
b/tests/providers/google/cloud/operators/test_cloud_composer.py
index ca2c1b171b..9f513a7f3a 100644
--- a/tests/providers/google/cloud/operators/test_cloud_composer.py
+++ b/tests/providers/google/cloud/operators/test_cloud_composer.py
@@ -94,7 +94,7 @@ class TestCloudComposerCreateEnvironmentOperator:
@mock.patch(COMPOSER_STRING.format("Environment.to_dict"))
@mock.patch(COMPOSER_STRING.format("CloudComposerHook"))
- @mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerHook"))
+ @mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerAsyncHook"))
def test_execute_deferrable(self, mock_trigger_hook, mock_hook,
to_dict_mode):
op = CloudComposerCreateEnvironmentOperator(
task_id=TASK_ID,
@@ -145,7 +145,7 @@ class TestCloudComposerDeleteEnvironmentOperator:
)
@mock.patch(COMPOSER_STRING.format("CloudComposerHook"))
- @mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerHook"))
+ @mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerAsyncHook"))
def test_execute_deferrable(self, mock_trigger_hook, mock_hook):
op = CloudComposerDeleteEnvironmentOperator(
task_id=TASK_ID,
@@ -200,7 +200,7 @@ class TestCloudComposerUpdateEnvironmentOperator:
@mock.patch(COMPOSER_STRING.format("Environment.to_dict"))
@mock.patch(COMPOSER_STRING.format("CloudComposerHook"))
- @mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerHook"))
+ @mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerAsyncHook"))
def test_execute_deferrable(self, mock_trigger_hook, mock_hook,
to_dict_mode):
op = CloudComposerUpdateEnvironmentOperator(
task_id=TASK_ID,