This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 094d6bf01b Add deferrable mode to dataflow operators (#27776)
094d6bf01b is described below

commit 094d6bf01b9d8b1a5d358dc10fd561cf3a04c51b
Author: George <[email protected]>
AuthorDate: Mon Jan 30 18:55:21 2023 +0100

    Add deferrable mode to dataflow operators (#27776)
    
    * Add deferrable mode to DataflowTemplatedJobStartOperator and 
DataflowStartFlexTemplateOperator operators
    
    * Change project_id param to be optional, add fixes for tests and docs build
    
    * Add comment about upper-bound for google-cloud-dataflow-client lib and
    change warning message
    
    ---------
    
    Co-authored-by: Heorhi Parkhomenka <[email protected]>
---
 .../google/cloud/example_dags/example_dataflow.py  |   3 +
 .../example_dags/example_dataflow_flex_template.py |   1 +
 airflow/providers/google/cloud/hooks/dataflow.py   | 166 +++++++++++++++----
 airflow/providers/google/cloud/links/dataflow.py   |   2 +-
 .../providers/google/cloud/operators/dataflow.py   | 175 +++++++++++++++++----
 .../providers/google/cloud/triggers/dataflow.py    | 150 ++++++++++++++++++
 airflow/providers/google/provider.yaml             |   3 +
 generated/provider_dependencies.json               |   1 +
 .../providers/google/cloud/hooks/test_dataflow.py  |  55 ++++++-
 .../google/cloud/operators/test_dataflow.py        | 135 +++++++++++++---
 .../google/cloud/triggers/test_dataflow.py         | 168 ++++++++++++++++++++
 .../cloud/dataflow/example_dataflow_template.py    |   2 +
 12 files changed, 769 insertions(+), 92 deletions(-)

diff --git a/airflow/providers/google/cloud/example_dags/example_dataflow.py 
b/airflow/providers/google/cloud/example_dags/example_dataflow.py
index f2a0860fc0..2ab4c04f5b 100644
--- a/airflow/providers/google/cloud/example_dags/example_dataflow.py
+++ b/airflow/providers/google/cloud/example_dags/example_dataflow.py
@@ -52,6 +52,7 @@ GCS_STAGING = os.environ.get("GCP_DATAFLOW_GCS_STAGING", 
"gs://INVALID BUCKET NA
 GCS_OUTPUT = os.environ.get("GCP_DATAFLOW_GCS_OUTPUT", "gs://INVALID BUCKET 
NAME/output")
 GCS_JAR = os.environ.get("GCP_DATAFLOW_JAR", "gs://INVALID BUCKET 
NAME/word-count-beam-bundled-0.1.jar")
 GCS_PYTHON = os.environ.get("GCP_DATAFLOW_PYTHON", "gs://INVALID BUCKET 
NAME/wordcount_debugging.py")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
 
 GCS_JAR_PARTS = urlsplit(GCS_JAR)
 GCS_JAR_BUCKET_NAME = GCS_JAR_PARTS.netloc
@@ -257,6 +258,7 @@ with models.DAG(
     # [START howto_operator_start_template_job]
     start_template_job = DataflowTemplatedJobStartOperator(
         task_id="start-template-job",
+        project_id=PROJECT_ID,
         template="gs://dataflow-templates/latest/Word_Count",
         parameters={"inputFile": 
"gs://dataflow-samples/shakespeare/kinglear.txt", "output": GCS_OUTPUT},
         location="europe-west3",
@@ -279,6 +281,7 @@ with models.DAG(
     # [END howto_operator_stop_dataflow_job]
     start_template_job = DataflowTemplatedJobStartOperator(
         task_id="start-template-job",
+        project_id=PROJECT_ID,
         template="gs://dataflow-templates/latest/Word_Count",
         parameters={"inputFile": 
"gs://dataflow-samples/shakespeare/kinglear.txt", "output": GCS_OUTPUT},
         location="europe-west3",
diff --git 
a/airflow/providers/google/cloud/example_dags/example_dataflow_flex_template.py 
b/airflow/providers/google/cloud/example_dags/example_dataflow_flex_template.py
index 86de7014c5..8af748e8a7 100644
--- 
a/airflow/providers/google/cloud/example_dags/example_dataflow_flex_template.py
+++ 
b/airflow/providers/google/cloud/example_dags/example_dataflow_flex_template.py
@@ -52,6 +52,7 @@ with models.DAG(
     # [START howto_operator_start_template_job]
     start_flex_template = DataflowStartFlexTemplateOperator(
         task_id="start_flex_template_streaming_beam_sql",
+        project_id=GCP_PROJECT_ID,
         body={
             "launchParameter": {
                 "containerSpecGcsPath": GCS_FLEX_TEMPLATE_TEMPLATE_PATH,
diff --git a/airflow/providers/google/cloud/hooks/dataflow.py 
b/airflow/providers/google/cloud/hooks/dataflow.py
index b034dbd5f9..3f7fd4b54e 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -29,11 +29,16 @@ import warnings
 from copy import deepcopy
 from typing import Any, Callable, Generator, Sequence, TypeVar, cast
 
+from google.cloud.dataflow_v1beta3 import GetJobRequest, Job, JobState, 
JobsV1Beta3AsyncClient, JobView
 from googleapiclient.discovery import build
 
 from airflow.exceptions import AirflowException
 from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, 
beam_options_to_args
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+from airflow.providers.google.common.hooks.base_google import (
+    PROVIDE_PROJECT_ID,
+    GoogleBaseAsyncHook,
+    GoogleBaseHook,
+)
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.timeout import timeout
 
@@ -645,36 +650,10 @@ class DataflowHook(GoogleBaseHook):
         """
         name = self.build_dataflow_job_name(job_name, append_job_name)
 
-        environment = environment or {}
-        # available keys for runtime environment are listed here:
-        # 
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
-        environment_keys = [
-            "numWorkers",
-            "maxWorkers",
-            "zone",
-            "serviceAccountEmail",
-            "tempLocation",
-            "bypassTempDirValidation",
-            "machineType",
-            "additionalExperiments",
-            "network",
-            "subnetwork",
-            "additionalUserLabels",
-            "kmsKeyName",
-            "ipConfiguration",
-            "workerRegion",
-            "workerZone",
-        ]
-
-        for key in variables:
-            if key in environment_keys:
-                if key in environment:
-                    self.log.warning(
-                        "'%s' parameter in 'variables' will override of "
-                        "the same one passed in 'environment'!",
-                        key,
-                    )
-                environment.update({key: variables[key]})
+        environment = self._update_environment(
+            variables=variables,
+            environment=environment,
+        )
 
         service = self.get_conn()
 
@@ -723,6 +702,40 @@ class DataflowHook(GoogleBaseHook):
         jobs_controller.wait_for_done()
         return response["job"]
 
+    def _update_environment(self, variables: dict, environment: dict | None = 
None) -> dict:
+        environment = environment or {}
+        # available keys for runtime environment are listed here:
+        # 
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
+        environment_keys = {
+            "numWorkers",
+            "maxWorkers",
+            "zone",
+            "serviceAccountEmail",
+            "tempLocation",
+            "bypassTempDirValidation",
+            "machineType",
+            "additionalExperiments",
+            "network",
+            "subnetwork",
+            "additionalUserLabels",
+            "kmsKeyName",
+            "ipConfiguration",
+            "workerRegion",
+            "workerZone",
+        }
+
+        def _check_one(key, val):
+            if key in environment:
+                self.log.warning(
+                    "%r parameter in 'variables' will override the same one 
passed in 'environment'!",
+                    key,
+                )
+            return key, val
+
+        environment.update(_check_one(key, val) for key, val in 
variables.items() if key in environment_keys)
+
+        return environment
+
     @GoogleBaseHook.fallback_to_default_project_id
     def start_flex_template(
         self,
@@ -731,9 +744,9 @@ class DataflowHook(GoogleBaseHook):
         project_id: str,
         on_new_job_id_callback: Callable[[str], None] | None = None,
         on_new_job_callback: Callable[[dict], None] | None = None,
-    ):
+    ) -> dict:
         """
-        Starts flex templates with the Dataflow  pipeline.
+        Starts flex templates with the Dataflow pipeline.
 
         :param body: The request body. See:
             
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
@@ -1041,7 +1054,7 @@ class DataflowHook(GoogleBaseHook):
     def get_job(
         self,
         job_id: str,
-        project_id: str,
+        project_id: str = PROVIDE_PROJECT_ID,
         location: str = DEFAULT_DATAFLOW_LOCATION,
     ) -> dict:
         """
@@ -1169,3 +1182,88 @@ class DataflowHook(GoogleBaseHook):
             wait_until_finished=self.wait_until_finished,
         )
         job_controller.wait_for_done()
+
+
+class AsyncDataflowHook(GoogleBaseAsyncHook):
+    """Async hook class for dataflow service."""
+
+    sync_hook_class = DataflowHook
+
+    async def initialize_client(self, client_class):
+        """
+        Initialize object of the given class.
+
+        Method is used to initialize asynchronous client. Because of the big 
amount of the classes which are
+        used for Dataflow service it was decided to initialize them the same 
way with credentials which are
+        received from the method of the GoogleBaseHook class.
+        :param client_class: Class of the Google cloud SDK
+        """
+        credentials = (await self.get_sync_hook()).get_credentials()
+        return client_class(
+            credentials=credentials,
+        )
+
+    async def get_project_id(self) -> str:
+        project_id = (await self.get_sync_hook()).project_id
+        return project_id
+
+    async def get_job(
+        self,
+        job_id: str,
+        project_id: str = PROVIDE_PROJECT_ID,
+        job_view: int = JobView.JOB_VIEW_SUMMARY,
+        location: str = DEFAULT_DATAFLOW_LOCATION,
+    ) -> Job:
+        """
+        Gets the job with the specified Job ID.
+
+        :param job_id: Job ID to get.
+        :param project_id: the Google Cloud project ID in which to start a job.
+            If set to None or missing, the default project_id from the Google 
Cloud connection is used.
+        :param job_view: Optional. JobView object which determines 
representation of the returned data
+        :param location: Optional. The location of the Dataflow job (for 
example europe-west1). See:
+            https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
+        """
+        project_id = project_id or (await self.get_project_id())
+        client = await self.initialize_client(JobsV1Beta3AsyncClient)
+
+        request = GetJobRequest(
+            dict(
+                project_id=project_id,
+                job_id=job_id,
+                view=job_view,
+                location=location,
+            )
+        )
+
+        job = await client.get_job(
+            request=request,
+        )
+
+        return job
+
+    async def get_job_status(
+        self,
+        job_id: str,
+        project_id: str = PROVIDE_PROJECT_ID,
+        job_view: int = JobView.JOB_VIEW_SUMMARY,
+        location: str = DEFAULT_DATAFLOW_LOCATION,
+    ) -> JobState:
+        """
+        Gets the job status with the specified Job ID.
+
+        :param job_id: Job ID to get.
+        :param project_id: the Google Cloud project ID in which to start a job.
+            If set to None or missing, the default project_id from the Google 
Cloud connection is used.
+        :param job_view: Optional. JobView object which determines 
representation of the returned data
+        :param location: Optional. The location of the Dataflow job (for 
example europe-west1). See:
+            https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
+        """
+        job = await self.get_job(
+            project_id=project_id,
+            job_id=job_id,
+            job_view=job_view,
+            location=location,
+        )
+        state = job.current_state
+        return state
diff --git a/airflow/providers/google/cloud/links/dataflow.py 
b/airflow/providers/google/cloud/links/dataflow.py
index 1f0f6f87e8..b62e29041d 100644
--- a/airflow/providers/google/cloud/links/dataflow.py
+++ b/airflow/providers/google/cloud/links/dataflow.py
@@ -48,5 +48,5 @@ class DataflowJobLink(BaseGoogleLink):
         operator_instance.xcom_push(
             context,
             key=DataflowJobLink.key,
-            value={"project_id": project_id, "location": region, "job_id": 
job_id},
+            value={"project_id": project_id, "region": region, "job_id": 
job_id},
         )
diff --git a/airflow/providers/google/cloud/operators/dataflow.py 
b/airflow/providers/google/cloud/operators/dataflow.py
index bdff866dd3..66a5da85f3 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -20,11 +20,14 @@ from __future__ import annotations
 
 import copy
 import re
+import uuid
 import warnings
 from contextlib import ExitStack
 from enum import Enum
 from typing import TYPE_CHECKING, Any, Sequence
 
+from airflow import AirflowException
+from airflow.compat.functools import cached_property
 from airflow.models import BaseOperator
 from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType
 from airflow.providers.google.cloud.hooks.dataflow import (
@@ -34,6 +37,7 @@ from airflow.providers.google.cloud.hooks.dataflow import (
 )
 from airflow.providers.google.cloud.hooks.gcs import GCSHook
 from airflow.providers.google.cloud.links.dataflow import DataflowJobLink
+from airflow.providers.google.cloud.triggers.dataflow import 
TemplateJobStartTrigger
 from airflow.version import version
 
 if TYPE_CHECKING:
@@ -55,8 +59,8 @@ class CheckJobRunning(Enum):
 
 class DataflowConfiguration:
     """Dataflow configuration that can be passed to
-    
:py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`
 and
-    
:py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator`.
+    
:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`
 and
+    
:class:`~airflow.providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator`.
 
     :param job_name: The 'jobName' to use when executing the Dataflow job
         (templated). This ends up being set in the pipeline options, so any 
entry
@@ -66,9 +70,8 @@ class DataflowConfiguration:
         If set to None or missing, the default project_id from the Google 
Cloud connection is used.
     :param location: Job location.
     :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
-    :param delegate_to: The account to impersonate using domain-wide 
delegation of authority,
-        if any. For this to work, the service account making the request must 
have
-        domain-wide delegation enabled.
+    :param delegate_to: The account to impersonate using domain-wide 
delegation of authority, if any.
+        For this to work, the service account making the request must have 
domain-wide delegation enabled.
     :param poll_sleep: The time in seconds to sleep between polling Google
         Cloud Platform for the dataflow job status while the job is in the
         JOB_STATE_RUNNING state.
@@ -82,7 +85,6 @@ class DataflowConfiguration:
         account from the list granting this role to the originating account 
(templated).
 
         .. warning::
-
             This option requires Apache Beam 2.39.0 or newer.
 
     :param drain_pipeline: Optional, set to True if want to stop streaming job 
by draining it
@@ -101,7 +103,6 @@ class DataflowConfiguration:
         * for the batch pipeline, wait for the jobs to complete.
 
         .. warning::
-
             You cannot call ``PipelineResult.wait_until_finish`` method in 
your pipeline code for the operator
             to work properly. i. e. you must use asynchronous execution. 
Otherwise, your pipeline will
             always wait until finished. For more information, look at:
@@ -109,10 +110,8 @@ class DataflowConfiguration:
             
<https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#python_10>`__
 
         The process of starting the Dataflow job in Airflow consists of two 
steps:
-
         * running a subprocess and reading the stderr/stderr log for the job 
id.
-        * loop waiting for the end of the job ID from the previous step.
-          This loop checks the status of the job.
+        * loop waiting for the end of the job ID from the previous step by 
checking its status.
 
         Step two is started just after step one has finished, so if you have 
wait_until_finished in your
         pipeline code, step two will not start until the process stops. When 
this process stops,
@@ -124,13 +123,10 @@ class DataflowConfiguration:
         If you in your pipeline do not call the wait_for_pipeline method, and 
pass wait_until_finish=False
         to the operator, the second loop will check once is job not in 
terminal state and exit the loop.
     :param multiple_jobs: If pipeline creates multiple jobs then monitor all 
jobs. Supported only by
-        
:py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`
+        
:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`.
     :param check_if_running: Before running job, validate that a previous run 
is not in process.
-        IgnoreJob = do not check if running.
-        FinishIfRunning = if job is running finish with nothing.
-        WaitForRun = wait until job finished and the run job.
         Supported only by:
-        
:py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`
+        
:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`.
     :param service_account: Run the job as a specific service account, instead 
of the default GCE robot.
     """
 
@@ -230,6 +226,7 @@ class DataflowCreateJavaJobOperator(BaseOperator):
         with key ``'jobName'`` in ``options`` will be overwritten.
     :param dataflow_default_options: Map of default job options.
     :param options: Map of job specific options.The key must be a dictionary.
+
         The value can contain different types:
 
         * If the value is None, the single option - ``--key`` (without value) 
will be added.
@@ -241,6 +238,7 @@ class DataflowCreateJavaJobOperator(BaseOperator):
         * Other value types will be replaced with the Python textual 
representation.
 
         When defining labels (``labels`` option), you can also provide a 
dictionary.
+
     :param project_id: Optional, the Google Cloud project ID in which to start 
a job.
         If set to None or missing, the default project_id from the Google 
Cloud connection is used.
     :param location: Job location.
@@ -583,7 +581,7 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
        )
 
     ``template``, ``dataflow_default_options``, ``parameters``, and 
``job_name`` are
-    templated so you can use variables in them.
+    templated, so you can use variables in them.
 
     Note that ``dataflow_default_options`` is expected to save high-level 
options
     for project information, which apply to all dataflow operators in the DAG.
@@ -594,6 +592,8 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
             
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
             For more detail on job template execution have a look at the 
reference:
             
https://cloud.google.com/dataflow/docs/templates/executing-templates
+
+    :param deferrable: Run operator in the deferrable mode.
     """
 
     template_fields: Sequence[str] = (
@@ -615,11 +615,11 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
         self,
         *,
         template: str,
+        project_id: str | None = None,
         job_name: str = "{{task.task_id}}",
         options: dict[str, Any] | None = None,
         dataflow_default_options: dict[str, Any] | None = None,
         parameters: dict[str, str] | None = None,
-        project_id: str | None = None,
         location: str = DEFAULT_DATAFLOW_LOCATION,
         gcp_conn_id: str = "google_cloud_default",
         delegate_to: str | None = None,
@@ -629,9 +629,11 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
         cancel_timeout: int | None = 10 * 60,
         wait_until_finished: bool | None = None,
         append_job_name: bool = True,
+        deferrable: bool = False,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
+
         self.template = template
         self.job_name = job_name
         self.options = options or {}
@@ -646,16 +648,31 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
             )
         self.delegate_to = delegate_to
         self.poll_sleep = poll_sleep
-        self.job = None
-        self.hook: DataflowHook | None = None
         self.impersonation_chain = impersonation_chain
         self.environment = environment
         self.cancel_timeout = cancel_timeout
         self.wait_until_finished = wait_until_finished
         self.append_job_name = append_job_name
+        self.deferrable = deferrable
 
-    def execute(self, context: Context) -> dict:
-        self.hook = DataflowHook(
+        self.job: dict | None = None
+
+        self._validate_deferrable_params()
+
+    def _validate_deferrable_params(self):
+        if self.deferrable and self.wait_until_finished:
+            raise ValueError(
+                "Conflict between deferrable and wait_until_finished 
parameters "
+                "because it makes operator as blocking when it requires to be 
deferred. "
+                "It should be True as deferrable parameter or True as 
wait_until_finished."
+            )
+
+        if self.deferrable and self.wait_until_finished is None:
+            self.wait_until_finished = False
+
+    @cached_property
+    def hook(self) -> DataflowHook:
+        hook = DataflowHook(
             gcp_conn_id=self.gcp_conn_id,
             delegate_to=self.delegate_to,
             poll_sleep=self.poll_sleep,
@@ -663,14 +680,17 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
             cancel_timeout=self.cancel_timeout,
             wait_until_finished=self.wait_until_finished,
         )
+        return hook
 
+    def execute(self, context: Context):
         def set_current_job(current_job):
             self.job = current_job
             DataflowJobLink.persist(self, context, self.project_id, 
self.location, self.job.get("id"))
 
         options = self.dataflow_default_options
         options.update(self.options)
-        job = self.hook.start_template_dataflow(
+
+        self.job = self.hook.start_template_dataflow(
             job_name=self.job_name,
             variables=options,
             parameters=self.parameters,
@@ -681,13 +701,48 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
             environment=self.environment,
             append_job_name=self.append_job_name,
         )
+        job_id = self.job.get("id")
 
-        return job
+        if job_id is None:
+            raise AirflowException(
+                "While reading job object after template execution error 
occurred. Job object has no id."
+            )
+
+        if not self.deferrable:
+            return job_id
+
+        context["ti"].xcom_push(key="job_id", value=job_id)
+
+        self.defer(
+            trigger=TemplateJobStartTrigger(
+                project_id=self.project_id,
+                job_id=job_id,
+                location=self.location,
+                gcp_conn_id=self.gcp_conn_id,
+                delegate_to=self.delegate_to,
+                poll_sleep=self.poll_sleep,
+                impersonation_chain=self.impersonation_chain,
+                cancel_timeout=self.cancel_timeout,
+            ),
+            method_name="execute_complete",
+        )
+
+    def execute_complete(self, context: Context, event: dict[str, Any]):
+        """Method which executes after trigger finishes its work."""
+        if event["status"] == "error" or event["status"] == "stopped":
+            self.log.info("status: %s, msg: %s", event["status"], 
event["message"])
+            raise AirflowException(event["message"])
+
+        job_id = event["job_id"]
+        self.log.info("Task %s completed with response %s", self.task_id, 
event["message"])
+        return job_id
 
     def on_kill(self) -> None:
         self.log.info("On kill.")
-        if self.job:
+        if self.job is not None:
+            self.log.info("Cancelling job %s", self.job_name)
             self.hook.cancel_job(
+                job_name=self.job_name,
                 job_id=self.job.get("id"),
                 project_id=self.job.get("projectId"),
                 location=self.job.get("location"),
@@ -706,7 +761,6 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
         
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
     :param location: The location of the Dataflow job (for example 
europe-west1)
     :param project_id: The ID of the GCP project that owns the job.
-        If set to ``None`` or missing, the default project_id from the GCP 
connection is used.
     :param gcp_conn_id: The connection ID to use connecting to Google Cloud
         Platform.
     :param delegate_to: The account to impersonate, if any.
@@ -758,6 +812,7 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
         If set as a sequence, the identities from the list must grant
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account 
(templated).
+    :param deferrable: Run operator in the deferrable mode.
     """
 
     template_fields: Sequence[str] = ("body", "location", "project_id", 
"gcp_conn_id")
@@ -774,6 +829,7 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
         cancel_timeout: int | None = 10 * 60,
         wait_until_finished: bool | None = None,
         impersonation_chain: str | Sequence[str] | None = None,
+        deferrable: bool = False,
         *args,
         **kwargs,
     ) -> None:
@@ -790,12 +846,26 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
         self.drain_pipeline = drain_pipeline
         self.cancel_timeout = cancel_timeout
         self.wait_until_finished = wait_until_finished
-        self.job = None
-        self.hook: DataflowHook | None = None
+        self.job: dict | None = None
         self.impersonation_chain = impersonation_chain
+        self.deferrable = deferrable
 
-    def execute(self, context: Context):
-        self.hook = DataflowHook(
+        self._validate_deferrable_params()
+
+    def _validate_deferrable_params(self):
+        if self.deferrable and self.wait_until_finished:
+            raise ValueError(
+                "Conflict between deferrable and wait_until_finished 
parameters "
+                "because it makes operator as blocking when it requires to be 
deferred. "
+                "It should be True as deferrable parameter or True as 
wait_until_finished."
+            )
+
+        if self.deferrable and self.wait_until_finished is None:
+            self.wait_until_finished = False
+
+    @cached_property
+    def hook(self) -> DataflowHook:
+        hook = DataflowHook(
             gcp_conn_id=self.gcp_conn_id,
             delegate_to=self.delegate_to,
             drain_pipeline=self.drain_pipeline,
@@ -803,23 +873,66 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
             wait_until_finished=self.wait_until_finished,
             impersonation_chain=self.impersonation_chain,
         )
+        return hook
+
+    def execute(self, context: Context):
+        self._append_uuid_to_job_name()
 
         def set_current_job(current_job):
             self.job = current_job
             DataflowJobLink.persist(self, context, self.project_id, 
self.location, self.job.get("id"))
 
-        job = self.hook.start_flex_template(
+        self.job = self.hook.start_flex_template(
             body=self.body,
             location=self.location,
             project_id=self.project_id,
             on_new_job_callback=set_current_job,
         )
 
+        job_id = self.job.get("id")
+        if job_id is None:
+            raise AirflowException(
+                "While reading job object after template execution error 
occurred. Job object has no id."
+            )
+
+        if not self.deferrable:
+            return self.job
+
+        self.defer(
+            trigger=TemplateJobStartTrigger(
+                project_id=self.project_id,
+                job_id=job_id,
+                location=self.location,
+                gcp_conn_id=self.gcp_conn_id,
+                delegate_to=self.delegate_to,
+                impersonation_chain=self.impersonation_chain,
+                cancel_timeout=self.cancel_timeout,
+            ),
+            method_name="execute_complete",
+        )
+
+    def _append_uuid_to_job_name(self):
+        job_body = self.body.get("launch_parameter") or 
self.body.get("launchParameter")
+        job_name = job_body.get("jobName")
+        if job_name:
+            job_name += f"-{str(uuid.uuid4())[:8]}"
+            job_body["jobName"] = job_name
+            self.log.info("Job name was changed to %s", job_name)
+
+    def execute_complete(self, context: Context, event: dict):
+        """Method which executes after trigger finishes its work."""
+        if event["status"] == "error" or event["status"] == "stopped":
+            self.log.info("status: %s, msg: %s", event["status"], 
event["message"])
+            raise AirflowException(event["message"])
+
+        job_id = event["job_id"]
+        self.log.info("Task %s completed with response %s", job_id, 
event["message"])
+        job = self.hook.get_job(job_id=job_id, project_id=self.project_id, 
location=self.location)
         return job
 
     def on_kill(self) -> None:
         self.log.info("On kill.")
-        if self.job:
+        if self.job is not None:
             self.hook.cancel_job(
                 job_id=self.job.get("id"),
                 project_id=self.job.get("projectId"),
diff --git a/airflow/providers/google/cloud/triggers/dataflow.py 
b/airflow/providers/google/cloud/triggers/dataflow.py
new file mode 100644
index 0000000000..5167323789
--- /dev/null
+++ b/airflow/providers/google/cloud/triggers/dataflow.py
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+from typing import Any, Sequence
+
+from google.cloud.dataflow_v1beta3 import JobState
+
+from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+DEFAULT_DATAFLOW_LOCATION = "us-central1"
+
+
+class TemplateJobStartTrigger(BaseTrigger):
+    """Dataflow trigger to check if templated job has been finished.
+
+    :param project_id: Required. the Google Cloud project ID in which the job 
was started.
+    :param job_id: Required. ID of the job.
+    :param location: Optional. the location where job is executed. If set to 
None then
+        the value of DEFAULT_DATAFLOW_LOCATION will be used
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :param delegate_to: The account to impersonate using domain-wide 
delegation of authority,
+        if any. For this to work, the service account making the request must 
have
+        domain-wide delegation enabled.
+    :param impersonation_chain: Optional. Service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    :param cancel_timeout: Optional. How long (in seconds) operator should 
wait for the pipeline to be
+        successfully cancelled when task is being killed.
+    """
+
+    def __init__(
+        self,
+        job_id: str,
+        project_id: str | None,
+        location: str = DEFAULT_DATAFLOW_LOCATION,
+        gcp_conn_id: str = "google_cloud_default",
+        delegate_to: str | None = None,
+        poll_sleep: int = 10,
+        impersonation_chain: str | Sequence[str] | None = None,
+        cancel_timeout: int | None = 5 * 60,
+    ):
+        super().__init__()
+
+        self.project_id = project_id
+        self.job_id = job_id
+        self.location = location
+        self.gcp_conn_id = gcp_conn_id
+        self.delegate_to = delegate_to
+        self.poll_sleep = poll_sleep
+        self.impersonation_chain = impersonation_chain
+        self.cancel_timeout = cancel_timeout
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes class arguments and classpath."""
+        return (
+            
"airflow.providers.google.cloud.triggers.dataflow.TemplateJobStartTrigger",
+            {
+                "project_id": self.project_id,
+                "job_id": self.job_id,
+                "location": self.location,
+                "gcp_conn_id": self.gcp_conn_id,
+                "delegate_to": self.delegate_to,
+                "poll_sleep": self.poll_sleep,
+                "impersonation_chain": self.impersonation_chain,
+                "cancel_timeout": self.cancel_timeout,
+            },
+        )
+
+    async def run(self):
+        """
+        Main loop of the class in where it is fetching the job status and 
yields certain Event.
+
+        If the job has status success then it yields TriggerEvent with success 
status, if job has
+        status failed - with error status. In any other case Trigger will wait 
for specified
+        amount of time stored in self.poll_sleep variable.
+        """
+        hook = self._get_async_hook()
+        while True:
+            try:
+                status = await hook.get_job_status(
+                    project_id=self.project_id,
+                    job_id=self.job_id,
+                    location=self.location,
+                )
+                if status == JobState.JOB_STATE_DONE:
+                    yield TriggerEvent(
+                        {
+                            "job_id": self.job_id,
+                            "status": "success",
+                            "message": "Job completed",
+                        }
+                    )
+                    return
+                elif status == JobState.JOB_STATE_FAILED:
+                    yield TriggerEvent(
+                        {
+                            "status": "error",
+                            "message": f"Dataflow job with id {self.job_id} 
has failed its execution",
+                        }
+                    )
+                    return
+                elif status == JobState.JOB_STATE_STOPPED:
+                    yield TriggerEvent(
+                        {
+                            "status": "stopped",
+                            "message": f"Dataflow job with id {self.job_id} 
was stopped",
+                        }
+                    )
+                    return
+                else:
+                    self.log.info("Job is still running...")
+                    self.log.info("Current job status is: %s", status)
+                    self.log.info("Sleeping for %s seconds.", self.poll_sleep)
+                    await asyncio.sleep(self.poll_sleep)
+            except Exception as e:
+                self.log.exception("Exception occurred while checking for job 
completion.")
+                yield TriggerEvent({"status": "error", "message": str(e)})
+                return
+
+    def _get_async_hook(self) -> AsyncDataflowHook:
+        return AsyncDataflowHook(
+            gcp_conn_id=self.gcp_conn_id,
+            delegate_to=self.delegate_to,
+            poll_sleep=self.poll_sleep,
+            impersonation_chain=self.impersonation_chain,
+            cancel_timeout=self.cancel_timeout,
+        )
diff --git a/airflow/providers/google/provider.yaml 
b/airflow/providers/google/provider.yaml
index 3269c4f597..faf5bcaadc 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -85,6 +85,9 @@ dependencies:
   - google-cloud-build>=3.0.0
   - google-cloud-compute>=0.1.0,<2.0.0
   - google-cloud-container>=2.2.0,<3.0.0
+  # google-cloud-dataflow-client of version 0.5.5 requires higher versions of
+  # protobuf and proto-plus libraries which can break other dependencies in 
the current package.
+  - google-cloud-dataflow-client>=0.5.2,<0.5.5
   - google-cloud-dataform>=0.2.0
   - google-cloud-datacatalog>=3.0.0
   - google-cloud-dataplex>=0.1.0
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index d23e804886..dd67a7c61e 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -339,6 +339,7 @@
       "google-cloud-compute>=0.1.0,<2.0.0",
       "google-cloud-container>=2.2.0,<3.0.0",
       "google-cloud-datacatalog>=3.0.0",
+      "google-cloud-dataflow-client>=0.5.2,<0.5.5",
       "google-cloud-dataform>=0.2.0",
       "google-cloud-dataplex>=0.1.0",
       "google-cloud-dataproc-metastore>=1.2.0,<2.0.0",
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py 
b/tests/providers/google/cloud/hooks/test_dataflow.py
index df9991fe57..1e48eb6a56 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -20,17 +20,20 @@ from __future__ import annotations
 import copy
 import re
 import shlex
+import sys
+from asyncio import Future
 from typing import Any
-from unittest import mock
 from unittest.mock import MagicMock
 from uuid import UUID
 
 import pytest
+from google.cloud.dataflow_v1beta3 import GetJobRequest, JobView
 
 from airflow.exceptions import AirflowException
 from airflow.providers.apache.beam.hooks.beam import BeamCommandRunner, 
BeamHook
 from airflow.providers.google.cloud.hooks.dataflow import (
     DEFAULT_DATAFLOW_LOCATION,
+    AsyncDataflowHook,
     DataflowHook,
     DataflowJobStatus,
     DataflowJobType,
@@ -39,6 +42,12 @@ from airflow.providers.google.cloud.hooks.dataflow import (
     process_line_and_extract_dataflow_job_id_callback,
 )
 
+if sys.version_info < (3, 8):
+    from asynctest import mock
+else:
+    from unittest import mock
+
+
 DEFAULT_RUNNER = "DirectRunner"
 BEAM_STRING = "airflow.providers.apache.beam.hooks.beam.{}"
 
@@ -52,6 +61,7 @@ PARAMETERS = {
     "inputFile": "gs://dataflow-samples/shakespeare/kinglear.txt",
     "output": "gs://test/output/my_output",
 }
+TEST_ENVIRONMENT = {}
 PY_FILE = "apache_beam.examples.wordcount"
 JAR_FILE = "unitest.jar"
 JOB_CLASS = "com.example.UnitTest"
@@ -1882,3 +1892,46 @@ class TestDataflow:
         mock_logging.info.assert_called_once_with("Running command: %s", "test 
cmd")
         with pytest.raises(Exception):
             dataflow.wait_for_done()
+
+
[email protected]()
+def make_mock_awaitable():
+    def func(mock_obj, return_value):
+        f = Future()
+        f.set_result(return_value)
+        mock_obj.return_value = f
+
+    return func
+
+
+class TestAsyncHook:
+    @pytest.fixture
+    def hook(self):
+        return AsyncDataflowHook(
+            gcp_conn_id=TEST_PROJECT_ID,
+        )
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.initialize_client")
+    async def test_get_job(self, initialize_client_mock, hook, 
make_mock_awaitable):
+        client = initialize_client_mock.return_value
+        make_mock_awaitable(client.get_job, None)
+
+        await hook.get_job(
+            project_id=TEST_PROJECT_ID,
+            job_id=TEST_JOB_ID,
+            location=TEST_LOCATION,
+        )
+        request = GetJobRequest(
+            dict(
+                project_id=TEST_PROJECT_ID,
+                job_id=TEST_JOB_ID,
+                location=TEST_LOCATION,
+                view=JobView.JOB_VIEW_SUMMARY,
+            )
+        )
+
+        initialize_client_mock.assert_called_once()
+        client.get_job.assert_called_once_with(
+            request=request,
+        )
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py 
b/tests/providers/google/cloud/operators/test_dataflow.py
index 40e37c9487..c9f5364f66 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -22,6 +22,8 @@ import unittest
 from copy import deepcopy
 from unittest import mock
 
+import pytest as pytest
+
 import airflow
 from airflow.providers.google.cloud.operators.dataflow import (
     CheckJobRunning,
@@ -93,6 +95,10 @@ FROM
 GROUP BY sales_region;
 """
 TEST_SQL_JOB = {"id": "test-job-id"}
+GCP_CONN_ID = "test_gcp_conn_id"
+DELEGATE_TO = "delegating_to_something"
+IMPERSONATION_CHAIN = ["impersonate", "this"]
+CANCEL_TIMEOUT = 10 * 420
 
 
 class TestDataflowPythonOperator(unittest.TestCase):
@@ -445,9 +451,11 @@ class TestDataflowJavaOperatorWithLocal(unittest.TestCase):
         )
 
 
-class TestDataflowTemplateOperator(unittest.TestCase):
-    def setUp(self):
-        self.dataflow = DataflowTemplatedJobStartOperator(
+class TestDataflowTemplateOperator:
+    @pytest.fixture
+    def sync_operator(self):
+        return DataflowTemplatedJobStartOperator(
+            project_id=TEST_PROJECT,
             task_id=TASK_ID,
             template=TEMPLATE,
             job_name=JOB_NAME,
@@ -459,14 +467,30 @@ class TestDataflowTemplateOperator(unittest.TestCase):
             environment={"maxWorkers": 2},
         )
 
-    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
-    def test_exec(self, dataflow_mock):
-        """Test DataflowHook is created and the right args are passed to
-        start_template_workflow.
+    @pytest.fixture
+    def deferrable_operator(self):
+        return DataflowTemplatedJobStartOperator(
+            project_id=TEST_PROJECT,
+            task_id=TASK_ID,
+            template=TEMPLATE,
+            job_name=JOB_NAME,
+            parameters=PARAMETERS,
+            options=DEFAULT_OPTIONS_TEMPLATE,
+            dataflow_default_options={"EXTRA_OPTION": "TEST_A"},
+            poll_sleep=POLL_SLEEP,
+            location=TEST_LOCATION,
+            environment={"maxWorkers": 2},
+            deferrable=True,
+            gcp_conn_id=GCP_CONN_ID,
+            delegate_to=DELEGATE_TO,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            cancel_timeout=CANCEL_TIMEOUT,
+        )
 
-        """
+    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
+    def test_exec(self, dataflow_mock, sync_operator):
         start_template_hook = 
dataflow_mock.return_value.start_template_dataflow
-        self.dataflow.execute(None)
+        sync_operator.execute(None)
         assert dataflow_mock.called
         expected_options = {
             "project": "test",
@@ -481,24 +505,66 @@ class TestDataflowTemplateOperator(unittest.TestCase):
             parameters=PARAMETERS,
             dataflow_template=TEMPLATE,
             on_new_job_callback=mock.ANY,
-            project_id=None,
+            project_id=TEST_PROJECT,
             location=TEST_LOCATION,
             environment={"maxWorkers": 2},
             append_job_name=True,
         )
 
+    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator.defer")
+    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator.hook")
+    def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, 
deferrable_operator):
+        deferrable_operator.execute(mock.MagicMock())
+        mock_defer_method.assert_called_once()
+
+    def test_validation_deferrable_params_raises_error(self):
+        init_kwargs = {
+            "project_id": TEST_PROJECT,
+            "task_id": TASK_ID,
+            "template": TEMPLATE,
+            "job_name": JOB_NAME,
+            "parameters": PARAMETERS,
+            "options": DEFAULT_OPTIONS_TEMPLATE,
+            "dataflow_default_options": {"EXTRA_OPTION": "TEST_A"},
+            "poll_sleep": POLL_SLEEP,
+            "location": TEST_LOCATION,
+            "environment": {"maxWorkers": 2},
+            "wait_until_finished": True,
+            "deferrable": True,
+            "gcp_conn_id": GCP_CONN_ID,
+            "delegate_to": DELEGATE_TO,
+            "impersonation_chain": IMPERSONATION_CHAIN,
+            "cancel_timeout": CANCEL_TIMEOUT,
+        }
+        with pytest.raises(ValueError):
+            DataflowTemplatedJobStartOperator(**init_kwargs)
 
-class TestDataflowStartFlexTemplateOperator(unittest.TestCase):
-    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
-    def test_execute(self, mock_dataflow):
-        start_flex_template = DataflowStartFlexTemplateOperator(
+
+class TestDataflowStartFlexTemplateOperator:
+    @pytest.fixture
+    def sync_operator(self):
+        return DataflowStartFlexTemplateOperator(
             task_id="start_flex_template_streaming_beam_sql",
             body={"launchParameter": TEST_FLEX_PARAMETERS},
             do_xcom_push=True,
             project_id=TEST_PROJECT,
             location=TEST_LOCATION,
         )
-        start_flex_template.execute(mock.MagicMock())
+
+    @pytest.fixture
+    def deferrable_operator(self):
+        return DataflowStartFlexTemplateOperator(
+            task_id="start_flex_template_streaming_beam_sql",
+            body={"launchParameter": TEST_FLEX_PARAMETERS},
+            do_xcom_push=True,
+            project_id=TEST_PROJECT,
+            location=TEST_LOCATION,
+            deferrable=True,
+        )
+
+    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
+    def test_execute(self, mock_dataflow, sync_operator):
+        sync_operator.execute(mock.MagicMock())
         mock_dataflow.assert_called_once_with(
             gcp_conn_id="google_cloud_default",
             delegate_to=None,
@@ -514,20 +580,39 @@ class 
TestDataflowStartFlexTemplateOperator(unittest.TestCase):
             on_new_job_callback=mock.ANY,
         )
 
-    def test_on_kill(self):
-        start_flex_template = DataflowStartFlexTemplateOperator(
-            task_id="start_flex_template_streaming_beam_sql",
+    def test_on_kill(self, sync_operator):
+        sync_operator.hook = mock.MagicMock()
+        sync_operator.job = {"id": JOB_ID, "projectId": TEST_PROJECT, 
"location": TEST_LOCATION}
+        sync_operator.on_kill()
+        sync_operator.hook.cancel_job.assert_called_once_with(
+            job_id="test-dataflow-pipeline-id", project_id=TEST_PROJECT, 
location=TEST_LOCATION
+        )
+
+    def test_validation_deferrable_params_raises_error(self):
+        init_kwargs = {
+            "task_id": "start_flex_template_streaming_beam_sql",
+            "body": {"launchParameter": TEST_FLEX_PARAMETERS},
+            "do_xcom_push": True,
+            "location": TEST_LOCATION,
+            "project_id": TEST_PROJECT,
+            "wait_until_finished": True,
+            "deferrable": True,
+        }
+        with pytest.raises(ValueError):
+            DataflowStartFlexTemplateOperator(**init_kwargs)
+
+    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowStartFlexTemplateOperator.defer")
+    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
+    def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, 
deferrable_operator):
+        deferrable_operator.execute(mock.MagicMock())
+
+        mock_hook.return_value.start_flex_template.assert_called_once_with(
             body={"launchParameter": TEST_FLEX_PARAMETERS},
-            do_xcom_push=True,
             location=TEST_LOCATION,
             project_id=TEST_PROJECT,
+            on_new_job_callback=mock.ANY,
         )
-        start_flex_template.hook = mock.MagicMock()
-        start_flex_template.job = {"id": JOB_ID, "projectId": TEST_PROJECT, 
"location": TEST_LOCATION}
-        start_flex_template.on_kill()
-        start_flex_template.hook.cancel_job.assert_called_once_with(
-            job_id="test-dataflow-pipeline-id", project_id=TEST_PROJECT, 
location=TEST_LOCATION
-        )
+        mock_defer_method.assert_called_once()
 
 
 class TestDataflowSqlOperator(unittest.TestCase):
diff --git a/tests/providers/google/cloud/triggers/test_dataflow.py 
b/tests/providers/google/cloud/triggers/test_dataflow.py
new file mode 100644
index 0000000000..2f055ecfe0
--- /dev/null
+++ b/tests/providers/google/cloud/triggers/test_dataflow.py
@@ -0,0 +1,168 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import sys
+from asyncio import Future
+
+import pytest
+from google.cloud.dataflow_v1beta3 import JobState
+
+from airflow.providers.google.cloud.triggers.dataflow import 
TemplateJobStartTrigger
+from airflow.triggers.base import TriggerEvent
+
+if sys.version_info < (3, 8):
+    from asynctest import mock
+else:
+    from unittest import mock
+
+PROJECT_ID = "test-project-id"
+JOB_ID = "test_job_id_2012-12-23-10:00"
+LOCATION = "us-central1"
+GCP_CONN_ID = "test_gcp_conn_id"
+DELEGATE_TO = "delegating_to_something"
+POLL_SLEEP = 20
+IMPERSONATION_CHAIN = ["impersonate", "this"]
+CANCEL_TIMEOUT = 10 * 420
+
+
[email protected]
+def trigger():
+    return TemplateJobStartTrigger(
+        project_id=PROJECT_ID,
+        job_id=JOB_ID,
+        location=LOCATION,
+        gcp_conn_id=GCP_CONN_ID,
+        delegate_to=DELEGATE_TO,
+        poll_sleep=POLL_SLEEP,
+        impersonation_chain=IMPERSONATION_CHAIN,
+        cancel_timeout=CANCEL_TIMEOUT,
+    )
+
+
[email protected]()
+def make_mock_awaitable():
+    def func(mock_obj, return_value):
+        if sys.version_info < (3, 8):
+            f = Future()
+            f.set_result(return_value)
+            mock_obj.return_value = f
+        else:
+            mock_obj.return_value = return_value
+        return mock_obj
+
+    return func
+
+
+def test_serialize(trigger):
+    actual_data = trigger.serialize()
+    expected_data = (
+        
"airflow.providers.google.cloud.triggers.dataflow.TemplateJobStartTrigger",
+        {
+            "project_id": PROJECT_ID,
+            "job_id": JOB_ID,
+            "location": LOCATION,
+            "gcp_conn_id": GCP_CONN_ID,
+            "delegate_to": DELEGATE_TO,
+            "poll_sleep": POLL_SLEEP,
+            "impersonation_chain": IMPERSONATION_CHAIN,
+            "cancel_timeout": CANCEL_TIMEOUT,
+        },
+    )
+    assert actual_data == expected_data
+
+
[email protected](
+    "attr, expected",
+    [
+        ("gcp_conn_id", GCP_CONN_ID),
+        ("delegate_to", DELEGATE_TO),
+        ("poll_sleep", POLL_SLEEP),
+        ("impersonation_chain", IMPERSONATION_CHAIN),
+        ("cancel_timeout", CANCEL_TIMEOUT),
+    ],
+)
+def test_get_async_hook(trigger, attr, expected):
+    hook = trigger._get_async_hook()
+    actual = hook._hook_kwargs.get(attr)
+    assert actual is not None
+    assert actual == expected
+
+
[email protected]
[email protected]("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status")
+async def test_run_loop_return_success_event(mock_job_status, trigger, 
make_mock_awaitable):
+    make_mock_awaitable(mock_job_status, JobState.JOB_STATE_DONE)
+
+    expected_event = TriggerEvent(
+        {
+            "job_id": JOB_ID,
+            "status": "success",
+            "message": "Job completed",
+        }
+    )
+    actual_event = await (trigger.run()).asend(None)
+
+    assert actual_event == expected_event
+
+
[email protected]
[email protected]("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status")
+async def test_run_loop_return_failed_event(mock_job_status, trigger, 
make_mock_awaitable):
+    make_mock_awaitable(mock_job_status, JobState.JOB_STATE_FAILED)
+
+    expected_event = TriggerEvent(
+        {
+            "status": "error",
+            "message": f"Dataflow job with id {JOB_ID} has failed its 
execution",
+        }
+    )
+    actual_event = await (trigger.run()).asend(None)
+
+    assert actual_event == expected_event
+
+
[email protected]
[email protected]("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status")
+async def test_run_loop_return_stopped_event(mock_job_status, trigger, 
make_mock_awaitable):
+    make_mock_awaitable(mock_job_status, JobState.JOB_STATE_STOPPED)
+    expected_event = TriggerEvent(
+        {
+            "status": "stopped",
+            "message": f"Dataflow job with id {JOB_ID} was stopped",
+        }
+    )
+    actual_event = await (trigger.run()).asend(None)
+
+    assert actual_event == expected_event
+
+
[email protected]
[email protected]("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status")
+async def test_run_loop_is_still_running(mock_job_status, trigger, caplog, 
make_mock_awaitable):
+    make_mock_awaitable(mock_job_status, JobState.JOB_STATE_RUNNING)
+    caplog.set_level(logging.INFO)
+
+    task = asyncio.create_task(trigger.run().__anext__())
+    await asyncio.sleep(0.5)
+
+    assert not task.done()
+    assert f"Current job status is: {JobState.JOB_STATE_RUNNING}"
+    assert f"Sleeping for {POLL_SLEEP} seconds."
diff --git 
a/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py 
b/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
index 9e9bb0b654..1247e1df41 100644
--- a/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
+++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
@@ -33,6 +33,7 @@ from airflow.providers.google.cloud.transfers.local_to_gcs 
import LocalFilesyste
 from airflow.utils.trigger_rule import TriggerRule
 
 ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
 DAG_ID = "dataflow_template"
 
 BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}"
@@ -71,6 +72,7 @@ with models.DAG(
     # [START howto_operator_start_template_job]
     start_template_job = DataflowTemplatedJobStartOperator(
         task_id="start_template_job",
+        project_id=PROJECT_ID,
         template="gs://dataflow-templates/latest/Word_Count",
         parameters={"inputFile": f"gs://{BUCKET_NAME}/{FILE_NAME}", "output": 
GCS_OUTPUT},
         location=LOCATION,

Reply via email to