This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new afb686c95e Implement deferrable mode for GKEStartJobOperator (#38454)
afb686c95e is described below
commit afb686c95ef276ac8d9d473b74303fd1551d00fd
Author: max <[email protected]>
AuthorDate: Tue Mar 26 21:28:22 2024 +0100
Implement deferrable mode for GKEStartJobOperator (#38454)
* Implement deferrable mode for GKEStartJobOperator
* Specify trigger return type
Co-authored-by: Wei Lee <[email protected]>
* Fix f-string
Co-authored-by: Wei Lee <[email protected]>
* Refactor trigger event yielding
Co-authored-by: Wei Lee <[email protected]>
* Fix typo
---------
Co-authored-by: Wei Lee <[email protected]>
---
airflow/providers/cncf/kubernetes/triggers/job.py | 2 +-
.../google/cloud/operators/kubernetes_engine.py | 41 +++++-
.../google/cloud/triggers/kubernetes_engine.py | 72 ++++++++++-
.../operators/cloud/kubernetes_engine.rst | 9 ++
.../cloud/operators/test_kubernetes_engine.py | 76 +++++++++++
.../cloud/triggers/test_kubernetes_engine.py | 144 ++++++++++++++++++---
.../example_kubernetes_engine_job.py | 44 ++++++-
7 files changed, 365 insertions(+), 23 deletions(-)
diff --git a/airflow/providers/cncf/kubernetes/triggers/job.py
b/airflow/providers/cncf/kubernetes/triggers/job.py
index 94f4667691..f229017df1 100644
--- a/airflow/providers/cncf/kubernetes/triggers/job.py
+++ b/airflow/providers/cncf/kubernetes/triggers/job.py
@@ -45,7 +45,7 @@ class KubernetesJobTrigger(BaseTrigger):
job_name: str,
job_namespace: str,
kubernetes_conn_id: str | None = None,
- poll_interval: float = 2,
+ poll_interval: float = 10.0,
cluster_context: str | None = None,
config_file: str | None = None,
in_cluster: bool | None = None,
diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py
b/airflow/providers/google/cloud/operators/kubernetes_engine.py
index 438ed7e8ac..8729335c5a 100644
--- a/airflow/providers/google/cloud/operators/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py
@@ -31,6 +31,7 @@ from google.api_core.exceptions import AlreadyExists
from google.cloud.container_v1.types import Cluster
from kubernetes.client import V1JobList
from kubernetes.utils.create_from_yaml import FailToCreateError
+from packaging.version import parse as parse_version
from airflow.configuration import conf
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
@@ -55,7 +56,12 @@ from airflow.providers.google.cloud.links.kubernetes_engine
import (
KubernetesEngineWorkloadsLink,
)
from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
-from airflow.providers.google.cloud.triggers.kubernetes_engine import
GKEOperationTrigger, GKEStartPodTrigger
+from airflow.providers.google.cloud.triggers.kubernetes_engine import (
+ GKEJobTrigger,
+ GKEOperationTrigger,
+ GKEStartPodTrigger,
+)
+from airflow.providers_manager import ProvidersManager
from airflow.utils.timezone import utcnow
if TYPE_CHECKING:
@@ -834,6 +840,9 @@ class GKEStartJobOperator(KubernetesJobOperator):
Service Account Token Creator IAM role to the directly preceding
identity, with first
account from the list granting this role to the originating account
(templated).
:param location: The location param is region name.
+ :param deferrable: Run operator in the deferrable mode.
+ :param poll_interval: (Deferrable mode only) polling period in seconds to
+ check for the status of job.
"""
template_fields: Sequence[str] = tuple(
@@ -850,6 +859,8 @@ class GKEStartJobOperator(KubernetesJobOperator):
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ job_poll_interval: float = 10.0,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -859,6 +870,8 @@ class GKEStartJobOperator(KubernetesJobOperator):
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.use_internal_ip = use_internal_ip
+ self.deferrable = deferrable
+ self.job_poll_interval = job_poll_interval
self.job: V1Job | None = None
self._ssl_ca_cert: str | None = None
@@ -900,6 +913,18 @@ class GKEStartJobOperator(KubernetesJobOperator):
def execute(self, context: Context):
"""Execute process of creating Job."""
+ if self.deferrable:
+ kubernetes_provider =
ProvidersManager().providers["apache-airflow-providers-cncf-kubernetes"]
+ kubernetes_provider_name = kubernetes_provider.data["package-name"]
+ kubernetes_provider_version = kubernetes_provider.version
+ min_version = "8.0.1"
+ if parse_version(kubernetes_provider_version) <=
parse_version(min_version):
+ raise AirflowException(
+ "You are trying to use `GKEStartJobOperator` in deferrable
mode with the provider "
+ f"package
{kubernetes_provider_name}=={kubernetes_provider_version} which doesn't "
+ f"support this feature. Please upgrade it to version
higher than {min_version}."
+ )
+
self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails(
cluster_name=self.cluster_name,
project_id=self.project_id,
@@ -909,6 +934,20 @@ class GKEStartJobOperator(KubernetesJobOperator):
return super().execute(context)
+ def execute_deferrable(self):
+ self.defer(
+ trigger=GKEJobTrigger(
+ cluster_url=self._cluster_url,
+ ssl_ca_cert=self._ssl_ca_cert,
+ job_name=self.job.metadata.name, # type: ignore[union-attr]
+ job_namespace=self.job.metadata.namespace, # type:
ignore[union-attr]
+ gcp_conn_id=self.gcp_conn_id,
+ poll_interval=self.job_poll_interval,
+ impersonation_chain=self.impersonation_chain,
+ ),
+ method_name="execute_complete",
+ )
+
class GKEDescribeJobOperator(GoogleCloudBaseOperator):
"""
diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py
b/airflow/providers/google/cloud/triggers/kubernetes_engine.py
index 85776b8a20..c0d8fef97a 100644
--- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py
@@ -27,12 +27,18 @@ from google.cloud.container_v1.types import Operation
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
-from airflow.providers.google.cloud.hooks.kubernetes_engine import
GKEAsyncHook, GKEPodAsyncHook
+from airflow.providers.google.cloud.hooks.kubernetes_engine import (
+ GKEAsyncHook,
+ GKEKubernetesAsyncHook,
+ GKEPodAsyncHook,
+)
from airflow.triggers.base import BaseTrigger, TriggerEvent
if TYPE_CHECKING:
from datetime import datetime
+ from kubernetes_asyncio.client import V1Job
+
class GKEStartPodTrigger(KubernetesPodTrigger):
"""
@@ -237,3 +243,67 @@ class GKEOperationTrigger(BaseTrigger):
impersonation_chain=self.impersonation_chain,
)
return self._hook
+
+
+class GKEJobTrigger(BaseTrigger):
+ """GKEJobTrigger run on the trigger worker to check the state of Job."""
+
+ def __init__(
+ self,
+ cluster_url: str,
+ ssl_ca_cert: str,
+ job_name: str,
+ job_namespace: str,
+ gcp_conn_id: str = "google_cloud_default",
+ poll_interval: float = 2,
+ impersonation_chain: str | Sequence[str] | None = None,
+ ) -> None:
+ super().__init__()
+ self.cluster_url = cluster_url
+ self.ssl_ca_cert = ssl_ca_cert
+ self.job_name = job_name
+ self.job_namespace = job_namespace
+ self.gcp_conn_id = gcp_conn_id
+ self.poll_interval = poll_interval
+ self.impersonation_chain = impersonation_chain
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize KubernetesCreateJobTrigger arguments and classpath."""
+ return (
+
"airflow.providers.google.cloud.triggers.kubernetes_engine.GKEJobTrigger",
+ {
+ "cluster_url": self.cluster_url,
+ "ssl_ca_cert": self.ssl_ca_cert,
+ "job_name": self.job_name,
+ "job_namespace": self.job_namespace,
+ "gcp_conn_id": self.gcp_conn_id,
+ "poll_interval": self.poll_interval,
+ "impersonation_chain": self.impersonation_chain,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]: # type:
ignore[override]
+ """Get current job status and yield a TriggerEvent."""
+ job: V1Job = await
self.hook.wait_until_job_complete(name=self.job_name,
namespace=self.job_namespace)
+ job_dict = job.to_dict()
+ error_message = self.hook.is_job_failed(job=job)
+ status = "error" if error_message else "success"
+ message = f"Job failed with error: {error_message}" if error_message
else "Job completed successfully"
+ yield TriggerEvent(
+ {
+ "name": job.metadata.name,
+ "namespace": job.metadata.namespace,
+ "status": status,
+ "message": message,
+ "job": job_dict,
+ }
+ )
+
+ @cached_property
+ def hook(self) -> GKEKubernetesAsyncHook:
+ return GKEKubernetesAsyncHook(
+ cluster_url=self.cluster_url,
+ ssl_ca_cert=self.ssl_ca_cert,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
diff --git
a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst
b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst
index 53681ddcee..7d5cc44826 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst
@@ -213,6 +213,15 @@ All Kubernetes parameters (except ``config_file``) are
also valid for the ``GKES
:start-after: [START howto_operator_gke_start_job]
:end-before: [END howto_operator_gke_start_job]
+``GKEStartJobOperator`` also supports deferrable mode. Note that it makes
sense only if the ``wait_until_job_complete``
+parameter is set ``True``.
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_gke_start_job_def]
+ :end-before: [END howto_operator_gke_start_job_def]
+
For run Job on a GKE cluster with Kueue enabled use
``GKEStartKueueJobOperator``.
.. exampleinclude::
/../../tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py
diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py
b/tests/providers/google/cloud/operators/test_kubernetes_engine.py
index 62ccb238fb..d752c3780d 100644
--- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py
@@ -71,6 +71,7 @@ TASK_NAME = "test-task-name"
JOB_NAME = "test-job"
NAMESPACE = ("default",)
IMAGE = "bash"
+JOB_POLL_INTERVAL = 20.0
GCLOUD_COMMAND = "gcloud container clusters get-credentials {} --zone {}
--project {}"
KUBE_ENV_VAR = "KUBECONFIG"
@@ -708,6 +709,81 @@ class TestGKEStartJobOperator:
self.gke_op.execute(context=mock.MagicMock())
fetch_cluster_info_mock.assert_called_once()
+ @mock.patch(KUB_JOB_OPERATOR_EXEC)
+ @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
+ @mock.patch(GKE_HOOK_PATH)
+ @mock.patch(f"{GKE_HOOK_MODULE_PATH}.ProvidersManager")
+ def test_execute_in_deferrable_mode(
+ self, mock_providers_manager, mock_hook, fetch_cluster_info_mock,
exec_mock
+ ):
+ kubernetes_package_name = "apache-airflow-providers-cncf-kubernetes"
+ mock_providers_manager.return_value.providers = {
+ kubernetes_package_name: mock.MagicMock(
+ data={
+ "package-name": kubernetes_package_name,
+ },
+ version="8.0.2",
+ )
+ }
+ self.gke_op.deferrable = True
+
+ fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
+ self.gke_op.execute(context=mock.MagicMock())
+ fetch_cluster_info_mock.assert_called_once()
+
+ @mock.patch(f"{GKE_HOOK_MODULE_PATH}.ProvidersManager")
+ def test_execute_in_deferrable_mode_exception(self,
mock_providers_manager):
+ kubernetes_package_name = "apache-airflow-providers-cncf-kubernetes"
+ mock_providers_manager.return_value.providers = {
+ kubernetes_package_name: mock.MagicMock(
+ data={
+ "package-name": kubernetes_package_name,
+ },
+ version="8.0.1",
+ )
+ }
+ self.gke_op.deferrable = True
+ with pytest.raises(AirflowException):
+ self.gke_op.execute({})
+
+ @mock.patch(f"{GKE_HOOK_MODULE_PATH}.GKEJobTrigger")
+ def test_execute_deferrable(self, mock_trigger):
+ mock_trigger_instance = mock_trigger.return_value
+
+ op = GKEStartJobOperator(
+ project_id=TEST_GCP_PROJECT_ID,
+ location=PROJECT_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ task_id=PROJECT_TASK_ID,
+ name=TASK_NAME,
+ namespace=NAMESPACE,
+ image=IMAGE,
+ job_poll_interval=JOB_POLL_INTERVAL,
+ )
+ op._ssl_ca_cert = SSL_CA_CERT
+ op._cluster_url = CLUSTER_URL
+
+ with mock.patch.object(op, "job") as mock_job:
+ mock_metadata = mock_job.metadata
+ mock_metadata.name = TASK_NAME
+ mock_metadata.namespace = NAMESPACE
+ with mock.patch.object(op, "defer") as mock_defer:
+ op.execute_deferrable()
+
+ mock_trigger.assert_called_once_with(
+ cluster_url=CLUSTER_URL,
+ ssl_ca_cert=SSL_CA_CERT,
+ job_name=TASK_NAME,
+ job_namespace=NAMESPACE,
+ gcp_conn_id="google_cloud_default",
+ poll_interval=JOB_POLL_INTERVAL,
+ impersonation_chain=None,
+ )
+ mock_defer.assert_called_once_with(
+ trigger=mock_trigger_instance,
+ method_name="execute_complete",
+ )
+
def test_config_file_throws_error(self):
with pytest.raises(AirflowException):
GKEStartJobOperator(
diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
index c25251886b..8a18dfcb90 100644
--- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
@@ -28,13 +28,20 @@ from google.cloud.container_v1.types import Operation
from kubernetes.client import models as k8s
from airflow.providers.cncf.kubernetes.triggers.kubernetes_pod import
ContainerState
-from airflow.providers.google.cloud.triggers.kubernetes_engine import
GKEOperationTrigger, GKEStartPodTrigger
+from airflow.providers.google.cloud.triggers.kubernetes_engine import (
+ GKEJobTrigger,
+ GKEOperationTrigger,
+ GKEStartPodTrigger,
+)
from airflow.triggers.base import TriggerEvent
-TRIGGER_GKE_PATH =
"airflow.providers.google.cloud.triggers.kubernetes_engine.GKEStartPodTrigger"
-TRIGGER_KUB_PATH =
"airflow.providers.cncf.kubernetes.triggers.pod.KubernetesPodTrigger"
+GKE_TRIGGERS_PATH = "airflow.providers.google.cloud.triggers.kubernetes_engine"
+TRIGGER_GKE_POD_PATH = GKE_TRIGGERS_PATH + ".GKEStartPodTrigger"
+TRIGGER_GKE_JOB_PATH = GKE_TRIGGERS_PATH + ".GKEJobTrigger"
+TRIGGER_KUB_POD_PATH =
"airflow.providers.cncf.kubernetes.triggers.pod.KubernetesPodTrigger"
HOOK_PATH =
"airflow.providers.google.cloud.hooks.kubernetes_engine.GKEPodAsyncHook"
POD_NAME = "test-pod-name"
+JOB_NAME = "test-job-name"
NAMESPACE = "default"
POLL_INTERVAL = 2
CLUSTER_CONTEXT = "test-context"
@@ -54,7 +61,7 @@ PROJECT_ID = "test-project-id"
LOCATION = "us-central1-c"
GCP_CONN_ID = "test-non-existing-project-id"
IMPERSONATION_CHAIN = ["impersonate", "this", "test"]
-TRIGGER_PATH =
"airflow.providers.google.cloud.triggers.kubernetes_engine.GKEOperationTrigger"
+TRIGGER_PATH = f"{GKE_TRIGGERS_PATH}.GKEOperationTrigger"
EXC_MSG = "test error msg"
@@ -78,6 +85,19 @@ def trigger():
)
[email protected]
+def job_trigger():
+ return GKEJobTrigger(
+ cluster_url=CLUSTER_URL,
+ ssl_ca_cert=SSL_CA_CERT,
+ job_name=JOB_NAME,
+ job_namespace=NAMESPACE,
+ gcp_conn_id=GCP_CONN_ID,
+ poll_interval=POLL_INTERVAL,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+
class TestGKEStartPodTrigger:
@staticmethod
def _mock_pod_result(result_to_mock):
@@ -88,7 +108,7 @@ class TestGKEStartPodTrigger:
def test_serialize_should_execute_successfully(self, trigger):
classpath, kwargs_dict = trigger.serialize()
- assert classpath == TRIGGER_GKE_PATH
+ assert classpath == TRIGGER_GKE_POD_PATH
assert kwargs_dict == {
"pod_name": POD_NAME,
"pod_namespace": NAMESPACE,
@@ -108,8 +128,8 @@ class TestGKEStartPodTrigger:
}
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
- @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
+ @mock.patch(f"{TRIGGER_KUB_POD_PATH}._wait_for_pod_start")
+ @mock.patch(f"{TRIGGER_GKE_POD_PATH}.hook")
async def test_run_loop_return_success_event_should_execute_successfully(
self, mock_hook, mock_wait_pod, trigger
):
@@ -129,8 +149,8 @@ class TestGKEStartPodTrigger:
assert actual_event == expected_event
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
- @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
+ @mock.patch(f"{TRIGGER_KUB_POD_PATH}._wait_for_pod_start")
+ @mock.patch(f"{TRIGGER_GKE_POD_PATH}.hook")
async def test_run_loop_return_failed_event_should_execute_successfully(
self, mock_hook, mock_wait_pod, trigger
):
@@ -156,9 +176,9 @@ class TestGKEStartPodTrigger:
assert actual_event == expected_event
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
- @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
+ @mock.patch(f"{TRIGGER_KUB_POD_PATH}._wait_for_pod_start")
+ @mock.patch(f"{TRIGGER_KUB_POD_PATH}.define_container_state")
+ @mock.patch(f"{TRIGGER_GKE_POD_PATH}.hook")
async def test_run_loop_return_waiting_event_should_execute_successfully(
self, mock_hook, mock_method, mock_wait_pod, trigger, caplog
):
@@ -175,9 +195,9 @@ class TestGKEStartPodTrigger:
assert f"Sleeping for {POLL_INTERVAL} seconds."
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
- @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
+ @mock.patch(f"{TRIGGER_KUB_POD_PATH}._wait_for_pod_start")
+ @mock.patch(f"{TRIGGER_KUB_POD_PATH}.define_container_state")
+ @mock.patch(f"{TRIGGER_GKE_POD_PATH}.hook")
async def test_run_loop_return_running_event_should_execute_successfully(
self, mock_hook, mock_method, mock_wait_pod, trigger, caplog
):
@@ -194,8 +214,8 @@ class TestGKEStartPodTrigger:
assert f"Sleeping for {POLL_INTERVAL} seconds."
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
- @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
+ @mock.patch(f"{TRIGGER_KUB_POD_PATH}._wait_for_pod_start")
+ @mock.patch(f"{TRIGGER_GKE_POD_PATH}.hook")
async def
test_logging_in_trigger_when_exception_should_execute_successfully(
self, mock_hook, mock_wait_pod, trigger, caplog
):
@@ -216,8 +236,8 @@ class TestGKEStartPodTrigger:
assert actual_stack_trace.startswith("Traceback (most recent call
last):")
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
+ @mock.patch(f"{TRIGGER_KUB_POD_PATH}.define_container_state")
+ @mock.patch(f"{TRIGGER_GKE_POD_PATH}.hook")
async def test_logging_in_trigger_when_fail_should_execute_successfully(
self, mock_hook, mock_method, trigger, caplog
):
@@ -438,3 +458,89 @@ class TestGKEOperationTrigger:
assert not task.done()
assert "Operation is still running."
assert f"Sleeping for {POLL_INTERVAL}s..."
+
+
+class TestGKEStartJobTrigger:
+ def test_serialize(self, job_trigger):
+ classpath, kwargs_dict = job_trigger.serialize()
+
+ assert classpath == TRIGGER_GKE_JOB_PATH
+ assert kwargs_dict == {
+ "cluster_url": CLUSTER_URL,
+ "ssl_ca_cert": SSL_CA_CERT,
+ "job_name": JOB_NAME,
+ "job_namespace": NAMESPACE,
+ "gcp_conn_id": GCP_CONN_ID,
+ "poll_interval": POLL_INTERVAL,
+ "impersonation_chain": IMPERSONATION_CHAIN,
+ }
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{TRIGGER_GKE_JOB_PATH}.hook")
+ async def test_run_success(self, mock_hook, job_trigger):
+ mock_job = mock.MagicMock()
+ mock_job.metadata.name = JOB_NAME
+ mock_job.metadata.namespace = NAMESPACE
+ mock_hook.wait_until_job_complete.side_effect =
mock.AsyncMock(return_value=mock_job)
+
+ mock_is_job_failed = mock_hook.is_job_failed
+ mock_is_job_failed.return_value = False
+
+ mock_job_dict = mock_job.to_dict.return_value
+
+ event_actual = await job_trigger.run().asend(None)
+
+
mock_hook.wait_until_job_complete.assert_called_once_with(name=JOB_NAME,
namespace=NAMESPACE)
+ mock_job.to_dict.assert_called_once()
+ mock_is_job_failed.assert_called_once_with(job=mock_job)
+ assert event_actual == TriggerEvent(
+ {
+ "name": JOB_NAME,
+ "namespace": NAMESPACE,
+ "status": "success",
+ "message": "Job completed successfully",
+ "job": mock_job_dict,
+ }
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{TRIGGER_GKE_JOB_PATH}.hook")
+ async def test_run_fail(self, mock_hook, job_trigger):
+ mock_job = mock.MagicMock()
+ mock_job.metadata.name = JOB_NAME
+ mock_job.metadata.namespace = NAMESPACE
+ mock_hook.wait_until_job_complete.side_effect =
mock.AsyncMock(return_value=mock_job)
+
+ mock_is_job_failed = mock_hook.is_job_failed
+ mock_is_job_failed.return_value = "Error"
+
+ mock_job_dict = mock_job.to_dict.return_value
+
+ event_actual = await job_trigger.run().asend(None)
+
+
mock_hook.wait_until_job_complete.assert_called_once_with(name=JOB_NAME,
namespace=NAMESPACE)
+ mock_job.to_dict.assert_called_once()
+ mock_is_job_failed.assert_called_once_with(job=mock_job)
+ assert event_actual == TriggerEvent(
+ {
+ "name": JOB_NAME,
+ "namespace": NAMESPACE,
+ "status": "error",
+ "message": "Job failed with error: Error",
+ "job": mock_job_dict,
+ }
+ )
+
+ @mock.patch(f"{GKE_TRIGGERS_PATH}.GKEKubernetesAsyncHook")
+ def test_hook(self, mock_hook, job_trigger):
+ hook_expected = mock_hook.return_value
+
+ hook_actual = job_trigger.hook
+
+ mock_hook.assert_called_once_with(
+ cluster_url=CLUSTER_URL,
+ ssl_ca_cert=SSL_CA_CERT,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ assert hook_actual == hook_expected
diff --git
a/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
b/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
index 46fbcfcfbc..97b5cd406b 100644
---
a/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
+++
b/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
@@ -24,6 +24,7 @@ from __future__ import annotations
import os
from datetime import datetime
+from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.kubernetes_engine import (
GKECreateClusterOperator,
@@ -44,6 +45,7 @@ CLUSTER_NAME = f"gke-job-{ENV_ID}".replace("_", "-")
CLUSTER = {"name": CLUSTER_NAME, "initial_node_count": 1}
JOB_NAME = "test-pi"
+JOB_NAME_DEF = "test-pi-def"
JOB_NAMESPACE = "default"
with DAG(
@@ -73,6 +75,21 @@ with DAG(
)
# [END howto_operator_gke_start_job]
+ # [START howto_operator_gke_start_job_def]
+ job_task_def = GKEStartJobOperator(
+ task_id="job_task_def",
+ project_id=GCP_PROJECT_ID,
+ location=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ namespace=JOB_NAMESPACE,
+ image="perl:5.34.0",
+ cmds=["perl", "-Mbignum=bpi", "-wle", "print bpi(2000)"],
+ name=JOB_NAME_DEF,
+ wait_until_job_complete=True,
+ deferrable=True,
+ )
+ # [END howto_operator_gke_start_job_def]
+
# [START howto_operator_gke_list_jobs]
list_job_task = GKEListJobsOperator(
task_id="list_job_task", project_id=GCP_PROJECT_ID,
location=GCP_LOCATION, cluster_name=CLUSTER_NAME
@@ -90,6 +107,15 @@ with DAG(
)
# [END howto_operator_gke_describe_job]
+ describe_job_task_def = GKEDescribeJobOperator(
+ task_id="describe_job_task_def",
+ project_id=GCP_PROJECT_ID,
+ location=GCP_LOCATION,
+ job_name=job_task_def.output["job_name"],
+ namespace="default",
+ cluster_name=CLUSTER_NAME,
+ )
+
# [START howto_operator_gke_delete_job]
delete_job = GKEDeleteJobOperator(
task_id="delete_job",
@@ -101,6 +127,15 @@ with DAG(
)
# [END howto_operator_gke_delete_job]
+ delete_job_def = GKEDeleteJobOperator(
+ task_id="delete_job_def",
+ project_id=GCP_PROJECT_ID,
+ location=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ name=JOB_NAME,
+ namespace=JOB_NAMESPACE,
+ )
+
delete_cluster = GKEDeleteClusterOperator(
task_id="delete_cluster",
name=CLUSTER_NAME,
@@ -109,7 +144,14 @@ with DAG(
trigger_rule=TriggerRule.ALL_DONE,
)
- create_cluster >> job_task >> [list_job_task, describe_job_task] >>
delete_job >> delete_cluster
+ chain(
+ create_cluster,
+ [job_task, job_task_def],
+ list_job_task,
+ [describe_job_task, describe_job_task_def],
+ [delete_job, delete_job_def],
+ delete_cluster,
+ )
from tests.system.utils.watcher import watcher