This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new da8f133053 Use AsyncClient for Composer Operators in deferrable mode 
(#25951)
da8f133053 is described below

commit da8f133053f7483cfe45109142943a7ded1ed1a2
Author: Ɓukasz Wyszomirski <[email protected]>
AuthorDate: Mon Aug 29 12:35:45 2022 +0200

    Use AsyncClient for Composer Operators in deferrable mode (#25951)
---
 .../providers/google/cloud/hooks/cloud_composer.py | 127 ++++++++++++++++++++-
 .../google/cloud/triggers/cloud_composer.py        |   6 +-
 .../google/cloud/hooks/test_cloud_composer.py      |  81 ++++++++++++-
 .../google/cloud/operators/test_cloud_composer.py  |   6 +-
 4 files changed, 212 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/cloud_composer.py 
b/airflow/providers/google/cloud/hooks/cloud_composer.py
index 51704c3938..21765dfe8b 100644
--- a/airflow/providers/google/cloud/hooks/cloud_composer.py
+++ b/airflow/providers/google/cloud/hooks/cloud_composer.py
@@ -21,8 +21,13 @@ from typing import Dict, Optional, Sequence, Tuple, Union
 from google.api_core.client_options import ClientOptions
 from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
 from google.api_core.operation import Operation
+from google.api_core.operation_async import AsyncOperation
 from google.api_core.retry import Retry
-from google.cloud.orchestration.airflow.service_v1 import EnvironmentsClient, 
ImageVersionsClient
+from google.cloud.orchestration.airflow.service_v1 import (
+    EnvironmentsAsyncClient,
+    EnvironmentsClient,
+    ImageVersionsClient,
+)
 from 
google.cloud.orchestration.airflow.service_v1.services.environments.pagers 
import ListEnvironmentsPager
 from 
google.cloud.orchestration.airflow.service_v1.services.image_versions.pagers 
import (
     ListImageVersionsPager,
@@ -275,3 +280,123 @@ class CloudComposerHook(GoogleBaseHook):
             metadata=metadata,
         )
         return result
+
+
+class CloudComposerAsyncHook(GoogleBaseHook):
+    """Hook for Google Cloud Composer async APIs."""
+
+    client_options = ClientOptions(api_endpoint='composer.googleapis.com:443')
+
+    def get_environment_client(self) -> EnvironmentsAsyncClient:
+        """Retrieves client library object that allow access Environments 
service."""
+        return EnvironmentsAsyncClient(
+            credentials=self.get_credentials(),
+            client_info=CLIENT_INFO,
+            client_options=self.client_options,
+        )
+
+    def get_environment_name(self, project_id, region, environment_id):
+        return 
f'projects/{project_id}/locations/{region}/environments/{environment_id}'
+
+    def get_parent(self, project_id, region):
+        return f'projects/{project_id}/locations/{region}'
+
+    async def get_operation(self, operation_name):
+        return await 
self.get_environment_client().transport.operations_client.get_operation(
+            name=operation_name
+        )
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    async def create_environment(
+        self,
+        project_id: str,
+        region: str,
+        environment: Union[Environment, Dict],
+        retry: Union[Retry, _MethodDefault] = DEFAULT,
+        timeout: Optional[float] = None,
+        metadata: Sequence[Tuple[str, str]] = (),
+    ) -> AsyncOperation:
+        """
+        Create a new environment.
+
+        :param project_id: Required. The ID of the Google Cloud project that 
the service belongs to.
+        :param region: Required. The ID of the Google Cloud region that the 
service belongs to.
+        :param environment:  The environment to create. This corresponds to 
the ``environment`` field on the
+            ``request`` instance; if ``request`` is provided, this should not 
be set.
+        :param retry: Designation of what errors, if any, should be retried.
+        :param timeout: The timeout for this request.
+        :param metadata: Strings which should be sent along with the request 
as metadata.
+        """
+        client = self.get_environment_client()
+        return await client.create_environment(
+            request={'parent': self.get_parent(project_id, region), 
'environment': environment},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    async def delete_environment(
+        self,
+        project_id: str,
+        region: str,
+        environment_id: str,
+        retry: Union[Retry, _MethodDefault] = DEFAULT,
+        timeout: Optional[float] = None,
+        metadata: Sequence[Tuple[str, str]] = (),
+    ) -> AsyncOperation:
+        """
+        Delete an environment.
+
+        :param project_id: Required. The ID of the Google Cloud project that 
the service belongs to.
+        :param region: Required. The ID of the Google Cloud region that the 
service belongs to.
+        :param environment_id: Required. The ID of the Google Cloud 
environment that the service belongs to.
+        :param retry: Designation of what errors, if any, should be retried.
+        :param timeout: The timeout for this request.
+        :param metadata: Strings which should be sent along with the request 
as metadata.
+        """
+        client = self.get_environment_client()
+        name = self.get_environment_name(project_id, region, environment_id)
+        return await client.delete_environment(
+            request={"name": name}, retry=retry, timeout=timeout, 
metadata=metadata
+        )
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    async def update_environment(
+        self,
+        project_id: str,
+        region: str,
+        environment_id: str,
+        environment: Union[Environment, Dict],
+        update_mask: Union[Dict, FieldMask],
+        retry: Union[Retry, _MethodDefault] = DEFAULT,
+        timeout: Optional[float] = None,
+        metadata: Sequence[Tuple[str, str]] = (),
+    ) -> AsyncOperation:
+        r"""
+        Update an environment.
+
+        :param project_id: Required. The ID of the Google Cloud project that 
the service belongs to.
+        :param region: Required. The ID of the Google Cloud region that the 
service belongs to.
+        :param environment_id: Required. The ID of the Google Cloud 
environment that the service belongs to.
+        :param environment:  A patch environment. Fields specified by the 
``updateMask`` will be copied from
+            the patch environment into the environment under update.
+
+            This corresponds to the ``environment`` field on the ``request`` 
instance; if ``request`` is
+            provided, this should not be set.
+        :param update_mask:  Required. A comma-separated list of paths, 
relative to ``Environment``, of fields
+            to update. If a dict is provided, it must be of the same form as 
the protobuf message
+            :class:`~google.protobuf.field_mask_pb2.FieldMask`
+        :param retry: Designation of what errors, if any, should be retried.
+        :param timeout: The timeout for this request.
+        :param metadata: Strings which should be sent along with the request 
as metadata.
+        """
+        client = self.get_environment_client()
+        name = self.get_environment_name(project_id, region, environment_id)
+
+        return await client.update_environment(
+            request={"name": name, "environment": environment, "update_mask": 
update_mask},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
diff --git a/airflow/providers/google/cloud/triggers/cloud_composer.py 
b/airflow/providers/google/cloud/triggers/cloud_composer.py
index e1e5e009a5..87bf233ac3 100644
--- a/airflow/providers/google/cloud/triggers/cloud_composer.py
+++ b/airflow/providers/google/cloud/triggers/cloud_composer.py
@@ -21,7 +21,7 @@ import logging
 from typing import Any, Dict, Optional, Sequence, Tuple, Union
 
 from airflow import AirflowException
-from airflow.providers.google.cloud.hooks.cloud_composer import 
CloudComposerHook
+from airflow.providers.google.cloud.hooks.cloud_composer import 
CloudComposerAsyncHook
 
 try:
     from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -58,7 +58,7 @@ class CloudComposerExecutionTrigger(BaseTrigger):
 
         self.pooling_period_seconds = pooling_period_seconds
 
-        self.gcp_hook = CloudComposerHook(
+        self.gcp_hook = CloudComposerAsyncHook(
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
             delegate_to=self.delegate_to,
@@ -80,7 +80,7 @@ class CloudComposerExecutionTrigger(BaseTrigger):
 
     async def run(self):
         while True:
-            operation = 
self.gcp_hook.get_operation(operation_name=self.operation_name)
+            operation = await 
self.gcp_hook.get_operation(operation_name=self.operation_name)
             if operation.done:
                 break
             elif operation.error.message:
diff --git a/tests/providers/google/cloud/hooks/test_cloud_composer.py 
b/tests/providers/google/cloud/hooks/test_cloud_composer.py
index 4d54c9397b..3bec308c03 100644
--- a/tests/providers/google/cloud/hooks/test_cloud_composer.py
+++ b/tests/providers/google/cloud/hooks/test_cloud_composer.py
@@ -20,9 +20,10 @@
 import unittest
 from unittest import mock
 
+import pytest
 from google.api_core.gapic_v1.method import DEFAULT
 
-from airflow.providers.google.cloud.hooks.cloud_composer import 
CloudComposerHook
+from airflow.providers.google.cloud.hooks.cloud_composer import 
CloudComposerAsyncHook, CloudComposerHook
 
 TEST_GCP_REGION = "global"
 TEST_GCP_PROJECT = "test-project"
@@ -193,3 +194,81 @@ class TestCloudComposerHook(unittest.TestCase):
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
         )
+
+
+class TestCloudComposerAsyncHook(unittest.TestCase):
+    def setUp(
+        self,
+    ) -> None:
+        with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"), 
new=mock_init):
+            self.hook = CloudComposerAsyncHook(gcp_conn_id="test")
+
+    @pytest.mark.asyncio
+    
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
+    async def test_create_environment(self, mock_client) -> None:
+        await self.hook.create_environment(
+            project_id=TEST_GCP_PROJECT,
+            region=TEST_GCP_REGION,
+            environment=TEST_ENVIRONMENT,
+            retry=TEST_RETRY,
+            timeout=TEST_TIMEOUT,
+            metadata=TEST_METADATA,
+        )
+        mock_client.assert_called_once()
+        mock_client.return_value.create_environment.assert_called_once_with(
+            request={
+                'parent': self.hook.get_parent(TEST_GCP_PROJECT, 
TEST_GCP_REGION),
+                'environment': TEST_ENVIRONMENT,
+            },
+            retry=TEST_RETRY,
+            timeout=TEST_TIMEOUT,
+            metadata=TEST_METADATA,
+        )
+
+    @pytest.mark.asyncio
+    
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
+    async def test_delete_environment(self, mock_client) -> None:
+        await self.hook.delete_environment(
+            project_id=TEST_GCP_PROJECT,
+            region=TEST_GCP_REGION,
+            environment_id=TEST_ENVIRONMENT_ID,
+            retry=TEST_RETRY,
+            timeout=TEST_TIMEOUT,
+            metadata=TEST_METADATA,
+        )
+        mock_client.assert_called_once()
+        mock_client.return_value.delete_environment.assert_called_once_with(
+            request={
+                "name": self.hook.get_environment_name(TEST_GCP_PROJECT, 
TEST_GCP_REGION, TEST_ENVIRONMENT_ID)
+            },
+            retry=TEST_RETRY,
+            timeout=TEST_TIMEOUT,
+            metadata=TEST_METADATA,
+        )
+
+    @pytest.mark.asyncio
+    
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
+    async def test_update_environment(self, mock_client) -> None:
+        await self.hook.update_environment(
+            project_id=TEST_GCP_PROJECT,
+            region=TEST_GCP_REGION,
+            environment_id=TEST_ENVIRONMENT_ID,
+            environment=TEST_UPDATED_ENVIRONMENT,
+            update_mask=TEST_UPDATE_MASK,
+            retry=TEST_RETRY,
+            timeout=TEST_TIMEOUT,
+            metadata=TEST_METADATA,
+        )
+        mock_client.assert_called_once()
+        mock_client.return_value.update_environment.assert_called_once_with(
+            request={
+                "name": self.hook.get_environment_name(
+                    TEST_GCP_PROJECT, TEST_GCP_REGION, TEST_ENVIRONMENT_ID
+                ),
+                "environment": TEST_UPDATED_ENVIRONMENT,
+                "update_mask": TEST_UPDATE_MASK,
+            },
+            retry=TEST_RETRY,
+            timeout=TEST_TIMEOUT,
+            metadata=TEST_METADATA,
+        )
diff --git a/tests/providers/google/cloud/operators/test_cloud_composer.py 
b/tests/providers/google/cloud/operators/test_cloud_composer.py
index ca2c1b171b..9f513a7f3a 100644
--- a/tests/providers/google/cloud/operators/test_cloud_composer.py
+++ b/tests/providers/google/cloud/operators/test_cloud_composer.py
@@ -94,7 +94,7 @@ class TestCloudComposerCreateEnvironmentOperator:
 
     @mock.patch(COMPOSER_STRING.format("Environment.to_dict"))
     @mock.patch(COMPOSER_STRING.format("CloudComposerHook"))
-    @mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerHook"))
+    @mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerAsyncHook"))
     def test_execute_deferrable(self, mock_trigger_hook, mock_hook, 
to_dict_mode):
         op = CloudComposerCreateEnvironmentOperator(
             task_id=TASK_ID,
@@ -145,7 +145,7 @@ class TestCloudComposerDeleteEnvironmentOperator:
         )
 
     @mock.patch(COMPOSER_STRING.format("CloudComposerHook"))
-    @mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerHook"))
+    @mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerAsyncHook"))
     def test_execute_deferrable(self, mock_trigger_hook, mock_hook):
         op = CloudComposerDeleteEnvironmentOperator(
             task_id=TASK_ID,
@@ -200,7 +200,7 @@ class TestCloudComposerUpdateEnvironmentOperator:
 
     @mock.patch(COMPOSER_STRING.format("Environment.to_dict"))
     @mock.patch(COMPOSER_STRING.format("CloudComposerHook"))
-    @mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerHook"))
+    @mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerAsyncHook"))
     def test_execute_deferrable(self, mock_trigger_hook, mock_hook, 
to_dict_mode):
         op = CloudComposerUpdateEnvironmentOperator(
             task_id=TASK_ID,

Reply via email to