This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f1758fdd7d fix: respect connection ID and impersonation in 
GKEStartPodOperator (#36861)
f1758fdd7d is described below

commit f1758fdd7da8e933a701ab1a8df96c43288e8d0d
Author: Cedrik Neumann <[email protected]>
AuthorDate: Sat Jan 20 09:45:20 2024 +0100

    fix: respect connection ID and impersonation in GKEStartPodOperator (#36861)
    
    The GKEStartPodOperator accepts `gcp_conn_id` and `impersonation_chain`
    as parameters.
    
    This PR ensures that those values are passed to and supported by the 
corresponding
    hooks and triggers in deferrable and non-deferrable mode.
---
 .../google/cloud/hooks/kubernetes_engine.py        | 26 ++++++++++++++++++----
 .../google/cloud/operators/kubernetes_engine.py    |  3 +++
 .../google/cloud/triggers/kubernetes_engine.py     |  8 +++++++
 .../google/cloud/hooks/test_kubernetes_engine.py   |  9 +++++++-
 .../cloud/triggers/test_kubernetes_engine.py       |  4 ++++
 5 files changed, 45 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py 
b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
index 0e62b990da..30be61dc1f 100644
--- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
@@ -352,10 +352,15 @@ class GKEPodHook(GoogleBaseHook, PodOperatorHookProtocol):
         self,
         cluster_url: str,
         ssl_ca_cert: str,
-        *args,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
         **kwargs,
     ):
-        super().__init__(*args, **kwargs)
+        super().__init__(
+            gcp_conn_id=gcp_conn_id,
+            impersonation_chain=impersonation_chain,
+            **kwargs,
+        )
         self._cluster_url = cluster_url
         self._ssl_ca_cert = ssl_ca_cert
 
@@ -440,10 +445,23 @@ class GKEPodAsyncHook(GoogleBaseAsyncHook):
     sync_hook_class = GKEPodHook
     scopes = ["https://www.googleapis.com/auth/cloud-platform";]
 
-    def __init__(self, cluster_url: str, ssl_ca_cert: str, **kwargs) -> None:
+    def __init__(
+        self,
+        cluster_url: str,
+        ssl_ca_cert: str,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ) -> None:
         self._cluster_url = cluster_url
         self._ssl_ca_cert = ssl_ca_cert
-        super().__init__(cluster_url=cluster_url, ssl_ca_cert=ssl_ca_cert, 
**kwargs)
+        super().__init__(
+            cluster_url=cluster_url,
+            ssl_ca_cert=ssl_ca_cert,
+            gcp_conn_id=gcp_conn_id,
+            impersonation_chain=impersonation_chain,
+            **kwargs,
+        )
 
     @contextlib.asynccontextmanager
     async def get_conn(self, token: Token) -> async_client.ApiClient:  # type: 
ignore[override]
diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py 
b/airflow/providers/google/cloud/operators/kubernetes_engine.py
index e5ca3c271b..2d2bf7337d 100644
--- a/airflow/providers/google/cloud/operators/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py
@@ -538,6 +538,7 @@ class GKEStartPodOperator(KubernetesPodOperator):
             gcp_conn_id=self.gcp_conn_id,
             cluster_url=self._cluster_url,
             ssl_ca_cert=self._ssl_ca_cert,
+            impersonation_chain=self.impersonation_chain,
         )
         return hook
 
@@ -577,6 +578,8 @@ class GKEStartPodOperator(KubernetesPodOperator):
                 in_cluster=self.in_cluster,
                 base_container_name=self.base_container_name,
                 on_finish_action=self.on_finish_action,
+                gcp_conn_id=self.gcp_conn_id,
+                impersonation_chain=self.impersonation_chain,
             ),
             method_name="execute_complete",
             kwargs={"cluster_url": self._cluster_url, "ssl_ca_cert": 
self._ssl_ca_cert},
diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py 
b/airflow/providers/google/cloud/triggers/kubernetes_engine.py
index da068dcfc3..0167ce5a6a 100644
--- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py
@@ -76,6 +76,8 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
         startup_timeout: int = 120,
         on_finish_action: str = "delete_pod",
         should_delete_pod: bool | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
         *args,
         **kwargs,
     ):
@@ -96,6 +98,8 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
         self.in_cluster = in_cluster
         self.get_logs = get_logs
         self.startup_timeout = startup_timeout
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
 
         if should_delete_pod is not None:
             warnings.warn(
@@ -131,6 +135,8 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
                 "base_container_name": self.base_container_name,
                 "should_delete_pod": self.should_delete_pod,
                 "on_finish_action": self.on_finish_action.value,
+                "gcp_conn_id": self.gcp_conn_id,
+                "impersonation_chain": self.impersonation_chain,
             },
         )
 
@@ -139,6 +145,8 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
         return GKEPodAsyncHook(
             cluster_url=self._cluster_url,
             ssl_ca_cert=self._ssl_ca_cert,
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
         )
 
 
diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py 
b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
index c226c0b98e..69189a8ae9 100644
--- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
@@ -312,6 +312,8 @@ class TestGKEPodAsyncHook:
         return GKEPodAsyncHook(
             cluster_url=CLUSTER_URL,
             ssl_ca_cert=SSL_CA_CERT,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATE_CHAIN,
         )
 
     @pytest.mark.asyncio
@@ -405,7 +407,12 @@ class TestGKEPodHook:
         with mock.patch(
             BASE_STRING.format("GoogleBaseHook.__init__"), 
new=mock_base_gcp_hook_default_project_id
         ):
-            self.gke_hook = GKEPodHook(gcp_conn_id="test", ssl_ca_cert=None, 
cluster_url=None)
+            self.gke_hook = GKEPodHook(
+                gcp_conn_id="test",
+                impersonation_chain=IMPERSONATE_CHAIN,
+                ssl_ca_cert=None,
+                cluster_url=None,
+            )
         self.gke_hook._client = mock.Mock()
 
         def refresh_token(request):
diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py 
b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
index 65bc45c415..ec31b2bcc4 100644
--- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
@@ -73,6 +73,8 @@ def trigger():
         cluster_url=CLUSTER_URL,
         ssl_ca_cert=SSL_CA_CERT,
         base_container_name=BASE_CONTAINER_NAME,
+        gcp_conn_id=GCP_CONN_ID,
+        impersonation_chain=IMPERSONATION_CHAIN,
     )
 
 
@@ -101,6 +103,8 @@ class TestGKEStartPodTrigger:
             "base_container_name": BASE_CONTAINER_NAME,
             "on_finish_action": ON_FINISH_ACTION,
             "should_delete_pod": SHOULD_DELETE_POD,
+            "gcp_conn_id": GCP_CONN_ID,
+            "impersonation_chain": IMPERSONATION_CHAIN,
         }
 
     @pytest.mark.asyncio

Reply via email to