This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f1758fdd7d fix: respect connection ID and impersonation in
GKEStartPodOperator (#36861)
f1758fdd7d is described below
commit f1758fdd7da8e933a701ab1a8df96c43288e8d0d
Author: Cedrik Neumann <[email protected]>
AuthorDate: Sat Jan 20 09:45:20 2024 +0100
fix: respect connection ID and impersonation in GKEStartPodOperator (#36861)
The GKEStartPodOperator accepts `gcp_conn_id` and `impersonation_chain`
as parameters.
This PR ensures that those values are passed to and supported by the
corresponding
hooks and triggers in deferrable and non-deferrable mode.
---
.../google/cloud/hooks/kubernetes_engine.py | 26 ++++++++++++++++++----
.../google/cloud/operators/kubernetes_engine.py | 3 +++
.../google/cloud/triggers/kubernetes_engine.py | 8 +++++++
.../google/cloud/hooks/test_kubernetes_engine.py | 9 +++++++-
.../cloud/triggers/test_kubernetes_engine.py | 4 ++++
5 files changed, 45 insertions(+), 5 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py
b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
index 0e62b990da..30be61dc1f 100644
--- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
@@ -352,10 +352,15 @@ class GKEPodHook(GoogleBaseHook, PodOperatorHookProtocol):
self,
cluster_url: str,
ssl_ca_cert: str,
- *args,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
):
- super().__init__(*args, **kwargs)
+ super().__init__(
+ gcp_conn_id=gcp_conn_id,
+ impersonation_chain=impersonation_chain,
+ **kwargs,
+ )
self._cluster_url = cluster_url
self._ssl_ca_cert = ssl_ca_cert
@@ -440,10 +445,23 @@ class GKEPodAsyncHook(GoogleBaseAsyncHook):
sync_hook_class = GKEPodHook
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
- def __init__(self, cluster_url: str, ssl_ca_cert: str, **kwargs) -> None:
+ def __init__(
+ self,
+ cluster_url: str,
+ ssl_ca_cert: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
self._cluster_url = cluster_url
self._ssl_ca_cert = ssl_ca_cert
- super().__init__(cluster_url=cluster_url, ssl_ca_cert=ssl_ca_cert,
**kwargs)
+ super().__init__(
+ cluster_url=cluster_url,
+ ssl_ca_cert=ssl_ca_cert,
+ gcp_conn_id=gcp_conn_id,
+ impersonation_chain=impersonation_chain,
+ **kwargs,
+ )
@contextlib.asynccontextmanager
async def get_conn(self, token: Token) -> async_client.ApiClient: # type:
ignore[override]
diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py
b/airflow/providers/google/cloud/operators/kubernetes_engine.py
index e5ca3c271b..2d2bf7337d 100644
--- a/airflow/providers/google/cloud/operators/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py
@@ -538,6 +538,7 @@ class GKEStartPodOperator(KubernetesPodOperator):
gcp_conn_id=self.gcp_conn_id,
cluster_url=self._cluster_url,
ssl_ca_cert=self._ssl_ca_cert,
+ impersonation_chain=self.impersonation_chain,
)
return hook
@@ -577,6 +578,8 @@ class GKEStartPodOperator(KubernetesPodOperator):
in_cluster=self.in_cluster,
base_container_name=self.base_container_name,
on_finish_action=self.on_finish_action,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
kwargs={"cluster_url": self._cluster_url, "ssl_ca_cert":
self._ssl_ca_cert},
diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py
b/airflow/providers/google/cloud/triggers/kubernetes_engine.py
index da068dcfc3..0167ce5a6a 100644
--- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py
@@ -76,6 +76,8 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
startup_timeout: int = 120,
on_finish_action: str = "delete_pod",
should_delete_pod: bool | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
*args,
**kwargs,
):
@@ -96,6 +98,8 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
self.in_cluster = in_cluster
self.get_logs = get_logs
self.startup_timeout = startup_timeout
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
if should_delete_pod is not None:
warnings.warn(
@@ -131,6 +135,8 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
"base_container_name": self.base_container_name,
"should_delete_pod": self.should_delete_pod,
"on_finish_action": self.on_finish_action.value,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
},
)
@@ -139,6 +145,8 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
return GKEPodAsyncHook(
cluster_url=self._cluster_url,
ssl_ca_cert=self._ssl_ca_cert,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
)
diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
index c226c0b98e..69189a8ae9 100644
--- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
@@ -312,6 +312,8 @@ class TestGKEPodAsyncHook:
return GKEPodAsyncHook(
cluster_url=CLUSTER_URL,
ssl_ca_cert=SSL_CA_CERT,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATE_CHAIN,
)
@pytest.mark.asyncio
@@ -405,7 +407,12 @@ class TestGKEPodHook:
with mock.patch(
BASE_STRING.format("GoogleBaseHook.__init__"),
new=mock_base_gcp_hook_default_project_id
):
- self.gke_hook = GKEPodHook(gcp_conn_id="test", ssl_ca_cert=None,
cluster_url=None)
+ self.gke_hook = GKEPodHook(
+ gcp_conn_id="test",
+ impersonation_chain=IMPERSONATE_CHAIN,
+ ssl_ca_cert=None,
+ cluster_url=None,
+ )
self.gke_hook._client = mock.Mock()
def refresh_token(request):
diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
index 65bc45c415..ec31b2bcc4 100644
--- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
@@ -73,6 +73,8 @@ def trigger():
cluster_url=CLUSTER_URL,
ssl_ca_cert=SSL_CA_CERT,
base_container_name=BASE_CONTAINER_NAME,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
)
@@ -101,6 +103,8 @@ class TestGKEStartPodTrigger:
"base_container_name": BASE_CONTAINER_NAME,
"on_finish_action": ON_FINISH_ACTION,
"should_delete_pod": SHOULD_DELETE_POD,
+ "gcp_conn_id": GCP_CONN_ID,
+ "impersonation_chain": IMPERSONATION_CHAIN,
}
@pytest.mark.asyncio