This is an automated email from the ASF dual-hosted git repository. eladkal pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 866a1f0f6d2 Fix CloudComposerAsyncHook to work correctly with Airflow 3. (#54976) 866a1f0f6d2 is described below commit 866a1f0f6d2dc585dfdc0188e93990cc97e6a327 Author: Nitochkin <62333822+crowi...@users.noreply.github.com> AuthorDate: Sat Sep 6 10:44:22 2025 +0200 Fix CloudComposerAsyncHook to work correctly with Airflow 3. (#54976) Co-authored-by: Anton Nitochkin <nitoch...@google.com> Co-authored-by: VladaZakharova <ula...@google.com> --- .../providers/google/cloud/hooks/cloud_composer.py | 26 ++++++++-------- .../google/cloud/triggers/cloud_composer.py | 36 +++++++++++++--------- .../unit/google/cloud/hooks/test_cloud_composer.py | 2 +- 3 files changed, 36 insertions(+), 28 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py index 9c963e47b0f..daf6a06a927 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py @@ -36,7 +36,7 @@ from google.cloud.orchestration.airflow.service_v1 import ( from airflow.exceptions import AirflowException from airflow.providers.google.common.consts import CLIENT_INFO -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook if TYPE_CHECKING: from google.api_core.operation import Operation @@ -473,15 +473,18 @@ class CloudComposerHook(GoogleBaseHook, OperationHelper): return response.json() -class CloudComposerAsyncHook(GoogleBaseHook): +class CloudComposerAsyncHook(GoogleBaseAsyncHook): """Hook for Google Cloud Composer async APIs.""" + sync_hook_class = CloudComposerHook + client_options = ClientOptions(api_endpoint="composer.googleapis.com:443") - def get_environment_client(self) -> EnvironmentsAsyncClient: + async def get_environment_client(self) -> EnvironmentsAsyncClient: """Retrieve client library object that allow access Environments service.""" + sync_hook = await self.get_sync_hook() return EnvironmentsAsyncClient( - credentials=self.get_credentials(), + credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO, client_options=self.client_options, ) @@ -493,9 +496,8 @@ class CloudComposerAsyncHook(GoogleBaseHook): return f"projects/{project_id}/locations/{region}" async def get_operation(self, operation_name): - return await self.get_environment_client().transport.operations_client.get_operation( - name=operation_name - ) + client = await self.get_environment_client() + return await client.transport.operations_client.get_operation(name=operation_name) @GoogleBaseHook.fallback_to_default_project_id async def create_environment( @@ -518,7 +520,7 @@ class CloudComposerAsyncHook(GoogleBaseHook): :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - client = self.get_environment_client() + client = await self.get_environment_client() return await client.create_environment( request={"parent": self.get_parent(project_id, region), "environment": environment}, retry=retry, @@ -546,7 +548,7 @@ class CloudComposerAsyncHook(GoogleBaseHook): :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - client = self.get_environment_client() + client = await self.get_environment_client() name = self.get_environment_name(project_id, region, environment_id) return await client.delete_environment( request={"name": name}, retry=retry, timeout=timeout, metadata=metadata @@ -582,7 +584,7 @@ class CloudComposerAsyncHook(GoogleBaseHook): :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - client = self.get_environment_client() + client = await self.get_environment_client() name = self.get_environment_name(project_id, region, environment_id) return await client.update_environment( @@ -620,7 +622,7 @@ class CloudComposerAsyncHook(GoogleBaseHook): :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - client = self.get_environment_client() + client = await self.get_environment_client() return await client.execute_airflow_command( request={ @@ -662,7 +664,7 @@ class CloudComposerAsyncHook(GoogleBaseHook): :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. """ - client = self.get_environment_client() + client = await self.get_environment_client() return await client.poll_airflow_command( request={ diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py index a2840dcffe2..f6654a39351 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py @@ -52,11 +52,6 @@ class CloudComposerExecutionTrigger(BaseTrigger): self.impersonation_chain = impersonation_chain self.pooling_period_seconds = pooling_period_seconds - self.gcp_hook = CloudComposerAsyncHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerExecutionTrigger", @@ -70,7 +65,14 @@ class CloudComposerExecutionTrigger(BaseTrigger): }, ) + def _get_async_hook(self) -> CloudComposerAsyncHook: + return CloudComposerAsyncHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + async def run(self): + self.gcp_hook = self._get_async_hook() while True: operation = await self.gcp_hook.get_operation(operation_name=self.operation_name) if operation.done: @@ -108,11 +110,6 @@ class CloudComposerAirflowCLICommandTrigger(BaseTrigger): self.impersonation_chain = impersonation_chain self.poll_interval = poll_interval - self.gcp_hook = CloudComposerAsyncHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerAirflowCLICommandTrigger", @@ -127,7 +124,14 @@ class CloudComposerAirflowCLICommandTrigger(BaseTrigger): }, ) + def _get_async_hook(self) -> CloudComposerAsyncHook: + return CloudComposerAsyncHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + async def run(self): + self.gcp_hook = self._get_async_hook() try: result = await self.gcp_hook.wait_command_execution_result( project_id=self.project_id, @@ -199,11 +203,6 @@ class CloudComposerDAGRunTrigger(BaseTrigger): self.poll_interval = poll_interval self.composer_airflow_version = composer_airflow_version - self.gcp_hook = CloudComposerAsyncHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerDAGRunTrigger", @@ -264,6 +263,12 @@ class CloudComposerDAGRunTrigger(BaseTrigger): return False return True + def _get_async_hook(self) -> CloudComposerAsyncHook: + return CloudComposerAsyncHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool: for dag_run in dag_runs: if dag_run["run_id"] == self.composer_dag_run_id and dag_run["state"] in self.allowed_states: @@ -271,6 +276,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger): return False async def run(self): + self.gcp_hook: CloudComposerAsyncHook = self._get_async_hook() try: while True: if datetime.now(self.end_date.tzinfo).timestamp() > self.end_date.timestamp(): diff --git a/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py b/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py index 4a371794793..cd62056e36e 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py @@ -286,7 +286,7 @@ class TestCloudComposerHook: class TestCloudComposerAsyncHook: def setup_method(self, method): - with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_init): + with mock.patch(BASE_STRING.format("GoogleBaseAsyncHook.__init__"), new=mock_init): self.hook = CloudComposerAsyncHook(gcp_conn_id="test") @pytest.mark.asyncio