This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 094d6bf01b Add deferrable mode to dataflow operators (#27776)
094d6bf01b is described below
commit 094d6bf01b9d8b1a5d358dc10fd561cf3a04c51b
Author: George <[email protected]>
AuthorDate: Mon Jan 30 18:55:21 2023 +0100
Add deferrable mode to dataflow operators (#27776)
* Add deferrable mode to DataflowTemplatedJobStartOperator and
DataflowStartFlexTemplateOperator operators
* Change project_id param to be optional, add fixes for tests and docs build
* Add comment about upper-bound for google-cloud-dataflow-client lib and
change warning message
---------
Co-authored-by: Heorhi Parkhomenka <[email protected]>
---
.../google/cloud/example_dags/example_dataflow.py | 3 +
.../example_dags/example_dataflow_flex_template.py | 1 +
airflow/providers/google/cloud/hooks/dataflow.py | 166 +++++++++++++++----
airflow/providers/google/cloud/links/dataflow.py | 2 +-
.../providers/google/cloud/operators/dataflow.py | 175 +++++++++++++++++----
.../providers/google/cloud/triggers/dataflow.py | 150 ++++++++++++++++++
airflow/providers/google/provider.yaml | 3 +
generated/provider_dependencies.json | 1 +
.../providers/google/cloud/hooks/test_dataflow.py | 55 ++++++-
.../google/cloud/operators/test_dataflow.py | 135 +++++++++++++---
.../google/cloud/triggers/test_dataflow.py | 168 ++++++++++++++++++++
.../cloud/dataflow/example_dataflow_template.py | 2 +
12 files changed, 769 insertions(+), 92 deletions(-)
diff --git a/airflow/providers/google/cloud/example_dags/example_dataflow.py
b/airflow/providers/google/cloud/example_dags/example_dataflow.py
index f2a0860fc0..2ab4c04f5b 100644
--- a/airflow/providers/google/cloud/example_dags/example_dataflow.py
+++ b/airflow/providers/google/cloud/example_dags/example_dataflow.py
@@ -52,6 +52,7 @@ GCS_STAGING = os.environ.get("GCP_DATAFLOW_GCS_STAGING",
"gs://INVALID BUCKET NA
GCS_OUTPUT = os.environ.get("GCP_DATAFLOW_GCS_OUTPUT", "gs://INVALID BUCKET
NAME/output")
GCS_JAR = os.environ.get("GCP_DATAFLOW_JAR", "gs://INVALID BUCKET
NAME/word-count-beam-bundled-0.1.jar")
GCS_PYTHON = os.environ.get("GCP_DATAFLOW_PYTHON", "gs://INVALID BUCKET
NAME/wordcount_debugging.py")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
GCS_JAR_PARTS = urlsplit(GCS_JAR)
GCS_JAR_BUCKET_NAME = GCS_JAR_PARTS.netloc
@@ -257,6 +258,7 @@ with models.DAG(
# [START howto_operator_start_template_job]
start_template_job = DataflowTemplatedJobStartOperator(
task_id="start-template-job",
+ project_id=PROJECT_ID,
template="gs://dataflow-templates/latest/Word_Count",
parameters={"inputFile":
"gs://dataflow-samples/shakespeare/kinglear.txt", "output": GCS_OUTPUT},
location="europe-west3",
@@ -279,6 +281,7 @@ with models.DAG(
# [END howto_operator_stop_dataflow_job]
start_template_job = DataflowTemplatedJobStartOperator(
task_id="start-template-job",
+ project_id=PROJECT_ID,
template="gs://dataflow-templates/latest/Word_Count",
parameters={"inputFile":
"gs://dataflow-samples/shakespeare/kinglear.txt", "output": GCS_OUTPUT},
location="europe-west3",
diff --git
a/airflow/providers/google/cloud/example_dags/example_dataflow_flex_template.py
b/airflow/providers/google/cloud/example_dags/example_dataflow_flex_template.py
index 86de7014c5..8af748e8a7 100644
---
a/airflow/providers/google/cloud/example_dags/example_dataflow_flex_template.py
+++
b/airflow/providers/google/cloud/example_dags/example_dataflow_flex_template.py
@@ -52,6 +52,7 @@ with models.DAG(
# [START howto_operator_start_template_job]
start_flex_template = DataflowStartFlexTemplateOperator(
task_id="start_flex_template_streaming_beam_sql",
+ project_id=GCP_PROJECT_ID,
body={
"launchParameter": {
"containerSpecGcsPath": GCS_FLEX_TEMPLATE_TEMPLATE_PATH,
diff --git a/airflow/providers/google/cloud/hooks/dataflow.py
b/airflow/providers/google/cloud/hooks/dataflow.py
index b034dbd5f9..3f7fd4b54e 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -29,11 +29,16 @@ import warnings
from copy import deepcopy
from typing import Any, Callable, Generator, Sequence, TypeVar, cast
+from google.cloud.dataflow_v1beta3 import GetJobRequest, Job, JobState,
JobsV1Beta3AsyncClient, JobView
from googleapiclient.discovery import build
from airflow.exceptions import AirflowException
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType,
beam_options_to_args
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+from airflow.providers.google.common.hooks.base_google import (
+ PROVIDE_PROJECT_ID,
+ GoogleBaseAsyncHook,
+ GoogleBaseHook,
+)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.timeout import timeout
@@ -645,36 +650,10 @@ class DataflowHook(GoogleBaseHook):
"""
name = self.build_dataflow_job_name(job_name, append_job_name)
- environment = environment or {}
- # available keys for runtime environment are listed here:
- #
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
- environment_keys = [
- "numWorkers",
- "maxWorkers",
- "zone",
- "serviceAccountEmail",
- "tempLocation",
- "bypassTempDirValidation",
- "machineType",
- "additionalExperiments",
- "network",
- "subnetwork",
- "additionalUserLabels",
- "kmsKeyName",
- "ipConfiguration",
- "workerRegion",
- "workerZone",
- ]
-
- for key in variables:
- if key in environment_keys:
- if key in environment:
- self.log.warning(
- "'%s' parameter in 'variables' will override of "
- "the same one passed in 'environment'!",
- key,
- )
- environment.update({key: variables[key]})
+ environment = self._update_environment(
+ variables=variables,
+ environment=environment,
+ )
service = self.get_conn()
@@ -723,6 +702,40 @@ class DataflowHook(GoogleBaseHook):
jobs_controller.wait_for_done()
return response["job"]
+ def _update_environment(self, variables: dict, environment: dict | None =
None) -> dict:
+ environment = environment or {}
+ # available keys for runtime environment are listed here:
+ #
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
+ environment_keys = {
+ "numWorkers",
+ "maxWorkers",
+ "zone",
+ "serviceAccountEmail",
+ "tempLocation",
+ "bypassTempDirValidation",
+ "machineType",
+ "additionalExperiments",
+ "network",
+ "subnetwork",
+ "additionalUserLabels",
+ "kmsKeyName",
+ "ipConfiguration",
+ "workerRegion",
+ "workerZone",
+ }
+
+ def _check_one(key, val):
+ if key in environment:
+ self.log.warning(
+ "%r parameter in 'variables' will override the same one
passed in 'environment'!",
+ key,
+ )
+ return key, val
+
+ environment.update(_check_one(key, val) for key, val in
variables.items() if key in environment_keys)
+
+ return environment
+
@GoogleBaseHook.fallback_to_default_project_id
def start_flex_template(
self,
@@ -731,9 +744,9 @@ class DataflowHook(GoogleBaseHook):
project_id: str,
on_new_job_id_callback: Callable[[str], None] | None = None,
on_new_job_callback: Callable[[dict], None] | None = None,
- ):
+ ) -> dict:
"""
- Starts flex templates with the Dataflow pipeline.
+ Starts flex templates with the Dataflow pipeline.
:param body: The request body. See:
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
@@ -1041,7 +1054,7 @@ class DataflowHook(GoogleBaseHook):
def get_job(
self,
job_id: str,
- project_id: str,
+ project_id: str = PROVIDE_PROJECT_ID,
location: str = DEFAULT_DATAFLOW_LOCATION,
) -> dict:
"""
@@ -1169,3 +1182,88 @@ class DataflowHook(GoogleBaseHook):
wait_until_finished=self.wait_until_finished,
)
job_controller.wait_for_done()
+
+
+class AsyncDataflowHook(GoogleBaseAsyncHook):
+ """Async hook class for dataflow service."""
+
+ sync_hook_class = DataflowHook
+
+ async def initialize_client(self, client_class):
+ """
+ Initialize object of the given class.
+
+ Method is used to initialize asynchronous client. Because of the big
amount of the classes which are
+ used for Dataflow service it was decided to initialize them the same
way with credentials which are
+ received from the method of the GoogleBaseHook class.
+ :param client_class: Class of the Google cloud SDK
+ """
+ credentials = (await self.get_sync_hook()).get_credentials()
+ return client_class(
+ credentials=credentials,
+ )
+
+ async def get_project_id(self) -> str:
+ project_id = (await self.get_sync_hook()).project_id
+ return project_id
+
+ async def get_job(
+ self,
+ job_id: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ job_view: int = JobView.JOB_VIEW_SUMMARY,
+ location: str = DEFAULT_DATAFLOW_LOCATION,
+ ) -> Job:
+ """
+ Gets the job with the specified Job ID.
+
+ :param job_id: Job ID to get.
+ :param project_id: the Google Cloud project ID in which to start a job.
+ If set to None or missing, the default project_id from the Google
Cloud connection is used.
+ :param job_view: Optional. JobView object which determines
representation of the returned data
+ :param location: Optional. The location of the Dataflow job (for
example europe-west1). See:
+ https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
+ """
+ project_id = project_id or (await self.get_project_id())
+ client = await self.initialize_client(JobsV1Beta3AsyncClient)
+
+ request = GetJobRequest(
+ dict(
+ project_id=project_id,
+ job_id=job_id,
+ view=job_view,
+ location=location,
+ )
+ )
+
+ job = await client.get_job(
+ request=request,
+ )
+
+ return job
+
+ async def get_job_status(
+ self,
+ job_id: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ job_view: int = JobView.JOB_VIEW_SUMMARY,
+ location: str = DEFAULT_DATAFLOW_LOCATION,
+ ) -> JobState:
+ """
+ Gets the job status with the specified Job ID.
+
+ :param job_id: Job ID to get.
+ :param project_id: the Google Cloud project ID in which to start a job.
+ If set to None or missing, the default project_id from the Google
Cloud connection is used.
+ :param job_view: Optional. JobView object which determines
representation of the returned data
+ :param location: Optional. The location of the Dataflow job (for
example europe-west1). See:
+ https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
+ """
+ job = await self.get_job(
+ project_id=project_id,
+ job_id=job_id,
+ job_view=job_view,
+ location=location,
+ )
+ state = job.current_state
+ return state
diff --git a/airflow/providers/google/cloud/links/dataflow.py
b/airflow/providers/google/cloud/links/dataflow.py
index 1f0f6f87e8..b62e29041d 100644
--- a/airflow/providers/google/cloud/links/dataflow.py
+++ b/airflow/providers/google/cloud/links/dataflow.py
@@ -48,5 +48,5 @@ class DataflowJobLink(BaseGoogleLink):
operator_instance.xcom_push(
context,
key=DataflowJobLink.key,
- value={"project_id": project_id, "location": region, "job_id":
job_id},
+ value={"project_id": project_id, "region": region, "job_id":
job_id},
)
diff --git a/airflow/providers/google/cloud/operators/dataflow.py
b/airflow/providers/google/cloud/operators/dataflow.py
index bdff866dd3..66a5da85f3 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -20,11 +20,14 @@ from __future__ import annotations
import copy
import re
+import uuid
import warnings
from contextlib import ExitStack
from enum import Enum
from typing import TYPE_CHECKING, Any, Sequence
+from airflow import AirflowException
+from airflow.compat.functools import cached_property
from airflow.models import BaseOperator
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType
from airflow.providers.google.cloud.hooks.dataflow import (
@@ -34,6 +37,7 @@ from airflow.providers.google.cloud.hooks.dataflow import (
)
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.links.dataflow import DataflowJobLink
+from airflow.providers.google.cloud.triggers.dataflow import
TemplateJobStartTrigger
from airflow.version import version
if TYPE_CHECKING:
@@ -55,8 +59,8 @@ class CheckJobRunning(Enum):
class DataflowConfiguration:
"""Dataflow configuration that can be passed to
-
:py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`
and
-
:py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator`.
+
:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`
and
+
:class:`~airflow.providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator`.
:param job_name: The 'jobName' to use when executing the Dataflow job
(templated). This ends up being set in the pipeline options, so any
entry
@@ -66,9 +70,8 @@ class DataflowConfiguration:
If set to None or missing, the default project_id from the Google
Cloud connection is used.
:param location: Job location.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
- :param delegate_to: The account to impersonate using domain-wide
delegation of authority,
- if any. For this to work, the service account making the request must
have
- domain-wide delegation enabled.
+ :param delegate_to: The account to impersonate using domain-wide
delegation of authority, if any.
+ For this to work, the service account making the request must have
domain-wide delegation enabled.
:param poll_sleep: The time in seconds to sleep between polling Google
Cloud Platform for the dataflow job status while the job is in the
JOB_STATE_RUNNING state.
@@ -82,7 +85,6 @@ class DataflowConfiguration:
account from the list granting this role to the originating account
(templated).
.. warning::
-
This option requires Apache Beam 2.39.0 or newer.
:param drain_pipeline: Optional, set to True if want to stop streaming job
by draining it
@@ -101,7 +103,6 @@ class DataflowConfiguration:
* for the batch pipeline, wait for the jobs to complete.
.. warning::
-
You cannot call ``PipelineResult.wait_until_finish`` method in
your pipeline code for the operator
to work properly. i. e. you must use asynchronous execution.
Otherwise, your pipeline will
always wait until finished. For more information, look at:
@@ -109,10 +110,8 @@ class DataflowConfiguration:
<https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#python_10>`__
The process of starting the Dataflow job in Airflow consists of two
steps:
-
* running a subprocess and reading the stderr/stderr log for the job
id.
- * loop waiting for the end of the job ID from the previous step.
- This loop checks the status of the job.
+ * loop waiting for the end of the job ID from the previous step by
checking its status.
Step two is started just after step one has finished, so if you have
wait_until_finished in your
pipeline code, step two will not start until the process stops. When
this process stops,
@@ -124,13 +123,10 @@ class DataflowConfiguration:
If you in your pipeline do not call the wait_for_pipeline method, and
pass wait_until_finish=False
to the operator, the second loop will check once is job not in
terminal state and exit the loop.
:param multiple_jobs: If pipeline creates multiple jobs then monitor all
jobs. Supported only by
-
:py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`
+
:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`.
:param check_if_running: Before running job, validate that a previous run
is not in process.
- IgnoreJob = do not check if running.
- FinishIfRunning = if job is running finish with nothing.
- WaitForRun = wait until job finished and the run job.
Supported only by:
-
:py:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`
+
:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`.
:param service_account: Run the job as a specific service account, instead
of the default GCE robot.
"""
@@ -230,6 +226,7 @@ class DataflowCreateJavaJobOperator(BaseOperator):
with key ``'jobName'`` in ``options`` will be overwritten.
:param dataflow_default_options: Map of default job options.
:param options: Map of job specific options.The key must be a dictionary.
+
The value can contain different types:
* If the value is None, the single option - ``--key`` (without value)
will be added.
@@ -241,6 +238,7 @@ class DataflowCreateJavaJobOperator(BaseOperator):
* Other value types will be replaced with the Python textual
representation.
When defining labels (``labels`` option), you can also provide a
dictionary.
+
:param project_id: Optional, the Google Cloud project ID in which to start
a job.
If set to None or missing, the default project_id from the Google
Cloud connection is used.
:param location: Job location.
@@ -583,7 +581,7 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
)
``template``, ``dataflow_default_options``, ``parameters``, and
``job_name`` are
- templated so you can use variables in them.
+ templated, so you can use variables in them.
Note that ``dataflow_default_options`` is expected to save high-level
options
for project information, which apply to all dataflow operators in the DAG.
@@ -594,6 +592,8 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
For more detail on job template execution have a look at the
reference:
https://cloud.google.com/dataflow/docs/templates/executing-templates
+
+ :param deferrable: Run operator in the deferrable mode.
"""
template_fields: Sequence[str] = (
@@ -615,11 +615,11 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
self,
*,
template: str,
+ project_id: str | None = None,
job_name: str = "{{task.task_id}}",
options: dict[str, Any] | None = None,
dataflow_default_options: dict[str, Any] | None = None,
parameters: dict[str, str] | None = None,
- project_id: str | None = None,
location: str = DEFAULT_DATAFLOW_LOCATION,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
@@ -629,9 +629,11 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
cancel_timeout: int | None = 10 * 60,
wait_until_finished: bool | None = None,
append_job_name: bool = True,
+ deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
+
self.template = template
self.job_name = job_name
self.options = options or {}
@@ -646,16 +648,31 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
)
self.delegate_to = delegate_to
self.poll_sleep = poll_sleep
- self.job = None
- self.hook: DataflowHook | None = None
self.impersonation_chain = impersonation_chain
self.environment = environment
self.cancel_timeout = cancel_timeout
self.wait_until_finished = wait_until_finished
self.append_job_name = append_job_name
+ self.deferrable = deferrable
- def execute(self, context: Context) -> dict:
- self.hook = DataflowHook(
+ self.job: dict | None = None
+
+ self._validate_deferrable_params()
+
+ def _validate_deferrable_params(self):
+ if self.deferrable and self.wait_until_finished:
+ raise ValueError(
+ "Conflict between deferrable and wait_until_finished
parameters "
+ "because it makes operator as blocking when it requires to be
deferred. "
+ "It should be True as deferrable parameter or True as
wait_until_finished."
+ )
+
+ if self.deferrable and self.wait_until_finished is None:
+ self.wait_until_finished = False
+
+ @cached_property
+ def hook(self) -> DataflowHook:
+ hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
poll_sleep=self.poll_sleep,
@@ -663,14 +680,17 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
cancel_timeout=self.cancel_timeout,
wait_until_finished=self.wait_until_finished,
)
+ return hook
+ def execute(self, context: Context):
def set_current_job(current_job):
self.job = current_job
DataflowJobLink.persist(self, context, self.project_id,
self.location, self.job.get("id"))
options = self.dataflow_default_options
options.update(self.options)
- job = self.hook.start_template_dataflow(
+
+ self.job = self.hook.start_template_dataflow(
job_name=self.job_name,
variables=options,
parameters=self.parameters,
@@ -681,13 +701,48 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
environment=self.environment,
append_job_name=self.append_job_name,
)
+ job_id = self.job.get("id")
- return job
+ if job_id is None:
+ raise AirflowException(
+ "While reading job object after template execution error
occurred. Job object has no id."
+ )
+
+ if not self.deferrable:
+ return job_id
+
+ context["ti"].xcom_push(key="job_id", value=job_id)
+
+ self.defer(
+ trigger=TemplateJobStartTrigger(
+ project_id=self.project_id,
+ job_id=job_id,
+ location=self.location,
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ poll_sleep=self.poll_sleep,
+ impersonation_chain=self.impersonation_chain,
+ cancel_timeout=self.cancel_timeout,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Context, event: dict[str, Any]):
+ """Method which executes after trigger finishes its work."""
+ if event["status"] == "error" or event["status"] == "stopped":
+ self.log.info("status: %s, msg: %s", event["status"],
event["message"])
+ raise AirflowException(event["message"])
+
+ job_id = event["job_id"]
+ self.log.info("Task %s completed with response %s", self.task_id,
event["message"])
+ return job_id
def on_kill(self) -> None:
self.log.info("On kill.")
- if self.job:
+ if self.job is not None:
+ self.log.info("Cancelling job %s", self.job_name)
self.hook.cancel_job(
+ job_name=self.job_name,
job_id=self.job.get("id"),
project_id=self.job.get("projectId"),
location=self.job.get("location"),
@@ -706,7 +761,6 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
:param location: The location of the Dataflow job (for example
europe-west1)
:param project_id: The ID of the GCP project that owns the job.
- If set to ``None`` or missing, the default project_id from the GCP
connection is used.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud
Platform.
:param delegate_to: The account to impersonate, if any.
@@ -758,6 +812,7 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding
identity, with first
account from the list granting this role to the originating account
(templated).
+ :param deferrable: Run operator in the deferrable mode.
"""
template_fields: Sequence[str] = ("body", "location", "project_id",
"gcp_conn_id")
@@ -774,6 +829,7 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
cancel_timeout: int | None = 10 * 60,
wait_until_finished: bool | None = None,
impersonation_chain: str | Sequence[str] | None = None,
+ deferrable: bool = False,
*args,
**kwargs,
) -> None:
@@ -790,12 +846,26 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
self.drain_pipeline = drain_pipeline
self.cancel_timeout = cancel_timeout
self.wait_until_finished = wait_until_finished
- self.job = None
- self.hook: DataflowHook | None = None
+ self.job: dict | None = None
self.impersonation_chain = impersonation_chain
+ self.deferrable = deferrable
- def execute(self, context: Context):
- self.hook = DataflowHook(
+ self._validate_deferrable_params()
+
+ def _validate_deferrable_params(self):
+ if self.deferrable and self.wait_until_finished:
+ raise ValueError(
+ "Conflict between deferrable and wait_until_finished
parameters "
+ "because it makes operator as blocking when it requires to be
deferred. "
+ "It should be True as deferrable parameter or True as
wait_until_finished."
+ )
+
+ if self.deferrable and self.wait_until_finished is None:
+ self.wait_until_finished = False
+
+ @cached_property
+ def hook(self) -> DataflowHook:
+ hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
drain_pipeline=self.drain_pipeline,
@@ -803,23 +873,66 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
wait_until_finished=self.wait_until_finished,
impersonation_chain=self.impersonation_chain,
)
+ return hook
+
+ def execute(self, context: Context):
+ self._append_uuid_to_job_name()
def set_current_job(current_job):
self.job = current_job
DataflowJobLink.persist(self, context, self.project_id,
self.location, self.job.get("id"))
- job = self.hook.start_flex_template(
+ self.job = self.hook.start_flex_template(
body=self.body,
location=self.location,
project_id=self.project_id,
on_new_job_callback=set_current_job,
)
+ job_id = self.job.get("id")
+ if job_id is None:
+ raise AirflowException(
+ "While reading job object after template execution error
occurred. Job object has no id."
+ )
+
+ if not self.deferrable:
+ return self.job
+
+ self.defer(
+ trigger=TemplateJobStartTrigger(
+ project_id=self.project_id,
+ job_id=job_id,
+ location=self.location,
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ impersonation_chain=self.impersonation_chain,
+ cancel_timeout=self.cancel_timeout,
+ ),
+ method_name="execute_complete",
+ )
+
+ def _append_uuid_to_job_name(self):
+ job_body = self.body.get("launch_parameter") or
self.body.get("launchParameter")
+ job_name = job_body.get("jobName")
+ if job_name:
+ job_name += f"-{str(uuid.uuid4())[:8]}"
+ job_body["jobName"] = job_name
+ self.log.info("Job name was changed to %s", job_name)
+
+ def execute_complete(self, context: Context, event: dict):
+ """Method which executes after trigger finishes its work."""
+ if event["status"] == "error" or event["status"] == "stopped":
+ self.log.info("status: %s, msg: %s", event["status"],
event["message"])
+ raise AirflowException(event["message"])
+
+ job_id = event["job_id"]
+ self.log.info("Task %s completed with response %s", job_id,
event["message"])
+ job = self.hook.get_job(job_id=job_id, project_id=self.project_id,
location=self.location)
return job
def on_kill(self) -> None:
self.log.info("On kill.")
- if self.job:
+ if self.job is not None:
self.hook.cancel_job(
job_id=self.job.get("id"),
project_id=self.job.get("projectId"),
diff --git a/airflow/providers/google/cloud/triggers/dataflow.py
b/airflow/providers/google/cloud/triggers/dataflow.py
new file mode 100644
index 0000000000..5167323789
--- /dev/null
+++ b/airflow/providers/google/cloud/triggers/dataflow.py
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+from typing import Any, Sequence
+
+from google.cloud.dataflow_v1beta3 import JobState
+
+from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+DEFAULT_DATAFLOW_LOCATION = "us-central1"
+
+
+class TemplateJobStartTrigger(BaseTrigger):
+ """Dataflow trigger to check if templated job has been finished.
+
+ :param project_id: Required. the Google Cloud project ID in which the job
was started.
+ :param job_id: Required. ID of the job.
+ :param location: Optional. the location where job is executed. If set to
None then
+ the value of DEFAULT_DATAFLOW_LOCATION will be used
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param delegate_to: The account to impersonate using domain-wide
delegation of authority,
+ if any. For this to work, the service account making the request must
have
+ domain-wide delegation enabled.
+ :param impersonation_chain: Optional. Service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ :param cancel_timeout: Optional. How long (in seconds) operator should
wait for the pipeline to be
+ successfully cancelled when task is being killed.
+ """
+
+ def __init__(
+ self,
+ job_id: str,
+ project_id: str | None,
+ location: str = DEFAULT_DATAFLOW_LOCATION,
+ gcp_conn_id: str = "google_cloud_default",
+ delegate_to: str | None = None,
+ poll_sleep: int = 10,
+ impersonation_chain: str | Sequence[str] | None = None,
+ cancel_timeout: int | None = 5 * 60,
+ ):
+ super().__init__()
+
+ self.project_id = project_id
+ self.job_id = job_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.delegate_to = delegate_to
+ self.poll_sleep = poll_sleep
+ self.impersonation_chain = impersonation_chain
+ self.cancel_timeout = cancel_timeout
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes class arguments and classpath."""
+ return (
+
"airflow.providers.google.cloud.triggers.dataflow.TemplateJobStartTrigger",
+ {
+ "project_id": self.project_id,
+ "job_id": self.job_id,
+ "location": self.location,
+ "gcp_conn_id": self.gcp_conn_id,
+ "delegate_to": self.delegate_to,
+ "poll_sleep": self.poll_sleep,
+ "impersonation_chain": self.impersonation_chain,
+ "cancel_timeout": self.cancel_timeout,
+ },
+ )
+
+ async def run(self):
+ """
+ Main loop of the class in where it is fetching the job status and
yields certain Event.
+
+ If the job has status success then it yields TriggerEvent with success
status, if job has
+ status failed - with error status. In any other case Trigger will wait
for specified
+ amount of time stored in self.poll_sleep variable.
+ """
+ hook = self._get_async_hook()
+ while True:
+ try:
+ status = await hook.get_job_status(
+ project_id=self.project_id,
+ job_id=self.job_id,
+ location=self.location,
+ )
+ if status == JobState.JOB_STATE_DONE:
+ yield TriggerEvent(
+ {
+ "job_id": self.job_id,
+ "status": "success",
+ "message": "Job completed",
+ }
+ )
+ return
+ elif status == JobState.JOB_STATE_FAILED:
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Dataflow job with id {self.job_id}
has failed its execution",
+ }
+ )
+ return
+ elif status == JobState.JOB_STATE_STOPPED:
+ yield TriggerEvent(
+ {
+ "status": "stopped",
+ "message": f"Dataflow job with id {self.job_id}
was stopped",
+ }
+ )
+ return
+ else:
+ self.log.info("Job is still running...")
+ self.log.info("Current job status is: %s", status)
+ self.log.info("Sleeping for %s seconds.", self.poll_sleep)
+ await asyncio.sleep(self.poll_sleep)
+ except Exception as e:
+ self.log.exception("Exception occurred while checking for job
completion.")
+ yield TriggerEvent({"status": "error", "message": str(e)})
+ return
+
+ def _get_async_hook(self) -> AsyncDataflowHook:
+ return AsyncDataflowHook(
+ gcp_conn_id=self.gcp_conn_id,
+ delegate_to=self.delegate_to,
+ poll_sleep=self.poll_sleep,
+ impersonation_chain=self.impersonation_chain,
+ cancel_timeout=self.cancel_timeout,
+ )
diff --git a/airflow/providers/google/provider.yaml
b/airflow/providers/google/provider.yaml
index 3269c4f597..faf5bcaadc 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -85,6 +85,9 @@ dependencies:
- google-cloud-build>=3.0.0
- google-cloud-compute>=0.1.0,<2.0.0
- google-cloud-container>=2.2.0,<3.0.0
+ # google-cloud-dataflow-client of version 0.5.5 requires higher versions of
+ # protobuf and proto-plus libraries which can break other dependencies in
the current package.
+ - google-cloud-dataflow-client>=0.5.2,<0.5.5
- google-cloud-dataform>=0.2.0
- google-cloud-datacatalog>=3.0.0
- google-cloud-dataplex>=0.1.0
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index d23e804886..dd67a7c61e 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -339,6 +339,7 @@
"google-cloud-compute>=0.1.0,<2.0.0",
"google-cloud-container>=2.2.0,<3.0.0",
"google-cloud-datacatalog>=3.0.0",
+ "google-cloud-dataflow-client>=0.5.2,<0.5.5",
"google-cloud-dataform>=0.2.0",
"google-cloud-dataplex>=0.1.0",
"google-cloud-dataproc-metastore>=1.2.0,<2.0.0",
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py
b/tests/providers/google/cloud/hooks/test_dataflow.py
index df9991fe57..1e48eb6a56 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -20,17 +20,20 @@ from __future__ import annotations
import copy
import re
import shlex
+import sys
+from asyncio import Future
from typing import Any
-from unittest import mock
from unittest.mock import MagicMock
from uuid import UUID
import pytest
+from google.cloud.dataflow_v1beta3 import GetJobRequest, JobView
from airflow.exceptions import AirflowException
from airflow.providers.apache.beam.hooks.beam import BeamCommandRunner,
BeamHook
from airflow.providers.google.cloud.hooks.dataflow import (
DEFAULT_DATAFLOW_LOCATION,
+ AsyncDataflowHook,
DataflowHook,
DataflowJobStatus,
DataflowJobType,
@@ -39,6 +42,12 @@ from airflow.providers.google.cloud.hooks.dataflow import (
process_line_and_extract_dataflow_job_id_callback,
)
+if sys.version_info < (3, 8):
+ from asynctest import mock
+else:
+ from unittest import mock
+
+
DEFAULT_RUNNER = "DirectRunner"
BEAM_STRING = "airflow.providers.apache.beam.hooks.beam.{}"
@@ -52,6 +61,7 @@ PARAMETERS = {
"inputFile": "gs://dataflow-samples/shakespeare/kinglear.txt",
"output": "gs://test/output/my_output",
}
+TEST_ENVIRONMENT = {}
PY_FILE = "apache_beam.examples.wordcount"
JAR_FILE = "unitest.jar"
JOB_CLASS = "com.example.UnitTest"
@@ -1882,3 +1892,46 @@ class TestDataflow:
mock_logging.info.assert_called_once_with("Running command: %s", "test
cmd")
with pytest.raises(Exception):
dataflow.wait_for_done()
+
+
[email protected]()
+def make_mock_awaitable():
+ def func(mock_obj, return_value):
+ f = Future()
+ f.set_result(return_value)
+ mock_obj.return_value = f
+
+ return func
+
+
+class TestAsyncHook:
+ @pytest.fixture
+ def hook(self):
+ return AsyncDataflowHook(
+ gcp_conn_id=TEST_PROJECT_ID,
+ )
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.initialize_client")
+ async def test_get_job(self, initialize_client_mock, hook,
make_mock_awaitable):
+ client = initialize_client_mock.return_value
+ make_mock_awaitable(client.get_job, None)
+
+ await hook.get_job(
+ project_id=TEST_PROJECT_ID,
+ job_id=TEST_JOB_ID,
+ location=TEST_LOCATION,
+ )
+ request = GetJobRequest(
+ dict(
+ project_id=TEST_PROJECT_ID,
+ job_id=TEST_JOB_ID,
+ location=TEST_LOCATION,
+ view=JobView.JOB_VIEW_SUMMARY,
+ )
+ )
+
+ initialize_client_mock.assert_called_once()
+ client.get_job.assert_called_once_with(
+ request=request,
+ )
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py
b/tests/providers/google/cloud/operators/test_dataflow.py
index 40e37c9487..c9f5364f66 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -22,6 +22,8 @@ import unittest
from copy import deepcopy
from unittest import mock
+import pytest as pytest
+
import airflow
from airflow.providers.google.cloud.operators.dataflow import (
CheckJobRunning,
@@ -93,6 +95,10 @@ FROM
GROUP BY sales_region;
"""
TEST_SQL_JOB = {"id": "test-job-id"}
+GCP_CONN_ID = "test_gcp_conn_id"
+DELEGATE_TO = "delegating_to_something"
+IMPERSONATION_CHAIN = ["impersonate", "this"]
+CANCEL_TIMEOUT = 10 * 420
class TestDataflowPythonOperator(unittest.TestCase):
@@ -445,9 +451,11 @@ class TestDataflowJavaOperatorWithLocal(unittest.TestCase):
)
-class TestDataflowTemplateOperator(unittest.TestCase):
- def setUp(self):
- self.dataflow = DataflowTemplatedJobStartOperator(
+class TestDataflowTemplateOperator:
+ @pytest.fixture
+ def sync_operator(self):
+ return DataflowTemplatedJobStartOperator(
+ project_id=TEST_PROJECT,
task_id=TASK_ID,
template=TEMPLATE,
job_name=JOB_NAME,
@@ -459,14 +467,30 @@ class TestDataflowTemplateOperator(unittest.TestCase):
environment={"maxWorkers": 2},
)
-
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
- def test_exec(self, dataflow_mock):
- """Test DataflowHook is created and the right args are passed to
- start_template_workflow.
+ @pytest.fixture
+ def deferrable_operator(self):
+ return DataflowTemplatedJobStartOperator(
+ project_id=TEST_PROJECT,
+ task_id=TASK_ID,
+ template=TEMPLATE,
+ job_name=JOB_NAME,
+ parameters=PARAMETERS,
+ options=DEFAULT_OPTIONS_TEMPLATE,
+ dataflow_default_options={"EXTRA_OPTION": "TEST_A"},
+ poll_sleep=POLL_SLEEP,
+ location=TEST_LOCATION,
+ environment={"maxWorkers": 2},
+ deferrable=True,
+ gcp_conn_id=GCP_CONN_ID,
+ delegate_to=DELEGATE_TO,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ cancel_timeout=CANCEL_TIMEOUT,
+ )
- """
+
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
+ def test_exec(self, dataflow_mock, sync_operator):
start_template_hook =
dataflow_mock.return_value.start_template_dataflow
- self.dataflow.execute(None)
+ sync_operator.execute(None)
assert dataflow_mock.called
expected_options = {
"project": "test",
@@ -481,24 +505,66 @@ class TestDataflowTemplateOperator(unittest.TestCase):
parameters=PARAMETERS,
dataflow_template=TEMPLATE,
on_new_job_callback=mock.ANY,
- project_id=None,
+ project_id=TEST_PROJECT,
location=TEST_LOCATION,
environment={"maxWorkers": 2},
append_job_name=True,
)
+
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator.defer")
+
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator.hook")
+ def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method,
deferrable_operator):
+ deferrable_operator.execute(mock.MagicMock())
+ mock_defer_method.assert_called_once()
+
+ def test_validation_deferrable_params_raises_error(self):
+ init_kwargs = {
+ "project_id": TEST_PROJECT,
+ "task_id": TASK_ID,
+ "template": TEMPLATE,
+ "job_name": JOB_NAME,
+ "parameters": PARAMETERS,
+ "options": DEFAULT_OPTIONS_TEMPLATE,
+ "dataflow_default_options": {"EXTRA_OPTION": "TEST_A"},
+ "poll_sleep": POLL_SLEEP,
+ "location": TEST_LOCATION,
+ "environment": {"maxWorkers": 2},
+ "wait_until_finished": True,
+ "deferrable": True,
+ "gcp_conn_id": GCP_CONN_ID,
+ "delegate_to": DELEGATE_TO,
+ "impersonation_chain": IMPERSONATION_CHAIN,
+ "cancel_timeout": CANCEL_TIMEOUT,
+ }
+ with pytest.raises(ValueError):
+ DataflowTemplatedJobStartOperator(**init_kwargs)
-class TestDataflowStartFlexTemplateOperator(unittest.TestCase):
-
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
- def test_execute(self, mock_dataflow):
- start_flex_template = DataflowStartFlexTemplateOperator(
+
+class TestDataflowStartFlexTemplateOperator:
+ @pytest.fixture
+ def sync_operator(self):
+ return DataflowStartFlexTemplateOperator(
task_id="start_flex_template_streaming_beam_sql",
body={"launchParameter": TEST_FLEX_PARAMETERS},
do_xcom_push=True,
project_id=TEST_PROJECT,
location=TEST_LOCATION,
)
- start_flex_template.execute(mock.MagicMock())
+
+ @pytest.fixture
+ def deferrable_operator(self):
+ return DataflowStartFlexTemplateOperator(
+ task_id="start_flex_template_streaming_beam_sql",
+ body={"launchParameter": TEST_FLEX_PARAMETERS},
+ do_xcom_push=True,
+ project_id=TEST_PROJECT,
+ location=TEST_LOCATION,
+ deferrable=True,
+ )
+
+
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
+ def test_execute(self, mock_dataflow, sync_operator):
+ sync_operator.execute(mock.MagicMock())
mock_dataflow.assert_called_once_with(
gcp_conn_id="google_cloud_default",
delegate_to=None,
@@ -514,20 +580,39 @@ class
TestDataflowStartFlexTemplateOperator(unittest.TestCase):
on_new_job_callback=mock.ANY,
)
- def test_on_kill(self):
- start_flex_template = DataflowStartFlexTemplateOperator(
- task_id="start_flex_template_streaming_beam_sql",
+ def test_on_kill(self, sync_operator):
+ sync_operator.hook = mock.MagicMock()
+ sync_operator.job = {"id": JOB_ID, "projectId": TEST_PROJECT,
"location": TEST_LOCATION}
+ sync_operator.on_kill()
+ sync_operator.hook.cancel_job.assert_called_once_with(
+ job_id="test-dataflow-pipeline-id", project_id=TEST_PROJECT,
location=TEST_LOCATION
+ )
+
+ def test_validation_deferrable_params_raises_error(self):
+ init_kwargs = {
+ "task_id": "start_flex_template_streaming_beam_sql",
+ "body": {"launchParameter": TEST_FLEX_PARAMETERS},
+ "do_xcom_push": True,
+ "location": TEST_LOCATION,
+ "project_id": TEST_PROJECT,
+ "wait_until_finished": True,
+ "deferrable": True,
+ }
+ with pytest.raises(ValueError):
+ DataflowStartFlexTemplateOperator(**init_kwargs)
+
+
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowStartFlexTemplateOperator.defer")
+
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
+ def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method,
deferrable_operator):
+ deferrable_operator.execute(mock.MagicMock())
+
+ mock_hook.return_value.start_flex_template.assert_called_once_with(
body={"launchParameter": TEST_FLEX_PARAMETERS},
- do_xcom_push=True,
location=TEST_LOCATION,
project_id=TEST_PROJECT,
+ on_new_job_callback=mock.ANY,
)
- start_flex_template.hook = mock.MagicMock()
- start_flex_template.job = {"id": JOB_ID, "projectId": TEST_PROJECT,
"location": TEST_LOCATION}
- start_flex_template.on_kill()
- start_flex_template.hook.cancel_job.assert_called_once_with(
- job_id="test-dataflow-pipeline-id", project_id=TEST_PROJECT,
location=TEST_LOCATION
- )
+ mock_defer_method.assert_called_once()
class TestDataflowSqlOperator(unittest.TestCase):
diff --git a/tests/providers/google/cloud/triggers/test_dataflow.py
b/tests/providers/google/cloud/triggers/test_dataflow.py
new file mode 100644
index 0000000000..2f055ecfe0
--- /dev/null
+++ b/tests/providers/google/cloud/triggers/test_dataflow.py
@@ -0,0 +1,168 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import sys
+from asyncio import Future
+
+import pytest
+from google.cloud.dataflow_v1beta3 import JobState
+
+from airflow.providers.google.cloud.triggers.dataflow import
TemplateJobStartTrigger
+from airflow.triggers.base import TriggerEvent
+
+if sys.version_info < (3, 8):
+ from asynctest import mock
+else:
+ from unittest import mock
+
+PROJECT_ID = "test-project-id"
+JOB_ID = "test_job_id_2012-12-23-10:00"
+LOCATION = "us-central1"
+GCP_CONN_ID = "test_gcp_conn_id"
+DELEGATE_TO = "delegating_to_something"
+POLL_SLEEP = 20
+IMPERSONATION_CHAIN = ["impersonate", "this"]
+CANCEL_TIMEOUT = 10 * 420
+
+
[email protected]
+def trigger():
+ return TemplateJobStartTrigger(
+ project_id=PROJECT_ID,
+ job_id=JOB_ID,
+ location=LOCATION,
+ gcp_conn_id=GCP_CONN_ID,
+ delegate_to=DELEGATE_TO,
+ poll_sleep=POLL_SLEEP,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ cancel_timeout=CANCEL_TIMEOUT,
+ )
+
+
[email protected]()
+def make_mock_awaitable():
+ def func(mock_obj, return_value):
+ if sys.version_info < (3, 8):
+ f = Future()
+ f.set_result(return_value)
+ mock_obj.return_value = f
+ else:
+ mock_obj.return_value = return_value
+ return mock_obj
+
+ return func
+
+
+def test_serialize(trigger):
+ actual_data = trigger.serialize()
+ expected_data = (
+
"airflow.providers.google.cloud.triggers.dataflow.TemplateJobStartTrigger",
+ {
+ "project_id": PROJECT_ID,
+ "job_id": JOB_ID,
+ "location": LOCATION,
+ "gcp_conn_id": GCP_CONN_ID,
+ "delegate_to": DELEGATE_TO,
+ "poll_sleep": POLL_SLEEP,
+ "impersonation_chain": IMPERSONATION_CHAIN,
+ "cancel_timeout": CANCEL_TIMEOUT,
+ },
+ )
+ assert actual_data == expected_data
+
+
[email protected](
+ "attr, expected",
+ [
+ ("gcp_conn_id", GCP_CONN_ID),
+ ("delegate_to", DELEGATE_TO),
+ ("poll_sleep", POLL_SLEEP),
+ ("impersonation_chain", IMPERSONATION_CHAIN),
+ ("cancel_timeout", CANCEL_TIMEOUT),
+ ],
+)
+def test_get_async_hook(trigger, attr, expected):
+ hook = trigger._get_async_hook()
+ actual = hook._hook_kwargs.get(attr)
+ assert actual is not None
+ assert actual == expected
+
+
[email protected]
[email protected]("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status")
+async def test_run_loop_return_success_event(mock_job_status, trigger,
make_mock_awaitable):
+ make_mock_awaitable(mock_job_status, JobState.JOB_STATE_DONE)
+
+ expected_event = TriggerEvent(
+ {
+ "job_id": JOB_ID,
+ "status": "success",
+ "message": "Job completed",
+ }
+ )
+ actual_event = await (trigger.run()).asend(None)
+
+ assert actual_event == expected_event
+
+
[email protected]
[email protected]("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status")
+async def test_run_loop_return_failed_event(mock_job_status, trigger,
make_mock_awaitable):
+ make_mock_awaitable(mock_job_status, JobState.JOB_STATE_FAILED)
+
+ expected_event = TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Dataflow job with id {JOB_ID} has failed its
execution",
+ }
+ )
+ actual_event = await (trigger.run()).asend(None)
+
+ assert actual_event == expected_event
+
+
[email protected]
[email protected]("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status")
+async def test_run_loop_return_stopped_event(mock_job_status, trigger,
make_mock_awaitable):
+ make_mock_awaitable(mock_job_status, JobState.JOB_STATE_STOPPED)
+ expected_event = TriggerEvent(
+ {
+ "status": "stopped",
+ "message": f"Dataflow job with id {JOB_ID} was stopped",
+ }
+ )
+ actual_event = await (trigger.run()).asend(None)
+
+ assert actual_event == expected_event
+
+
[email protected]
[email protected]("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job_status")
+async def test_run_loop_is_still_running(mock_job_status, trigger, caplog,
make_mock_awaitable):
+ make_mock_awaitable(mock_job_status, JobState.JOB_STATE_RUNNING)
+ caplog.set_level(logging.INFO)
+
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ assert not task.done()
+ assert f"Current job status is: {JobState.JOB_STATE_RUNNING}"
+ assert f"Sleeping for {POLL_SLEEP} seconds."
diff --git
a/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
b/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
index 9e9bb0b654..1247e1df41 100644
--- a/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
+++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
@@ -33,6 +33,7 @@ from airflow.providers.google.cloud.transfers.local_to_gcs
import LocalFilesyste
from airflow.utils.trigger_rule import TriggerRule
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
DAG_ID = "dataflow_template"
BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}"
@@ -71,6 +72,7 @@ with models.DAG(
# [START howto_operator_start_template_job]
start_template_job = DataflowTemplatedJobStartOperator(
task_id="start_template_job",
+ project_id=PROJECT_ID,
template="gs://dataflow-templates/latest/Word_Count",
parameters={"inputFile": f"gs://{BUCKET_NAME}/{FILE_NAME}", "output":
GCS_OUTPUT},
location=LOCATION,