This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 28a240a18f Fix deferrable mode for DataflowTemplatedJobStartOperator 
and DataflowStartFlexTemplateOperator (#39018)
28a240a18f is described below

commit 28a240a18f7e5958e69732f61d639e1d8f39152f
Author: Eugene <[email protected]>
AuthorDate: Mon Apr 29 02:40:20 2024 +0000

    Fix deferrable mode for DataflowTemplatedJobStartOperator and 
DataflowStartFlexTemplateOperator (#39018)
---
 airflow/providers/google/cloud/hooks/dataflow.py   | 177 +++++++++++++++++----
 .../providers/google/cloud/operators/dataflow.py   |  86 ++++++----
 .../providers/google/cloud/triggers/dataflow.py    |   2 +-
 .../operators/cloud/dataflow.rst                   |  20 ++-
 .../providers/google/cloud/hooks/test_dataflow.py  |  52 ++++++
 .../google/cloud/operators/test_dataflow.py        |  50 ++++--
 .../cloud/dataflow/example_dataflow_template.py    |  42 ++++-
 7 files changed, 336 insertions(+), 93 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/dataflow.py 
b/airflow/providers/google/cloud/hooks/dataflow.py
index a9bf802b14..59eee63501 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -41,9 +41,13 @@ from google.cloud.dataflow_v1beta3 import (
     MessagesV1Beta3AsyncClient,
     MetricsV1Beta3AsyncClient,
 )
-from google.cloud.dataflow_v1beta3.types import GetJobMetricsRequest, 
JobMessageImportance, JobMetrics
+from google.cloud.dataflow_v1beta3.types import (
+    GetJobMetricsRequest,
+    JobMessageImportance,
+    JobMetrics,
+)
 from google.cloud.dataflow_v1beta3.types.jobs import ListJobsRequest
-from googleapiclient.discovery import build
+from googleapiclient.discovery import Resource, build
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, 
beam_options_to_args
@@ -573,7 +577,7 @@ class DataflowHook(GoogleBaseHook):
             impersonation_chain=impersonation_chain,
         )
 
-    def get_conn(self) -> build:
+    def get_conn(self) -> Resource:
         """Return a Google Cloud Dataflow service object."""
         http_authorized = self._authorize()
         return build("dataflow", "v1b3", http=http_authorized, 
cache_discovery=False)
@@ -653,9 +657,9 @@ class DataflowHook(GoogleBaseHook):
         on_new_job_callback: Callable[[dict], None] | None = None,
         location: str = DEFAULT_DATAFLOW_LOCATION,
         environment: dict | None = None,
-    ) -> dict:
+    ) -> dict[str, str]:
         """
-        Start Dataflow template job.
+        Launch a Dataflow job with a Classic Template and wait for its 
completion.
 
         :param job_name: The name of the job.
         :param variables: Map of job runtime environment options.
@@ -688,26 +692,14 @@ class DataflowHook(GoogleBaseHook):
             environment=environment,
         )
 
-        service = self.get_conn()
-
-        request = (
-            service.projects()
-            .locations()
-            .templates()
-            .launch(
-                projectId=project_id,
-                location=location,
-                gcsPath=dataflow_template,
-                body={
-                    "jobName": name,
-                    "parameters": parameters,
-                    "environment": environment,
-                },
-            )
+        job: dict[str, str] = self.send_launch_template_request(
+            project_id=project_id,
+            location=location,
+            gcs_path=dataflow_template,
+            job_name=name,
+            parameters=parameters,
+            environment=environment,
         )
-        response = request.execute(num_retries=self.num_retries)
-
-        job = response["job"]
 
         if on_new_job_id_callback:
             warnings.warn(
@@ -715,7 +707,7 @@ class DataflowHook(GoogleBaseHook):
                 AirflowProviderDeprecationWarning,
                 stacklevel=3,
             )
-            on_new_job_id_callback(job.get("id"))
+            on_new_job_id_callback(job["id"])
 
         if on_new_job_callback:
             on_new_job_callback(job)
@@ -734,7 +726,62 @@ class DataflowHook(GoogleBaseHook):
             expected_terminal_state=self.expected_terminal_state,
         )
         jobs_controller.wait_for_done()
-        return response["job"]
+        return job
+
+    @_fallback_to_location_from_variables
+    @_fallback_to_project_id_from_variables
+    @GoogleBaseHook.fallback_to_default_project_id
+    def launch_job_with_template(
+        self,
+        *,
+        job_name: str,
+        variables: dict,
+        parameters: dict,
+        dataflow_template: str,
+        project_id: str,
+        append_job_name: bool = True,
+        location: str = DEFAULT_DATAFLOW_LOCATION,
+        environment: dict | None = None,
+    ) -> dict[str, str]:
+        """
+        Launch a Dataflow job with a Classic Template and exit without waiting 
for its completion.
+
+        :param job_name: The name of the job.
+        :param variables: Map of job runtime environment options.
+            It will update environment argument if passed.
+
+            .. seealso::
+                For more information on possible configurations, look at the 
API documentation
+                
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
+                
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
+
+        :param parameters: Parameters for the template
+        :param dataflow_template: GCS path to the template.
+        :param project_id: Optional, the Google Cloud project ID in which to 
start a job.
+            If set to None or missing, the default project_id from the Google 
Cloud connection is used.
+        :param append_job_name: True if unique suffix has to be appended to 
job name.
+        :param location: Job location.
+
+            .. seealso::
+                For more information on possible configurations, look at the 
API documentation
+                
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
+                
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
+        :return: the Dataflow job response
+        """
+        name = self.build_dataflow_job_name(job_name, append_job_name)
+        environment = self._update_environment(
+            variables=variables,
+            environment=environment,
+        )
+        job: dict[str, str] = self.send_launch_template_request(
+            project_id=project_id,
+            location=location,
+            gcs_path=dataflow_template,
+            job_name=name,
+            parameters=parameters,
+            environment=environment,
+        )
+        return job
 
     def _update_environment(self, variables: dict, environment: dict | None = 
None) -> dict:
         environment = environment or {}
@@ -770,6 +817,35 @@ class DataflowHook(GoogleBaseHook):
 
         return environment
 
+    def send_launch_template_request(
+        self,
+        *,
+        project_id: str,
+        location: str,
+        gcs_path: str,
+        job_name: str,
+        parameters: dict,
+        environment: dict,
+    ) -> dict[str, str]:
+        service: Resource = self.get_conn()
+        request = (
+            service.projects()
+            .locations()
+            .templates()
+            .launch(
+                projectId=project_id,
+                location=location,
+                gcsPath=gcs_path,
+                body={
+                    "jobName": job_name,
+                    "parameters": parameters,
+                    "environment": environment,
+                },
+            )
+        )
+        response: dict = request.execute(num_retries=self.num_retries)
+        return response["job"]
+
     @GoogleBaseHook.fallback_to_default_project_id
     def start_flex_template(
         self,
@@ -778,9 +854,9 @@ class DataflowHook(GoogleBaseHook):
         project_id: str,
         on_new_job_id_callback: Callable[[str], None] | None = None,
         on_new_job_callback: Callable[[dict], None] | None = None,
-    ) -> dict:
+    ) -> dict[str, str]:
         """
-        Start flex templates with the Dataflow pipeline.
+        Launch a Dataflow job with a Flex Template and wait for its completion.
 
         :param body: The request body. See:
             
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
@@ -791,15 +867,16 @@ class DataflowHook(GoogleBaseHook):
         :param on_new_job_callback: A callback that is called when a Job is 
detected.
         :return: the Job
         """
-        service = self.get_conn()
+        service: Resource = self.get_conn()
         request = (
             service.projects()
             .locations()
             .flexTemplates()
             .launch(projectId=project_id, body=body, location=location)
         )
-        response = request.execute(num_retries=self.num_retries)
+        response: dict = request.execute(num_retries=self.num_retries)
         job = response["job"]
+        job_id: str = job["id"]
 
         if on_new_job_id_callback:
             warnings.warn(
@@ -807,7 +884,7 @@ class DataflowHook(GoogleBaseHook):
                 AirflowProviderDeprecationWarning,
                 stacklevel=3,
             )
-            on_new_job_id_callback(job.get("id"))
+            on_new_job_id_callback(job_id)
 
         if on_new_job_callback:
             on_new_job_callback(job)
@@ -815,7 +892,7 @@ class DataflowHook(GoogleBaseHook):
         jobs_controller = _DataflowJobsController(
             dataflow=self.get_conn(),
             project_number=project_id,
-            job_id=job.get("id"),
+            job_id=job_id,
             location=location,
             poll_sleep=self.poll_sleep,
             num_retries=self.num_retries,
@@ -826,6 +903,42 @@ class DataflowHook(GoogleBaseHook):
 
         return jobs_controller.get_jobs(refresh=True)[0]
 
+    @GoogleBaseHook.fallback_to_default_project_id
+    def launch_job_with_flex_template(
+        self,
+        body: dict,
+        location: str,
+        project_id: str,
+    ) -> dict[str, str]:
+        """
+        Launch a Dataflow Job with a Flex Template and exit without waiting 
for the job completion.
+
+        :param body: The request body. See:
+            
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
+        :param location: The location of the Dataflow job (for example 
europe-west1)
+        :param project_id: The ID of the GCP project that owns the job.
+            If set to ``None`` or missing, the default project_id from the GCP 
connection is used.
+        :return: a Dataflow job response
+        """
+        service: Resource = self.get_conn()
+        request = (
+            service.projects()
+            .locations()
+            .flexTemplates()
+            .launch(projectId=project_id, body=body, location=location)
+        )
+        response: dict = request.execute(num_retries=self.num_retries)
+        return response["job"]
+
+    @staticmethod
+    def extract_job_id(job: dict) -> str:
+        try:
+            return job["id"]
+        except KeyError:
+            raise AirflowException(
+                "While reading job object after template execution error 
occurred. Job object has no id."
+            )
+
     @_fallback_to_location_from_variables
     @_fallback_to_project_id_from_variables
     @GoogleBaseHook.fallback_to_default_project_id
diff --git a/airflow/providers/google/cloud/operators/dataflow.py 
b/airflow/providers/google/cloud/operators/dataflow.py
index 4a6f197e14..424cb8d805 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -41,6 +41,7 @@ from airflow.providers.google.cloud.hooks.gcs import GCSHook
 from airflow.providers.google.cloud.links.dataflow import DataflowJobLink
 from airflow.providers.google.cloud.operators.cloud_base import 
GoogleCloudBaseOperator
 from airflow.providers.google.cloud.triggers.dataflow import 
TemplateJobStartTrigger
+from airflow.providers.google.common.consts import 
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 from airflow.providers.google.common.hooks.base_google import 
PROVIDE_PROJECT_ID
 from airflow.version import version
 
@@ -460,7 +461,7 @@ class 
DataflowCreateJavaJobOperator(GoogleCloudBaseOperator):
 
 class DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
     """
-    Start a Templated Cloud Dataflow job; the parameters of the operation will 
be passed to the job.
+    Start a Dataflow job with a classic template; the parameters of the 
operation will be passed to the job.
 
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
@@ -643,7 +644,7 @@ class 
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
         self.deferrable = deferrable
         self.expected_terminal_state = expected_terminal_state
 
-        self.job: dict | None = None
+        self.job: dict[str, str] | None = None
 
         self._validate_deferrable_params()
 
@@ -681,29 +682,34 @@ class 
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
         if not self.location:
             self.location = DEFAULT_DATAFLOW_LOCATION
 
-        self.job = self.hook.start_template_dataflow(
+        if not self.deferrable:
+            self.job = self.hook.start_template_dataflow(
+                job_name=self.job_name,
+                variables=options,
+                parameters=self.parameters,
+                dataflow_template=self.template,
+                on_new_job_callback=set_current_job,
+                project_id=self.project_id,
+                location=self.location,
+                environment=self.environment,
+                append_job_name=self.append_job_name,
+            )
+            job_id = self.hook.extract_job_id(self.job)
+            self.xcom_push(context, key="job_id", value=job_id)
+            return job_id
+
+        self.job = self.hook.launch_job_with_template(
             job_name=self.job_name,
             variables=options,
             parameters=self.parameters,
             dataflow_template=self.template,
-            on_new_job_callback=set_current_job,
             project_id=self.project_id,
+            append_job_name=self.append_job_name,
             location=self.location,
             environment=self.environment,
-            append_job_name=self.append_job_name,
         )
-        job_id = self.job.get("id")
-
-        if job_id is None:
-            raise AirflowException(
-                "While reading job object after template execution error 
occurred. Job object has no id."
-            )
-
-        if not self.deferrable:
-            return job_id
-
-        context["ti"].xcom_push(key="job_id", value=job_id)
-
+        job_id = self.hook.extract_job_id(self.job)
+        DataflowJobLink.persist(self, context, self.project_id, self.location, 
job_id)
         self.defer(
             trigger=TemplateJobStartTrigger(
                 project_id=self.project_id,
@@ -714,16 +720,17 @@ class 
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
                 impersonation_chain=self.impersonation_chain,
                 cancel_timeout=self.cancel_timeout,
             ),
-            method_name="execute_complete",
+            method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
         )
 
-    def execute_complete(self, context: Context, event: dict[str, Any]):
+    def execute_complete(self, context: Context, event: dict[str, Any]) -> str:
         """Execute after trigger finishes its work."""
         if event["status"] in ("error", "stopped"):
             self.log.info("status: %s, msg: %s", event["status"], 
event["message"])
             raise AirflowException(event["message"])
 
         job_id = event["job_id"]
+        self.xcom_push(context, key="job_id", value=job_id)
         self.log.info("Task %s completed with response %s", self.task_id, 
event["message"])
         return job_id
 
@@ -741,7 +748,7 @@ class 
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
 
 class DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
     """
-    Starts flex templates with the Dataflow pipeline.
+    Starts a Dataflow Job with a Flex Template.
 
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
@@ -803,6 +810,9 @@ class 
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
     :param expected_terminal_state: The expected final status of the operator 
on which the corresponding
         Airflow task succeeds. When not specified, it will be determined by 
the hook.
     :param append_job_name: True if unique suffix has to be appended to job 
name.
+    :param poll_sleep: The time in seconds to sleep between polling Google
+        Cloud Platform for the dataflow job status while the job is in the
+        JOB_STATE_RUNNING state.
     """
 
     template_fields: Sequence[str] = ("body", "location", "project_id", 
"gcp_conn_id")
@@ -821,6 +831,7 @@ class 
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         append_job_name: bool = True,
         expected_terminal_state: str | None = None,
+        poll_sleep: int = 10,
         *args,
         **kwargs,
     ) -> None:
@@ -832,11 +843,12 @@ class 
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
         self.drain_pipeline = drain_pipeline
         self.cancel_timeout = cancel_timeout
         self.wait_until_finished = wait_until_finished
-        self.job: dict | None = None
+        self.job: dict[str, str] | None = None
         self.impersonation_chain = impersonation_chain
         self.deferrable = deferrable
         self.expected_terminal_state = expected_terminal_state
         self.append_job_name = append_job_name
+        self.poll_sleep = poll_sleep
 
         self._validate_deferrable_params()
 
@@ -871,32 +883,35 @@ class 
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
             self.job = current_job
             DataflowJobLink.persist(self, context, self.project_id, 
self.location, self.job.get("id"))
 
-        self.job = self.hook.start_flex_template(
+        if not self.deferrable:
+            self.job = self.hook.start_flex_template(
+                body=self.body,
+                location=self.location,
+                project_id=self.project_id,
+                on_new_job_callback=set_current_job,
+            )
+            job_id = self.hook.extract_job_id(self.job)
+            self.xcom_push(context, key="job_id", value=job_id)
+            return self.job
+
+        self.job = self.hook.launch_job_with_flex_template(
             body=self.body,
             location=self.location,
             project_id=self.project_id,
-            on_new_job_callback=set_current_job,
         )
-
-        job_id = self.job.get("id")
-        if job_id is None:
-            raise AirflowException(
-                "While reading job object after template execution error 
occurred. Job object has no id."
-            )
-
-        if not self.deferrable:
-            return self.job
-
+        job_id = self.hook.extract_job_id(self.job)
+        DataflowJobLink.persist(self, context, self.project_id, self.location, 
job_id)
         self.defer(
             trigger=TemplateJobStartTrigger(
                 project_id=self.project_id,
                 job_id=job_id,
                 location=self.location,
                 gcp_conn_id=self.gcp_conn_id,
+                poll_sleep=self.poll_sleep,
                 impersonation_chain=self.impersonation_chain,
                 cancel_timeout=self.cancel_timeout,
             ),
-            method_name="execute_complete",
+            method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
         )
 
     def _append_uuid_to_job_name(self):
@@ -907,7 +922,7 @@ class 
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
             job_body["jobName"] = job_name
             self.log.info("Job name was changed to %s", job_name)
 
-    def execute_complete(self, context: Context, event: dict):
+    def execute_complete(self, context: Context, event: dict) -> dict[str, 
str]:
         """Execute after trigger finishes its work."""
         if event["status"] in ("error", "stopped"):
             self.log.info("status: %s, msg: %s", event["status"], 
event["message"])
@@ -915,6 +930,7 @@ class 
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
 
         job_id = event["job_id"]
         self.log.info("Task %s completed with response %s", job_id, 
event["message"])
+        self.xcom_push(context, key="job_id", value=job_id)
         job = self.hook.get_job(job_id=job_id, project_id=self.project_id, 
location=self.location)
         return job
 
diff --git a/airflow/providers/google/cloud/triggers/dataflow.py 
b/airflow/providers/google/cloud/triggers/dataflow.py
index 32f68a9fd7..577c0bbf60 100644
--- a/airflow/providers/google/cloud/triggers/dataflow.py
+++ b/airflow/providers/google/cloud/triggers/dataflow.py
@@ -138,7 +138,7 @@ class TemplateJobStartTrigger(BaseTrigger):
                     return
                 else:
                     self.log.info("Job is still running...")
-                    self.log.info("Current job status is: %s", status)
+                    self.log.info("Current job status is: %s", status.name)
                     self.log.info("Sleeping for %s seconds.", self.poll_sleep)
                     await asyncio.sleep(self.poll_sleep)
         except Exception as e:
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst 
b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
index d3f1bd6df4..f9302af8c3 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
@@ -208,7 +208,7 @@ from the staging and execution steps. There are two types 
of templates for Dataf
 See the `official documentation for Dataflow templates
 <https://cloud.google.com/dataflow/docs/concepts/dataflow-templates>`_ for 
more information.
 
-Here is an example of running Classic template with
+Here is an example of running a Dataflow job using a Classic Template with
 
:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator`:
 
 .. exampleinclude:: 
/../../tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
@@ -217,10 +217,18 @@ Here is an example of running Classic template with
     :start-after: [START howto_operator_start_template_job]
     :end-before: [END howto_operator_start_template_job]
 
+Also for this action you can use the operator in the deferrable mode:
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_start_template_job_deferrable]
+    :end-before: [END howto_operator_start_template_job_deferrable]
+
 See the `list of Google-provided templates that can be used with this operator
 <https://cloud.google.com/dataflow/docs/guides/templates/provided-templates>`_.
 
-Here is an example of running Flex template with
+Here is an example of running a Dataflow job using a Flex Template with
 
:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowStartFlexTemplateOperator`:
 
 .. exampleinclude:: 
/../../tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
@@ -229,6 +237,14 @@ Here is an example of running Flex template with
     :start-after: [START howto_operator_start_flex_template_job]
     :end-before: [END howto_operator_start_flex_template_job]
 
+Also for this action you can use the operator in the deferrable mode:
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_start_flex_template_job_deferrable]
+    :end-before: [END howto_operator_start_flex_template_job_deferrable]
+
 .. _howto/operator:DataflowStartSqlJobOperator:
 
 Dataflow SQL
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py 
b/tests/providers/google/cloud/hooks/test_dataflow.py
index 2458b48e81..1c8f768ea3 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -1052,6 +1052,34 @@ class TestDataflowTemplateHook:
         )
         mock_uuid.assert_called_once_with()
 
+    @mock.patch(DATAFLOW_STRING.format("uuid.uuid4"), return_value=MOCK_UUID)
+    @mock.patch(DATAFLOW_STRING.format("DataflowHook.get_conn"))
+    def test_launch_job_with_template(self, mock_conn, mock_uuid):
+        launch_method = (
+            
mock_conn.return_value.projects.return_value.locations.return_value.templates.return_value.launch
+        )
+        launch_method.return_value.execute.return_value = {"job": {"id": 
TEST_JOB_ID}}
+        variables = {"zone": "us-central1-f", "tempLocation": "gs://test/temp"}
+        result = self.dataflow_hook.launch_job_with_template(
+            job_name=JOB_NAME,
+            variables=copy.deepcopy(variables),
+            parameters=PARAMETERS,
+            dataflow_template=TEST_TEMPLATE,
+            project_id=TEST_PROJECT,
+        )
+
+        launch_method.assert_called_once_with(
+            body={
+                "jobName": f"test-dataflow-pipeline-{MOCK_UUID_PREFIX}",
+                "parameters": PARAMETERS,
+                "environment": variables,
+            },
+            gcsPath="gs://dataflow-templates/wordcount/template_file",
+            projectId=TEST_PROJECT,
+            location=DEFAULT_DATAFLOW_LOCATION,
+        )
+        assert result == {"id": TEST_JOB_ID}
+
     @mock.patch(DATAFLOW_STRING.format("_DataflowJobsController"))
     @mock.patch(DATAFLOW_STRING.format("DataflowHook.get_conn"))
     def test_start_flex_template(self, mock_conn, mock_controller):
@@ -1088,6 +1116,26 @@ class TestDataflowTemplateHook:
         
mock_controller.return_value.get_jobs.assert_called_once_with(refresh=True)
         assert result == {"id": TEST_JOB_ID}
 
+    @mock.patch(DATAFLOW_STRING.format("DataflowHook.get_conn"))
+    def test_launch_job_with_flex_template(self, mock_conn):
+        expected_job = {"id": TEST_JOB_ID}
+
+        mock_locations = mock_conn.return_value.projects.return_value.locations
+        launch_method = 
mock_locations.return_value.flexTemplates.return_value.launch
+        launch_method.return_value.execute.return_value = {"job": expected_job}
+
+        result = self.dataflow_hook.launch_job_with_flex_template(
+            body={"launchParameter": TEST_FLEX_PARAMETERS},
+            location=TEST_LOCATION,
+            project_id=TEST_PROJECT_ID,
+        )
+        launch_method.assert_called_once_with(
+            projectId="test-project-id",
+            body={"launchParameter": TEST_FLEX_PARAMETERS},
+            location=TEST_LOCATION,
+        )
+        assert result == {"id": TEST_JOB_ID}
+
     @mock.patch(DATAFLOW_STRING.format("_DataflowJobsController"))
     @mock.patch(DATAFLOW_STRING.format("DataflowHook.get_conn"))
     def test_cancel_job(self, mock_get_conn, jobs_controller):
@@ -1177,6 +1225,10 @@ class TestDataflowTemplateHook:
                 on_new_job_callback=mock.MagicMock(),
             )
 
+    def test_extract_job_id_raises_exception(self):
+        with pytest.raises(AirflowException):
+            self.dataflow_hook.extract_job_id({"not_id": True})
+
 
 class TestDataflowJob:
     def setup_method(self):
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py 
b/tests/providers/google/cloud/operators/test_dataflow.py
index 495287b9af..ebbf471383 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -102,6 +102,7 @@ TEST_SQL_JOB = {"id": "test-job-id"}
 GCP_CONN_ID = "test_gcp_conn_id"
 IMPERSONATION_CHAIN = ["impersonate", "this"]
 CANCEL_TIMEOUT = 10 * 420
+DATAFLOW_PATH = "airflow.providers.google.cloud.operators.dataflow"
 
 
 class TestDataflowCreatePythonJobOperator:
@@ -488,11 +489,12 @@ class TestDataflowTemplatedJobStartOperator:
             cancel_timeout=CANCEL_TIMEOUT,
         )
 
-    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
-    def test_exec(self, dataflow_mock, sync_operator):
-        start_template_hook = 
dataflow_mock.return_value.start_template_dataflow
+    @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push")
+    @mock.patch(f"{DATAFLOW_PATH}.DataflowHook")
+    def test_execute(self, hook_mock, mock_xcom_push, sync_operator):
+        start_template_hook = hook_mock.return_value.start_template_dataflow
         sync_operator.execute(None)
-        assert dataflow_mock.called
+        assert hook_mock.called
         expected_options = {
             "project": "test",
             "stagingLocation": "gs://test/staging",
@@ -512,10 +514,27 @@ class TestDataflowTemplatedJobStartOperator:
             append_job_name=True,
         )
 
-    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator.defer")
-    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator.hook")
+    @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.defer")
+    @mock.patch(f"{DATAFLOW_PATH}.DataflowHook")
     def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, 
deferrable_operator):
         deferrable_operator.execute(mock.MagicMock())
+        expected_variables = {
+            "project": "test",
+            "stagingLocation": "gs://test/staging",
+            "tempLocation": "gs://test/temp",
+            "zone": "us-central1-f",
+            "EXTRA_OPTION": "TEST_A",
+        }
+        
mock_hook.return_value.launch_job_with_template.assert_called_once_with(
+            job_name=JOB_NAME,
+            variables=expected_variables,
+            parameters=PARAMETERS,
+            dataflow_template=TEMPLATE,
+            project_id=TEST_PROJECT,
+            append_job_name=True,
+            location=TEST_LOCATION,
+            environment={"maxWorkers": 2},
+        )
         mock_defer_method.assert_called_once()
 
     def test_validation_deferrable_params_raises_error(self):
@@ -540,8 +559,9 @@ class TestDataflowTemplatedJobStartOperator:
             DataflowTemplatedJobStartOperator(**init_kwargs)
 
     @pytest.mark.db_test
-    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook.start_template_dataflow")
-    def test_start_with_custom_region(self, dataflow_mock):
+    @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push")
+    @mock.patch(f"{DATAFLOW_PATH}.DataflowHook.start_template_dataflow")
+    def test_start_with_custom_region(self, dataflow_mock, mock_xcom_push):
         init_kwargs = {
             "task_id": TASK_ID,
             "template": TEMPLATE,
@@ -560,8 +580,9 @@ class TestDataflowTemplatedJobStartOperator:
         assert kwargs["location"] == DEFAULT_DATAFLOW_LOCATION
 
     @pytest.mark.db_test
-    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook.start_template_dataflow")
-    def test_start_with_location(self, dataflow_mock):
+    @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push")
+    @mock.patch(f"{DATAFLOW_PATH}.DataflowHook.start_template_dataflow")
+    def test_start_with_location(self, dataflow_mock, mock_xcom_push):
         init_kwargs = {
             "task_id": TASK_ID,
             "template": TEMPLATE,
@@ -601,7 +622,7 @@ class TestDataflowStartFlexTemplateOperator:
             deferrable=True,
         )
 
-    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
+    @mock.patch(f"{DATAFLOW_PATH}.DataflowHook")
     def test_execute(self, mock_dataflow, sync_operator):
         sync_operator.execute(mock.MagicMock())
         mock_dataflow.assert_called_once_with(
@@ -640,16 +661,15 @@ class TestDataflowStartFlexTemplateOperator:
         with pytest.raises(ValueError):
             DataflowStartFlexTemplateOperator(**init_kwargs)
 
-    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowStartFlexTemplateOperator.defer")
-    
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
+    @mock.patch(f"{DATAFLOW_PATH}.DataflowStartFlexTemplateOperator.defer")
+    @mock.patch(f"{DATAFLOW_PATH}.DataflowHook")
     def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, 
deferrable_operator):
         deferrable_operator.execute(mock.MagicMock())
 
-        mock_hook.return_value.start_flex_template.assert_called_once_with(
+        
mock_hook.return_value.launch_job_with_flex_template.assert_called_once_with(
             body={"launchParameter": TEST_FLEX_PARAMETERS},
             location=TEST_LOCATION,
             project_id=TEST_PROJECT,
-            on_new_job_callback=mock.ANY,
         )
         mock_defer_method.assert_called_once()
 
diff --git 
a/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py 
b/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
index b6eec97a16..2a3e747eb7 100644
--- a/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
+++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
@@ -17,7 +17,8 @@
 # under the License.
 
 """
-Example Airflow DAG for testing Google Dataflow
+Example Airflow DAG for testing Google Dataflow.
+
 
:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator`
 operator.
 """
 
@@ -27,6 +28,7 @@ import os
 from datetime import datetime
 from pathlib import Path
 
+from airflow.models.baseoperator import chain
 from airflow.models.dag import DAG
 from airflow.providers.google.cloud.operators.dataflow import (
     DataflowStartFlexTemplateOperator,
@@ -104,6 +106,7 @@ with DAG(
         template="gs://dataflow-templates/latest/Word_Count",
         parameters={"inputFile": f"gs://{BUCKET_NAME}/{CSV_FILE_NAME}", 
"output": GCS_OUTPUT},
         location=LOCATION,
+        wait_until_finished=True,
     )
     # [END howto_operator_start_template_job]
 
@@ -114,20 +117,43 @@ with DAG(
         body=BODY,
         location=LOCATION,
         append_job_name=False,
+        wait_until_finished=True,
     )
     # [END howto_operator_start_flex_template_job]
 
+    # [START howto_operator_start_template_job_deferrable]
+    start_template_job_deferrable = DataflowTemplatedJobStartOperator(
+        task_id="start_template_job_deferrable",
+        project_id=PROJECT_ID,
+        template="gs://dataflow-templates/latest/Word_Count",
+        parameters={"inputFile": f"gs://{BUCKET_NAME}/{CSV_FILE_NAME}", 
"output": GCS_OUTPUT},
+        location=LOCATION,
+        deferrable=True,
+    )
+    # [END howto_operator_start_template_job_deferrable]
+
+    # [START howto_operator_start_flex_template_job_deferrable]
+    start_flex_template_job_deferrable = DataflowStartFlexTemplateOperator(
+        task_id="start_flex_template_job_deferrable",
+        project_id=PROJECT_ID,
+        body=BODY,
+        location=LOCATION,
+        append_job_name=False,
+        deferrable=True,
+    )
+    # [END howto_operator_start_flex_template_job_deferrable]
+
     delete_bucket = GCSDeleteBucketOperator(
         task_id="delete_bucket", bucket_name=BUCKET_NAME, 
trigger_rule=TriggerRule.ALL_DONE
     )
 
-    (
-        create_bucket
-        >> upload_file
-        >> upload_schema
-        >> start_template_job
-        >> start_flex_template_job
-        >> delete_bucket
+    chain(
+        create_bucket,
+        upload_file,
+        upload_schema,
+        [start_template_job, start_flex_template_job],
+        [start_template_job_deferrable, start_flex_template_job_deferrable],
+        delete_bucket,
     )
 
     from tests.system.utils.watcher import watcher


Reply via email to