This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new afb686c95e Implement deferrable mode for GKEStartJobOperator (#38454)
afb686c95e is described below

commit afb686c95ef276ac8d9d473b74303fd1551d00fd
Author: max <[email protected]>
AuthorDate: Tue Mar 26 21:28:22 2024 +0100

    Implement deferrable mode for GKEStartJobOperator (#38454)
    
    * Implement deferrable mode for GKEStartJobOperator
    
    * Specify trigger return type
    
    Co-authored-by: Wei Lee <[email protected]>
    
    * Fix f-string
    
    Co-authored-by: Wei Lee <[email protected]>
    
    * Refactor trigger event yielding
    
    Co-authored-by: Wei Lee <[email protected]>
    
    * Fix typo
    
    ---------
    
    Co-authored-by: Wei Lee <[email protected]>
---
 airflow/providers/cncf/kubernetes/triggers/job.py  |   2 +-
 .../google/cloud/operators/kubernetes_engine.py    |  41 +++++-
 .../google/cloud/triggers/kubernetes_engine.py     |  72 ++++++++++-
 .../operators/cloud/kubernetes_engine.rst          |   9 ++
 .../cloud/operators/test_kubernetes_engine.py      |  76 +++++++++++
 .../cloud/triggers/test_kubernetes_engine.py       | 144 ++++++++++++++++++---
 .../example_kubernetes_engine_job.py               |  44 ++++++-
 7 files changed, 365 insertions(+), 23 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/triggers/job.py 
b/airflow/providers/cncf/kubernetes/triggers/job.py
index 94f4667691..f229017df1 100644
--- a/airflow/providers/cncf/kubernetes/triggers/job.py
+++ b/airflow/providers/cncf/kubernetes/triggers/job.py
@@ -45,7 +45,7 @@ class KubernetesJobTrigger(BaseTrigger):
         job_name: str,
         job_namespace: str,
         kubernetes_conn_id: str | None = None,
-        poll_interval: float = 2,
+        poll_interval: float = 10.0,
         cluster_context: str | None = None,
         config_file: str | None = None,
         in_cluster: bool | None = None,
diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py 
b/airflow/providers/google/cloud/operators/kubernetes_engine.py
index 438ed7e8ac..8729335c5a 100644
--- a/airflow/providers/google/cloud/operators/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py
@@ -31,6 +31,7 @@ from google.api_core.exceptions import AlreadyExists
 from google.cloud.container_v1.types import Cluster
 from kubernetes.client import V1JobList
 from kubernetes.utils.create_from_yaml import FailToCreateError
+from packaging.version import parse as parse_version
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
@@ -55,7 +56,12 @@ from airflow.providers.google.cloud.links.kubernetes_engine 
import (
     KubernetesEngineWorkloadsLink,
 )
 from airflow.providers.google.cloud.operators.cloud_base import 
GoogleCloudBaseOperator
-from airflow.providers.google.cloud.triggers.kubernetes_engine import 
GKEOperationTrigger, GKEStartPodTrigger
+from airflow.providers.google.cloud.triggers.kubernetes_engine import (
+    GKEJobTrigger,
+    GKEOperationTrigger,
+    GKEStartPodTrigger,
+)
+from airflow.providers_manager import ProvidersManager
 from airflow.utils.timezone import utcnow
 
 if TYPE_CHECKING:
@@ -834,6 +840,9 @@ class GKEStartJobOperator(KubernetesJobOperator):
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account 
(templated).
     :param location: The location param is region name.
+    :param deferrable: Run operator in the deferrable mode.
+    :param poll_interval: (Deferrable mode only) polling period in seconds to
+        check for the status of job.
     """
 
     template_fields: Sequence[str] = tuple(
@@ -850,6 +859,8 @@ class GKEStartJobOperator(KubernetesJobOperator):
         project_id: str | None = None,
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        job_poll_interval: float = 10.0,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -859,6 +870,8 @@ class GKEStartJobOperator(KubernetesJobOperator):
         self.gcp_conn_id = gcp_conn_id
         self.impersonation_chain = impersonation_chain
         self.use_internal_ip = use_internal_ip
+        self.deferrable = deferrable
+        self.job_poll_interval = job_poll_interval
 
         self.job: V1Job | None = None
         self._ssl_ca_cert: str | None = None
@@ -900,6 +913,18 @@ class GKEStartJobOperator(KubernetesJobOperator):
 
     def execute(self, context: Context):
         """Execute process of creating Job."""
+        if self.deferrable:
+            kubernetes_provider = 
ProvidersManager().providers["apache-airflow-providers-cncf-kubernetes"]
+            kubernetes_provider_name = kubernetes_provider.data["package-name"]
+            kubernetes_provider_version = kubernetes_provider.version
+            min_version = "8.0.1"
+            if parse_version(kubernetes_provider_version) <= 
parse_version(min_version):
+                raise AirflowException(
+                    "You are trying to use `GKEStartJobOperator` in deferrable 
mode with the provider "
+                    f"package 
{kubernetes_provider_name}=={kubernetes_provider_version} which doesn't "
+                    f"support this feature. Please upgrade it to version 
higher than {min_version}."
+                )
+
         self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails(
             cluster_name=self.cluster_name,
             project_id=self.project_id,
@@ -909,6 +934,20 @@ class GKEStartJobOperator(KubernetesJobOperator):
 
         return super().execute(context)
 
+    def execute_deferrable(self):
+        self.defer(
+            trigger=GKEJobTrigger(
+                cluster_url=self._cluster_url,
+                ssl_ca_cert=self._ssl_ca_cert,
+                job_name=self.job.metadata.name,  # type: ignore[union-attr]
+                job_namespace=self.job.metadata.namespace,  # type: 
ignore[union-attr]
+                gcp_conn_id=self.gcp_conn_id,
+                poll_interval=self.job_poll_interval,
+                impersonation_chain=self.impersonation_chain,
+            ),
+            method_name="execute_complete",
+        )
+
 
 class GKEDescribeJobOperator(GoogleCloudBaseOperator):
     """
diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py 
b/airflow/providers/google/cloud/triggers/kubernetes_engine.py
index 85776b8a20..c0d8fef97a 100644
--- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py
@@ -27,12 +27,18 @@ from google.cloud.container_v1.types import Operation
 from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
 from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
-from airflow.providers.google.cloud.hooks.kubernetes_engine import 
GKEAsyncHook, GKEPodAsyncHook
+from airflow.providers.google.cloud.hooks.kubernetes_engine import (
+    GKEAsyncHook,
+    GKEKubernetesAsyncHook,
+    GKEPodAsyncHook,
+)
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 
 if TYPE_CHECKING:
     from datetime import datetime
 
+    from kubernetes_asyncio.client import V1Job
+
 
 class GKEStartPodTrigger(KubernetesPodTrigger):
     """
@@ -237,3 +243,67 @@ class GKEOperationTrigger(BaseTrigger):
                 impersonation_chain=self.impersonation_chain,
             )
         return self._hook
+
+
+class GKEJobTrigger(BaseTrigger):
+    """GKEJobTrigger run on the trigger worker to check the state of Job."""
+
+    def __init__(
+        self,
+        cluster_url: str,
+        ssl_ca_cert: str,
+        job_name: str,
+        job_namespace: str,
+        gcp_conn_id: str = "google_cloud_default",
+        poll_interval: float = 2,
+        impersonation_chain: str | Sequence[str] | None = None,
+    ) -> None:
+        super().__init__()
+        self.cluster_url = cluster_url
+        self.ssl_ca_cert = ssl_ca_cert
+        self.job_name = job_name
+        self.job_namespace = job_namespace
+        self.gcp_conn_id = gcp_conn_id
+        self.poll_interval = poll_interval
+        self.impersonation_chain = impersonation_chain
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serialize KubernetesCreateJobTrigger arguments and classpath."""
+        return (
+            
"airflow.providers.google.cloud.triggers.kubernetes_engine.GKEJobTrigger",
+            {
+                "cluster_url": self.cluster_url,
+                "ssl_ca_cert": self.ssl_ca_cert,
+                "job_name": self.job_name,
+                "job_namespace": self.job_namespace,
+                "gcp_conn_id": self.gcp_conn_id,
+                "poll_interval": self.poll_interval,
+                "impersonation_chain": self.impersonation_chain,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:  # type: 
ignore[override]
+        """Get current job status and yield a TriggerEvent."""
+        job: V1Job = await 
self.hook.wait_until_job_complete(name=self.job_name, 
namespace=self.job_namespace)
+        job_dict = job.to_dict()
+        error_message = self.hook.is_job_failed(job=job)
+        status = "error" if error_message else "success"
+        message = f"Job failed with error: {error_message}" if error_message 
else "Job completed successfully"
+        yield TriggerEvent(
+            {
+                "name": job.metadata.name,
+                "namespace": job.metadata.namespace,
+                "status": status,
+                "message": message,
+                "job": job_dict,
+            }
+        )
+
+    @cached_property
+    def hook(self) -> GKEKubernetesAsyncHook:
+        return GKEKubernetesAsyncHook(
+            cluster_url=self.cluster_url,
+            ssl_ca_cert=self.ssl_ca_cert,
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
diff --git 
a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst 
b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst
index 53681ddcee..7d5cc44826 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst
@@ -213,6 +213,15 @@ All Kubernetes parameters (except ``config_file``) are 
also valid for the ``GKES
     :start-after: [START howto_operator_gke_start_job]
     :end-before: [END howto_operator_gke_start_job]
 
+``GKEStartJobOperator`` also supports deferrable mode. Note that it makes 
sense only if the ``wait_until_job_complete``
+parameter is set ``True``.
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_gke_start_job_def]
+    :end-before: [END howto_operator_gke_start_job_def]
+
 For run Job on a GKE cluster with Kueue enabled use 
``GKEStartKueueJobOperator``.
 
 .. exampleinclude:: 
/../../tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py
diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py 
b/tests/providers/google/cloud/operators/test_kubernetes_engine.py
index 62ccb238fb..d752c3780d 100644
--- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py
@@ -71,6 +71,7 @@ TASK_NAME = "test-task-name"
 JOB_NAME = "test-job"
 NAMESPACE = ("default",)
 IMAGE = "bash"
+JOB_POLL_INTERVAL = 20.0
 
 GCLOUD_COMMAND = "gcloud container clusters get-credentials {} --zone {} 
--project {}"
 KUBE_ENV_VAR = "KUBECONFIG"
@@ -708,6 +709,81 @@ class TestGKEStartJobOperator:
         self.gke_op.execute(context=mock.MagicMock())
         fetch_cluster_info_mock.assert_called_once()
 
+    @mock.patch(KUB_JOB_OPERATOR_EXEC)
+    @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
+    @mock.patch(GKE_HOOK_PATH)
+    @mock.patch(f"{GKE_HOOK_MODULE_PATH}.ProvidersManager")
+    def test_execute_in_deferrable_mode(
+        self, mock_providers_manager, mock_hook, fetch_cluster_info_mock, 
exec_mock
+    ):
+        kubernetes_package_name = "apache-airflow-providers-cncf-kubernetes"
+        mock_providers_manager.return_value.providers = {
+            kubernetes_package_name: mock.MagicMock(
+                data={
+                    "package-name": kubernetes_package_name,
+                },
+                version="8.0.2",
+            )
+        }
+        self.gke_op.deferrable = True
+
+        fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
+        self.gke_op.execute(context=mock.MagicMock())
+        fetch_cluster_info_mock.assert_called_once()
+
+    @mock.patch(f"{GKE_HOOK_MODULE_PATH}.ProvidersManager")
+    def test_execute_in_deferrable_mode_exception(self, 
mock_providers_manager):
+        kubernetes_package_name = "apache-airflow-providers-cncf-kubernetes"
+        mock_providers_manager.return_value.providers = {
+            kubernetes_package_name: mock.MagicMock(
+                data={
+                    "package-name": kubernetes_package_name,
+                },
+                version="8.0.1",
+            )
+        }
+        self.gke_op.deferrable = True
+        with pytest.raises(AirflowException):
+            self.gke_op.execute({})
+
+    @mock.patch(f"{GKE_HOOK_MODULE_PATH}.GKEJobTrigger")
+    def test_execute_deferrable(self, mock_trigger):
+        mock_trigger_instance = mock_trigger.return_value
+
+        op = GKEStartJobOperator(
+            project_id=TEST_GCP_PROJECT_ID,
+            location=PROJECT_LOCATION,
+            cluster_name=CLUSTER_NAME,
+            task_id=PROJECT_TASK_ID,
+            name=TASK_NAME,
+            namespace=NAMESPACE,
+            image=IMAGE,
+            job_poll_interval=JOB_POLL_INTERVAL,
+        )
+        op._ssl_ca_cert = SSL_CA_CERT
+        op._cluster_url = CLUSTER_URL
+
+        with mock.patch.object(op, "job") as mock_job:
+            mock_metadata = mock_job.metadata
+            mock_metadata.name = TASK_NAME
+            mock_metadata.namespace = NAMESPACE
+            with mock.patch.object(op, "defer") as mock_defer:
+                op.execute_deferrable()
+
+        mock_trigger.assert_called_once_with(
+            cluster_url=CLUSTER_URL,
+            ssl_ca_cert=SSL_CA_CERT,
+            job_name=TASK_NAME,
+            job_namespace=NAMESPACE,
+            gcp_conn_id="google_cloud_default",
+            poll_interval=JOB_POLL_INTERVAL,
+            impersonation_chain=None,
+        )
+        mock_defer.assert_called_once_with(
+            trigger=mock_trigger_instance,
+            method_name="execute_complete",
+        )
+
     def test_config_file_throws_error(self):
         with pytest.raises(AirflowException):
             GKEStartJobOperator(
diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py 
b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
index c25251886b..8a18dfcb90 100644
--- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
@@ -28,13 +28,20 @@ from google.cloud.container_v1.types import Operation
 from kubernetes.client import models as k8s
 
 from airflow.providers.cncf.kubernetes.triggers.kubernetes_pod import 
ContainerState
-from airflow.providers.google.cloud.triggers.kubernetes_engine import 
GKEOperationTrigger, GKEStartPodTrigger
+from airflow.providers.google.cloud.triggers.kubernetes_engine import (
+    GKEJobTrigger,
+    GKEOperationTrigger,
+    GKEStartPodTrigger,
+)
 from airflow.triggers.base import TriggerEvent
 
-TRIGGER_GKE_PATH = 
"airflow.providers.google.cloud.triggers.kubernetes_engine.GKEStartPodTrigger"
-TRIGGER_KUB_PATH = 
"airflow.providers.cncf.kubernetes.triggers.pod.KubernetesPodTrigger"
+GKE_TRIGGERS_PATH = "airflow.providers.google.cloud.triggers.kubernetes_engine"
+TRIGGER_GKE_POD_PATH = GKE_TRIGGERS_PATH + ".GKEStartPodTrigger"
+TRIGGER_GKE_JOB_PATH = GKE_TRIGGERS_PATH + ".GKEJobTrigger"
+TRIGGER_KUB_POD_PATH = 
"airflow.providers.cncf.kubernetes.triggers.pod.KubernetesPodTrigger"
 HOOK_PATH = 
"airflow.providers.google.cloud.hooks.kubernetes_engine.GKEPodAsyncHook"
 POD_NAME = "test-pod-name"
+JOB_NAME = "test-job-name"
 NAMESPACE = "default"
 POLL_INTERVAL = 2
 CLUSTER_CONTEXT = "test-context"
@@ -54,7 +61,7 @@ PROJECT_ID = "test-project-id"
 LOCATION = "us-central1-c"
 GCP_CONN_ID = "test-non-existing-project-id"
 IMPERSONATION_CHAIN = ["impersonate", "this", "test"]
-TRIGGER_PATH = 
"airflow.providers.google.cloud.triggers.kubernetes_engine.GKEOperationTrigger"
+TRIGGER_PATH = f"{GKE_TRIGGERS_PATH}.GKEOperationTrigger"
 EXC_MSG = "test error msg"
 
 
@@ -78,6 +85,19 @@ def trigger():
     )
 
 
[email protected]
+def job_trigger():
+    return GKEJobTrigger(
+        cluster_url=CLUSTER_URL,
+        ssl_ca_cert=SSL_CA_CERT,
+        job_name=JOB_NAME,
+        job_namespace=NAMESPACE,
+        gcp_conn_id=GCP_CONN_ID,
+        poll_interval=POLL_INTERVAL,
+        impersonation_chain=IMPERSONATION_CHAIN,
+    )
+
+
 class TestGKEStartPodTrigger:
     @staticmethod
     def _mock_pod_result(result_to_mock):
@@ -88,7 +108,7 @@ class TestGKEStartPodTrigger:
     def test_serialize_should_execute_successfully(self, trigger):
         classpath, kwargs_dict = trigger.serialize()
 
-        assert classpath == TRIGGER_GKE_PATH
+        assert classpath == TRIGGER_GKE_POD_PATH
         assert kwargs_dict == {
             "pod_name": POD_NAME,
             "pod_namespace": NAMESPACE,
@@ -108,8 +128,8 @@ class TestGKEStartPodTrigger:
         }
 
     @pytest.mark.asyncio
-    @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
-    @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
+    @mock.patch(f"{TRIGGER_KUB_POD_PATH}._wait_for_pod_start")
+    @mock.patch(f"{TRIGGER_GKE_POD_PATH}.hook")
     async def test_run_loop_return_success_event_should_execute_successfully(
         self, mock_hook, mock_wait_pod, trigger
     ):
@@ -129,8 +149,8 @@ class TestGKEStartPodTrigger:
         assert actual_event == expected_event
 
     @pytest.mark.asyncio
-    @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
-    @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
+    @mock.patch(f"{TRIGGER_KUB_POD_PATH}._wait_for_pod_start")
+    @mock.patch(f"{TRIGGER_GKE_POD_PATH}.hook")
     async def test_run_loop_return_failed_event_should_execute_successfully(
         self, mock_hook, mock_wait_pod, trigger
     ):
@@ -156,9 +176,9 @@ class TestGKEStartPodTrigger:
         assert actual_event == expected_event
 
     @pytest.mark.asyncio
-    @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
-    @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
+    @mock.patch(f"{TRIGGER_KUB_POD_PATH}._wait_for_pod_start")
+    @mock.patch(f"{TRIGGER_KUB_POD_PATH}.define_container_state")
+    @mock.patch(f"{TRIGGER_GKE_POD_PATH}.hook")
     async def test_run_loop_return_waiting_event_should_execute_successfully(
         self, mock_hook, mock_method, mock_wait_pod, trigger, caplog
     ):
@@ -175,9 +195,9 @@ class TestGKEStartPodTrigger:
         assert f"Sleeping for {POLL_INTERVAL} seconds."
 
     @pytest.mark.asyncio
-    @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
-    @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
+    @mock.patch(f"{TRIGGER_KUB_POD_PATH}._wait_for_pod_start")
+    @mock.patch(f"{TRIGGER_KUB_POD_PATH}.define_container_state")
+    @mock.patch(f"{TRIGGER_GKE_POD_PATH}.hook")
     async def test_run_loop_return_running_event_should_execute_successfully(
         self, mock_hook, mock_method, mock_wait_pod, trigger, caplog
     ):
@@ -194,8 +214,8 @@ class TestGKEStartPodTrigger:
         assert f"Sleeping for {POLL_INTERVAL} seconds."
 
     @pytest.mark.asyncio
-    @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start")
-    @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
+    @mock.patch(f"{TRIGGER_KUB_POD_PATH}._wait_for_pod_start")
+    @mock.patch(f"{TRIGGER_GKE_POD_PATH}.hook")
     async def 
test_logging_in_trigger_when_exception_should_execute_successfully(
         self, mock_hook, mock_wait_pod, trigger, caplog
     ):
@@ -216,8 +236,8 @@ class TestGKEStartPodTrigger:
         assert actual_stack_trace.startswith("Traceback (most recent call 
last):")
 
     @pytest.mark.asyncio
-    @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
+    @mock.patch(f"{TRIGGER_KUB_POD_PATH}.define_container_state")
+    @mock.patch(f"{TRIGGER_GKE_POD_PATH}.hook")
     async def test_logging_in_trigger_when_fail_should_execute_successfully(
         self, mock_hook, mock_method, trigger, caplog
     ):
@@ -438,3 +458,89 @@ class TestGKEOperationTrigger:
         assert not task.done()
         assert "Operation is still running."
         assert f"Sleeping for {POLL_INTERVAL}s..."
+
+
+class TestGKEStartJobTrigger:
+    def test_serialize(self, job_trigger):
+        classpath, kwargs_dict = job_trigger.serialize()
+
+        assert classpath == TRIGGER_GKE_JOB_PATH
+        assert kwargs_dict == {
+            "cluster_url": CLUSTER_URL,
+            "ssl_ca_cert": SSL_CA_CERT,
+            "job_name": JOB_NAME,
+            "job_namespace": NAMESPACE,
+            "gcp_conn_id": GCP_CONN_ID,
+            "poll_interval": POLL_INTERVAL,
+            "impersonation_chain": IMPERSONATION_CHAIN,
+        }
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{TRIGGER_GKE_JOB_PATH}.hook")
+    async def test_run_success(self, mock_hook, job_trigger):
+        mock_job = mock.MagicMock()
+        mock_job.metadata.name = JOB_NAME
+        mock_job.metadata.namespace = NAMESPACE
+        mock_hook.wait_until_job_complete.side_effect = 
mock.AsyncMock(return_value=mock_job)
+
+        mock_is_job_failed = mock_hook.is_job_failed
+        mock_is_job_failed.return_value = False
+
+        mock_job_dict = mock_job.to_dict.return_value
+
+        event_actual = await job_trigger.run().asend(None)
+
+        
mock_hook.wait_until_job_complete.assert_called_once_with(name=JOB_NAME, 
namespace=NAMESPACE)
+        mock_job.to_dict.assert_called_once()
+        mock_is_job_failed.assert_called_once_with(job=mock_job)
+        assert event_actual == TriggerEvent(
+            {
+                "name": JOB_NAME,
+                "namespace": NAMESPACE,
+                "status": "success",
+                "message": "Job completed successfully",
+                "job": mock_job_dict,
+            }
+        )
+
+    @pytest.mark.asyncio
+    @mock.patch(f"{TRIGGER_GKE_JOB_PATH}.hook")
+    async def test_run_fail(self, mock_hook, job_trigger):
+        mock_job = mock.MagicMock()
+        mock_job.metadata.name = JOB_NAME
+        mock_job.metadata.namespace = NAMESPACE
+        mock_hook.wait_until_job_complete.side_effect = 
mock.AsyncMock(return_value=mock_job)
+
+        mock_is_job_failed = mock_hook.is_job_failed
+        mock_is_job_failed.return_value = "Error"
+
+        mock_job_dict = mock_job.to_dict.return_value
+
+        event_actual = await job_trigger.run().asend(None)
+
+        
mock_hook.wait_until_job_complete.assert_called_once_with(name=JOB_NAME, 
namespace=NAMESPACE)
+        mock_job.to_dict.assert_called_once()
+        mock_is_job_failed.assert_called_once_with(job=mock_job)
+        assert event_actual == TriggerEvent(
+            {
+                "name": JOB_NAME,
+                "namespace": NAMESPACE,
+                "status": "error",
+                "message": "Job failed with error: Error",
+                "job": mock_job_dict,
+            }
+        )
+
+    @mock.patch(f"{GKE_TRIGGERS_PATH}.GKEKubernetesAsyncHook")
+    def test_hook(self, mock_hook, job_trigger):
+        hook_expected = mock_hook.return_value
+
+        hook_actual = job_trigger.hook
+
+        mock_hook.assert_called_once_with(
+            cluster_url=CLUSTER_URL,
+            ssl_ca_cert=SSL_CA_CERT,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        assert hook_actual == hook_expected
diff --git 
a/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
 
b/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
index 46fbcfcfbc..97b5cd406b 100644
--- 
a/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
+++ 
b/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
@@ -24,6 +24,7 @@ from __future__ import annotations
 import os
 from datetime import datetime
 
+from airflow.models.baseoperator import chain
 from airflow.models.dag import DAG
 from airflow.providers.google.cloud.operators.kubernetes_engine import (
     GKECreateClusterOperator,
@@ -44,6 +45,7 @@ CLUSTER_NAME = f"gke-job-{ENV_ID}".replace("_", "-")
 CLUSTER = {"name": CLUSTER_NAME, "initial_node_count": 1}
 
 JOB_NAME = "test-pi"
+JOB_NAME_DEF = "test-pi-def"
 JOB_NAMESPACE = "default"
 
 with DAG(
@@ -73,6 +75,21 @@ with DAG(
     )
     # [END howto_operator_gke_start_job]
 
+    # [START howto_operator_gke_start_job_def]
+    job_task_def = GKEStartJobOperator(
+        task_id="job_task_def",
+        project_id=GCP_PROJECT_ID,
+        location=GCP_LOCATION,
+        cluster_name=CLUSTER_NAME,
+        namespace=JOB_NAMESPACE,
+        image="perl:5.34.0",
+        cmds=["perl", "-Mbignum=bpi", "-wle", "print bpi(2000)"],
+        name=JOB_NAME_DEF,
+        wait_until_job_complete=True,
+        deferrable=True,
+    )
+    # [END howto_operator_gke_start_job_def]
+
     # [START howto_operator_gke_list_jobs]
     list_job_task = GKEListJobsOperator(
         task_id="list_job_task", project_id=GCP_PROJECT_ID, 
location=GCP_LOCATION, cluster_name=CLUSTER_NAME
@@ -90,6 +107,15 @@ with DAG(
     )
     # [END howto_operator_gke_describe_job]
 
+    describe_job_task_def = GKEDescribeJobOperator(
+        task_id="describe_job_task_def",
+        project_id=GCP_PROJECT_ID,
+        location=GCP_LOCATION,
+        job_name=job_task_def.output["job_name"],
+        namespace="default",
+        cluster_name=CLUSTER_NAME,
+    )
+
     # [START howto_operator_gke_delete_job]
     delete_job = GKEDeleteJobOperator(
         task_id="delete_job",
@@ -101,6 +127,15 @@ with DAG(
     )
     # [END howto_operator_gke_delete_job]
 
+    delete_job_def = GKEDeleteJobOperator(
+        task_id="delete_job_def",
+        project_id=GCP_PROJECT_ID,
+        location=GCP_LOCATION,
+        cluster_name=CLUSTER_NAME,
+        name=JOB_NAME,
+        namespace=JOB_NAMESPACE,
+    )
+
     delete_cluster = GKEDeleteClusterOperator(
         task_id="delete_cluster",
         name=CLUSTER_NAME,
@@ -109,7 +144,14 @@ with DAG(
         trigger_rule=TriggerRule.ALL_DONE,
     )
 
-    create_cluster >> job_task >> [list_job_task, describe_job_task] >> 
delete_job >> delete_cluster
+    chain(
+        create_cluster,
+        [job_task, job_task_def],
+        list_job_task,
+        [describe_job_task, describe_job_task_def],
+        [delete_job, delete_job_def],
+        delete_cluster,
+    )
 
     from tests.system.utils.watcher import watcher
 

Reply via email to