This is an automated email from the ASF dual-hosted git repository.

shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 68f05b3b2cf Add transport parameter to CloudRunHook and 
CloudRunExecuteJobOperator (#60394)
68f05b3b2cf is described below

commit 68f05b3b2cf68366bd955a47912b3de98835b1da
Author: Arjav Patel <[email protected]>
AuthorDate: Tue Jan 27 00:31:29 2026 +0530

    Add transport parameter to CloudRunHook and CloudRunExecuteJobOperator 
(#60394)
---
 .../providers/google/cloud/hooks/cloud_run.py      | 30 +++++++++++++++++++---
 .../providers/google/cloud/operators/cloud_run.py  | 19 +++++++++++---
 .../providers/google/cloud/triggers/cloud_run.py   |  9 ++++++-
 .../unit/google/cloud/hooks/test_cloud_run.py      | 16 ++++++++++++
 .../unit/google/cloud/operators/test_cloud_run.py  | 22 ++++++++++++++++
 .../unit/google/cloud/triggers/test_cloud_run.py   |  2 ++
 6 files changed, 90 insertions(+), 8 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py 
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py
index ca67373ec64..f2f3ba417b1 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 import itertools
 from collections.abc import Iterable, Sequence
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Literal
 
 from google.cloud.run_v2 import (
     CreateJobRequest,
@@ -67,16 +67,21 @@ class CloudRunHook(GoogleBaseHook):
         If set as a sequence, the identities from the list must grant
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account.
+    :param transport: Optional. The transport to use for API requests. Can be 
'rest' or 'grpc'.
+        If set to None, a transport is chosen automatically. Use 'rest' if 
gRPC is not available
+        or fails in your environment (e.g., Docker containers with certain 
network configurations).
     """
 
     def __init__(
         self,
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
+        transport: Literal["rest", "grpc"] | None = None,
         **kwargs,
     ) -> None:
         super().__init__(gcp_conn_id=gcp_conn_id, 
impersonation_chain=impersonation_chain, **kwargs)
         self._client: JobsClient | None = None
+        self.transport = transport
 
     def get_conn(self):
         """
@@ -85,7 +90,12 @@ class CloudRunHook(GoogleBaseHook):
         :return: Cloud Run Jobs client object.
         """
         if self._client is None:
-            self._client = JobsClient(credentials=self.get_credentials(), 
client_info=CLIENT_INFO)
+            client_kwargs = {
+                "credentials": self.get_credentials(),
+                "client_info": CLIENT_INFO,
+                "transport": self.transport,
+            }
+            self._client = JobsClient(**client_kwargs)
         return self._client
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -176,6 +186,9 @@ class CloudRunAsyncHook(GoogleBaseAsyncHook):
         If set as a sequence, the identities from the list must grant
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account.
+    :param transport: Optional. The transport to use for API requests. Can be 
'rest' or 'grpc'.
+        If set to None, a transport is chosen automatically. Use 'rest' if 
gRPC is not available
+        or fails in your environment (e.g., Docker containers with certain 
network configurations).
     """
 
     sync_hook_class = CloudRunHook
@@ -184,15 +197,24 @@ class CloudRunAsyncHook(GoogleBaseAsyncHook):
         self,
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
+        transport: Literal["rest", "grpc"] | None = None,
         **kwargs,
     ):
         self._client: JobsAsyncClient | None = None
-        super().__init__(gcp_conn_id=gcp_conn_id, 
impersonation_chain=impersonation_chain, **kwargs)
+        self.transport = transport
+        super().__init__(
+            gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, 
transport=transport, **kwargs
+        )
 
     async def get_conn(self):
         if self._client is None:
             sync_hook = await self.get_sync_hook()
-            self._client = 
JobsAsyncClient(credentials=sync_hook.get_credentials(), 
client_info=CLIENT_INFO)
+            client_kwargs = {
+                "credentials": sync_hook.get_credentials(),
+                "client_info": CLIENT_INFO,
+                "transport": self.transport,
+            }
+            self._client = JobsAsyncClient(**client_kwargs)
 
         return self._client
 
diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py 
b/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py
index 7e7b5faf1b4..5c12dd4d7da 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 from collections.abc import Sequence
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Literal
 
 import google.cloud.exceptions
 from google.api_core.exceptions import AlreadyExists
@@ -263,6 +263,9 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account 
(templated).
     :param deferrable: Run the operator in deferrable mode.
+    :param transport: Optional. The transport to use for API requests. Can be 
'rest' or 'grpc'.
+        If set to None, a transport is chosen automatically. Use 'rest' if 
gRPC is not available
+        or fails in your environment (e.g., Docker containers with certain 
network configurations).
     """
 
     operator_extra_links = (CloudRunJobLoggingLink(),)
@@ -275,6 +278,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
         "overrides",
         "polling_period_seconds",
         "timeout_seconds",
+        "transport",
     )
 
     def __init__(
@@ -288,6 +292,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        transport: Literal["rest", "grpc"] | None = None,
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -300,11 +305,14 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
         self.polling_period_seconds = polling_period_seconds
         self.timeout_seconds = timeout_seconds
         self.deferrable = deferrable
+        self.transport = transport
         self.operation: operation.Operation | None = None
 
     def execute(self, context: Context):
         hook: CloudRunHook = CloudRunHook(
-            gcp_conn_id=self.gcp_conn_id, 
impersonation_chain=self.impersonation_chain
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+            transport=self.transport,
         )
         self.operation = hook.execute_job(
             region=self.region, project_id=self.project_id, 
job_name=self.job_name, overrides=self.overrides
@@ -333,6 +341,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
                 gcp_conn_id=self.gcp_conn_id,
                 impersonation_chain=self.impersonation_chain,
                 polling_period_seconds=self.polling_period_seconds,
+                transport=self.transport,
             ),
             method_name="execute_complete",
         )
@@ -350,7 +359,11 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
                 f"Operation failed with error code [{error_code}] and error 
message [{error_message}]"
             )
 
-        hook: CloudRunHook = CloudRunHook(self.gcp_conn_id, 
self.impersonation_chain)
+        hook: CloudRunHook = CloudRunHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+            transport=self.transport,
+        )
 
         job = hook.get_job(job_name=event["job_name"], region=self.region, 
project_id=self.project_id)
         return Job.to_dict(job)
diff --git 
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py 
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py
index b5547f45bac..8261edd416a 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 import asyncio
 from collections.abc import AsyncIterator, Sequence
 from enum import Enum
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Literal
 
 from airflow.providers.common.compat.sdk import AirflowException
 from airflow.providers.google.cloud.hooks.cloud_run import CloudRunAsyncHook
@@ -59,6 +59,9 @@ class CloudRunJobFinishedTrigger(BaseTrigger):
         account from the list granting this role to the originating account 
(templated).
     :param poll_sleep: Polling period in seconds to check for the status.
     :timeout: The time to wait before failing the operation.
+    :param transport: Optional. The transport to use for API requests. Can be 
'rest' or 'grpc'.
+        Defaults to 'grpc'. Use 'rest' if gRPC is not available or fails in 
your environment
+        (e.g., Docker containers with certain network configurations).
     """
 
     def __init__(
@@ -71,6 +74,7 @@ class CloudRunJobFinishedTrigger(BaseTrigger):
         impersonation_chain: str | Sequence[str] | None = None,
         polling_period_seconds: float = 10,
         timeout: float | None = None,
+        transport: Literal["rest", "grpc"] | None = None,
     ):
         super().__init__()
         self.project_id = project_id
@@ -81,6 +85,7 @@ class CloudRunJobFinishedTrigger(BaseTrigger):
         self.polling_period_seconds = polling_period_seconds
         self.timeout = timeout
         self.impersonation_chain = impersonation_chain
+        self.transport = transport
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
         """Serialize class arguments and classpath."""
@@ -95,6 +100,7 @@ class CloudRunJobFinishedTrigger(BaseTrigger):
                 "polling_period_seconds": self.polling_period_seconds,
                 "timeout": self.timeout,
                 "impersonation_chain": self.impersonation_chain,
+                "transport": self.transport,
             },
         )
 
@@ -143,4 +149,5 @@ class CloudRunJobFinishedTrigger(BaseTrigger):
         return CloudRunAsyncHook(
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
+            transport=self.transport or "grpc",
         )
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py 
b/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py
index bccea8c3e34..4a8150459c4 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py
@@ -259,6 +259,22 @@ class TestCloudRunHook:
         cloud_run_hook.delete_job(job_name=JOB_NAME, region=REGION, 
project_id=PROJECT_ID)
         
cloud_run_hook._client.delete_job.assert_called_once_with(delete_request)
 
+    @mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    )
+    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
+    @pytest.mark.parametrize(("transport", "expected_transport"), [("rest", 
"rest"), (None, None)])
+    def test_get_conn_with_transport(self, mock_jobs_client, transport, 
expected_transport):
+        """Test that transport parameter is passed to JobsClient."""
+        hook = CloudRunHook(transport=transport)
+        hook.get_credentials = self.dummy_get_credentials
+        hook.get_conn()
+
+        mock_jobs_client.assert_called_once()
+        call_kwargs = mock_jobs_client.call_args[1]
+        assert call_kwargs["transport"] == expected_transport
+
     def _mock_pager(self, number_of_jobs):
         mock_pager = []
         for i in range(number_of_jobs):
diff --git 
a/providers/google/tests/unit/google/cloud/operators/test_cloud_run.py 
b/providers/google/tests/unit/google/cloud/operators/test_cloud_run.py
index d4431877121..3c389713b1a 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_cloud_run.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_cloud_run.py
@@ -102,6 +102,28 @@ class TestCloudRunExecuteJobOperator:
         assert "overrides" in operator.template_fields
         assert "polling_period_seconds" in operator.template_fields
         assert "timeout_seconds" in operator.template_fields
+        assert "transport" in operator.template_fields
+
+    @mock.patch(CLOUD_RUN_HOOK_PATH)
+    def test_execute_with_transport(self, hook_mock):
+        """Test that transport parameter is passed to CloudRunHook."""
+        hook_mock.return_value.get_job.return_value = JOB
+        hook_mock.return_value.execute_job.return_value = 
self._mock_operation(3, 3, 0)
+
+        operator = CloudRunExecuteJobOperator(
+            task_id=TASK_ID,
+            project_id=PROJECT_ID,
+            region=REGION,
+            job_name=JOB_NAME,
+            transport="rest",
+        )
+
+        operator.execute(context=mock.MagicMock())
+
+        # Verify that CloudRunHook was instantiated with transport parameter
+        hook_mock.assert_called_once()
+        call_kwargs = hook_mock.call_args[1]
+        assert call_kwargs["transport"] == "rest"
 
     @mock.patch(CLOUD_RUN_HOOK_PATH)
     def test_execute_success(self, hook_mock):
diff --git 
a/providers/google/tests/unit/google/cloud/triggers/test_cloud_run.py 
b/providers/google/tests/unit/google/cloud/triggers/test_cloud_run.py
index 7a526d590c2..3902a17885e 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_cloud_run.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_cloud_run.py
@@ -49,6 +49,7 @@ def trigger():
         polling_period_seconds=POLL_SLEEP,
         timeout=TIMEOUT,
         impersonation_chain=IMPERSONATION_CHAIN,
+        transport=None,
     )
 
 
@@ -65,6 +66,7 @@ class TestCloudBatchJobFinishedTrigger:
             "polling_period_seconds": POLL_SLEEP,
             "timeout": TIMEOUT,
             "impersonation_chain": IMPERSONATION_CHAIN,
+            "transport": None,
         }
 
     @pytest.mark.asyncio

Reply via email to