This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7d3a402f48 Add DataflowStartYamlJobOperator (#41576)
7d3a402f48 is described below
commit 7d3a402f4882387c9ca31a435ef41441254304bc
Author: max <[email protected]>
AuthorDate: Mon Sep 2 14:27:01 2024 +0000
Add DataflowStartYamlJobOperator (#41576)
* Add DataflowStartYamlJobOperator
* Refactor hook and operator
---------
Co-authored-by: Eugene Galan <[email protected]>
---
airflow/providers/google/cloud/hooks/dataflow.py | 131 +++++++++++----
.../providers/google/cloud/operators/dataflow.py | 182 ++++++++++++++++++++-
.../providers/google/cloud/triggers/dataflow.py | 146 ++++++++++++++++-
.../operators/cloud/dataflow.rst | 32 ++++
tests/always/test_example_dags.py | 1 +
.../google/cloud/operators/test_dataflow.py | 109 +++++++++++-
.../google/cloud/triggers/test_dataflow.py | 141 +++++++++++++++-
.../google/cloud/dataflow/example_dataflow_yaml.py | 172 +++++++++++++++++++
8 files changed, 865 insertions(+), 49 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/dataflow.py
b/airflow/providers/google/cloud/hooks/dataflow.py
index 8f7e2e2549..97eaa49b36 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -186,9 +186,9 @@ class DataflowJobType:
class _DataflowJobsController(LoggingMixin):
"""
- Interface for communication with Google API.
+ Interface for communication with Google Cloud Dataflow API.
- It's not use Apache Beam, but only Google Dataflow API.
+ Does not use Apache Beam API.
:param dataflow: Discovery resource
:param project_number: The Google Cloud Project ID.
@@ -271,12 +271,12 @@ class _DataflowJobsController(LoggingMixin):
else:
raise ValueError("Missing both dataflow job ID and name.")
- def fetch_job_by_id(self, job_id: str) -> dict:
+ def fetch_job_by_id(self, job_id: str) -> dict[str, str]:
"""
Fetch the job with the specified Job ID.
- :param job_id: Job ID to get.
- :return: the Job
+ :param job_id: ID of the job that needs to be fetched.
+ :return: Dictionary containing the Job's data
"""
return (
self._dataflow.projects()
@@ -444,7 +444,6 @@ class _DataflowJobsController(LoggingMixin):
"Google Cloud Dataflow job's expected terminal state cannot be
"
"JOB_STATE_DRAINED while it is a batch job"
)
-
if current_state == current_expected_state:
if current_expected_state == DataflowJobStatus.JOB_STATE_RUNNING:
return not self._wait_until_finished
@@ -938,6 +937,90 @@ class DataflowHook(GoogleBaseHook):
response: dict = request.execute(num_retries=self.num_retries)
return response["job"]
+ @GoogleBaseHook.fallback_to_default_project_id
+ def launch_beam_yaml_job(
+ self,
+ *,
+ job_name: str,
+ yaml_pipeline_file: str,
+ append_job_name: bool,
+ jinja_variables: dict[str, str] | None,
+ options: dict[str, Any] | None,
+ project_id: str,
+ location: str = DEFAULT_DATAFLOW_LOCATION,
+ ) -> str:
+ """
+ Launch a Dataflow YAML job and run it until completion.
+
+ :param job_name: The unique name to assign to the Cloud Dataflow job.
+ :param yaml_pipeline_file: Path to a file defining the YAML pipeline
to run.
+ Must be a local file or a URL beginning with 'gs://'.
+ :param append_job_name: Set to True if a unique suffix has to be
appended to the `job_name`.
+ :param jinja_variables: A dictionary of Jinja2 variables to be used in
reifying the yaml pipeline file.
+ :param options: Additional gcloud or Beam job parameters.
+ It must be a dictionary with the keys matching the optional flag
names in gcloud.
+ The list of supported flags can be found at:
`https://cloud.google.com/sdk/gcloud/reference/dataflow/yaml/run`.
+ Note that if a flag does not require a value, then its dictionary
value must be either True or None.
+ For example, the `--log-http` flag can be passed as {'log-http':
True}.
+ :param project_id: The ID of the GCP project that owns the job.
+ :param location: Region ID of the job's regional endpoint. Defaults to
'us-central1'.
+ :param on_new_job_callback: Callback function that passes the job to
the operator once known.
+ :return: Job ID.
+ """
+ gcp_flags = {
+ "yaml-pipeline-file": yaml_pipeline_file,
+ "project": project_id,
+ "format": "value(job.id)",
+ "region": location,
+ }
+
+ if jinja_variables:
+ gcp_flags["jinja-variables"] = json.dumps(jinja_variables)
+
+ if options:
+ gcp_flags.update(options)
+
+ job_name = self.build_dataflow_job_name(job_name, append_job_name)
+ cmd = self._build_gcloud_command(
+ command=["gcloud", "dataflow", "yaml", "run", job_name],
parameters=gcp_flags
+ )
+ job_id = self._create_dataflow_job_with_gcloud(cmd=cmd)
+ return job_id
+
+ def _build_gcloud_command(self, command: list[str], parameters: dict[str,
str]) -> list[str]:
+ _parameters = deepcopy(parameters)
+ if self.impersonation_chain:
+ if isinstance(self.impersonation_chain, str):
+ impersonation_account = self.impersonation_chain
+ elif len(self.impersonation_chain) == 1:
+ impersonation_account = self.impersonation_chain[0]
+ else:
+ raise AirflowException(
+ "Chained list of accounts is not supported, please specify
only one service account."
+ )
+ _parameters["impersonate-service-account"] = impersonation_account
+ return [*command, *(beam_options_to_args(_parameters))]
+
+ def _create_dataflow_job_with_gcloud(self, cmd: list[str]) -> str:
+ """Create a Dataflow job with a gcloud command and return the job's
ID."""
+ self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c
in cmd))
+ success_code = 0
+
+ with self.provide_authorized_gcloud():
+ proc = subprocess.run(cmd, capture_output=True)
+
+ if proc.returncode != success_code:
+ stderr_last_20_lines =
"\n".join(proc.stderr.decode().strip().splitlines()[-20:])
+ raise AirflowException(
+ f"Process exit with non-zero exit code. Exit code:
{proc.returncode}. Error Details : "
+ f"{stderr_last_20_lines}"
+ )
+
+ job_id = proc.stdout.decode().strip()
+ self.log.info("Created job's ID: %s", job_id)
+
+ return job_id
+
@staticmethod
def extract_job_id(job: dict) -> str:
try:
@@ -1139,33 +1222,15 @@ class DataflowHook(GoogleBaseHook):
:param on_new_job_callback: Callback called when the job is known.
:return: the new job object
"""
- gcp_options = [
- f"--project={project_id}",
- "--format=value(job.id)",
- f"--job-name={job_name}",
- f"--region={location}",
- ]
-
- if self.impersonation_chain:
- if isinstance(self.impersonation_chain, str):
- impersonation_account = self.impersonation_chain
- elif len(self.impersonation_chain) == 1:
- impersonation_account = self.impersonation_chain[0]
- else:
- raise AirflowException(
- "Chained list of accounts is not supported, please specify
only one service account"
- )
-
gcp_options.append(f"--impersonate-service-account={impersonation_account}")
-
- cmd = [
- "gcloud",
- "dataflow",
- "sql",
- "query",
- query,
- *gcp_options,
- *(beam_options_to_args(options)),
- ]
+ gcp_options = {
+ "project": project_id,
+ "format": "value(job.id)",
+ "job-name": job_name,
+ "region": location,
+ }
+ cmd = self._build_gcloud_command(
+ command=["gcloud", "dataflow", "sql", "query", query],
parameters={**gcp_options, **options}
+ )
self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c
in cmd))
with self.provide_authorized_gcloud():
proc = subprocess.run(cmd, capture_output=True)
diff --git a/airflow/providers/google/cloud/operators/dataflow.py
b/airflow/providers/google/cloud/operators/dataflow.py
index 625356d50e..fd4d0644a7 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -40,7 +40,10 @@ from airflow.providers.google.cloud.hooks.dataflow import (
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.links.dataflow import DataflowJobLink,
DataflowPipelineLink
from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
-from airflow.providers.google.cloud.triggers.dataflow import
TemplateJobStartTrigger
+from airflow.providers.google.cloud.triggers.dataflow import (
+ DataflowStartYamlJobTrigger,
+ TemplateJobStartTrigger,
+)
from airflow.providers.google.common.consts import
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
from airflow.providers.google.common.deprecated import deprecated
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID
@@ -946,6 +949,11 @@ class
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
)
+@deprecated(
+ planned_removal_date="January 31, 2025",
+ use_instead="DataflowStartYamlJobOperator",
+ category=AirflowProviderDeprecationWarning,
+)
class DataflowStartSqlJobOperator(GoogleCloudBaseOperator):
"""
Starts Dataflow SQL query.
@@ -1051,6 +1059,178 @@ class
DataflowStartSqlJobOperator(GoogleCloudBaseOperator):
)
+class DataflowStartYamlJobOperator(GoogleCloudBaseOperator):
+ """
+ Launch a Dataflow YAML job and return the result.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:DataflowStartYamlJobOperator`
+
+ .. warning::
+ This operator requires ``gcloud`` command (Google Cloud SDK) must be
installed on the Airflow worker
+ <https://cloud.google.com/sdk/docs/install>`__
+
+ :param job_name: Required. The unique name to assign to the Cloud Dataflow
job.
+ :param yaml_pipeline_file: Required. Path to a file defining the YAML
pipeline to run.
+ Must be a local file or a URL beginning with 'gs://'.
+ :param region: Optional. Region ID of the job's regional endpoint.
Defaults to 'us-central1'.
+ :param project_id: Required. The ID of the GCP project that owns the job.
+ If set to ``None`` or missing, the default project_id from the GCP
connection is used.
+ :param gcp_conn_id: Optional. The connection ID used to connect to GCP.
+ :param append_job_name: Optional. Set to True if a unique suffix has to be
appended to the `job_name`.
+ Defaults to True.
+ :param drain_pipeline: Optional. Set to True if you want to stop a
streaming pipeline job by draining it
+ instead of canceling when killing the task instance. Note that this
does not work for batch pipeline jobs
+ or in the deferrable mode. Defaults to False.
+ For more info see:
https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
+ :param deferrable: Optional. Run operator in the deferrable mode.
+ :param expected_terminal_state: Optional. The expected terminal state of
the Dataflow job at which the
+ operator task is set to succeed. Defaults to 'JOB_STATE_DONE' for the
batch jobs and 'JOB_STATE_RUNNING'
+ for the streaming jobs.
+ :param poll_sleep: Optional. The time in seconds to sleep between polling
Google Cloud Platform for the Dataflow job status.
+ Used both for the sync and deferrable mode.
+ :param cancel_timeout: Optional. How long (in seconds) operator should
wait for the pipeline to be
+ successfully canceled when the task is being killed.
+ :param jinja_variables: Optional. A dictionary of Jinja2 variables to be
used in reifying the yaml pipeline file.
+ :param options: Optional. Additional gcloud or Beam job parameters.
+ It must be a dictionary with the keys matching the optional flag names
in gcloud.
+ The list of supported flags can be found at:
`https://cloud.google.com/sdk/gcloud/reference/dataflow/yaml/run`.
+ Note that if a flag does not require a value, then its dictionary
value must be either True or None.
+ For example, the `--log-http` flag can be passed as {'log-http': True}.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ :return: Dictionary containing the job's data.
+ """
+
+ template_fields: Sequence[str] = (
+ "job_name",
+ "yaml_pipeline_file",
+ "jinja_variables",
+ "options",
+ "region",
+ "project_id",
+ "gcp_conn_id",
+ )
+ template_fields_renderers = {
+ "jinja_variables": "json",
+ }
+ operator_extra_links = (DataflowJobLink(),)
+
+ def __init__(
+ self,
+ *,
+ job_name: str,
+ yaml_pipeline_file: str,
+ region: str = DEFAULT_DATAFLOW_LOCATION,
+ project_id: str = PROVIDE_PROJECT_ID,
+ gcp_conn_id: str = "google_cloud_default",
+ append_job_name: bool = True,
+ drain_pipeline: bool = False,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ poll_sleep: int = 10,
+ cancel_timeout: int | None = 5 * 60,
+ expected_terminal_state: str | None = None,
+ jinja_variables: dict[str, str] | None = None,
+ options: dict[str, Any] | None = None,
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.job_name = job_name
+ self.yaml_pipeline_file = yaml_pipeline_file
+ self.region = region
+ self.project_id = project_id
+ self.gcp_conn_id = gcp_conn_id
+ self.append_job_name = append_job_name
+ self.drain_pipeline = drain_pipeline
+ self.deferrable = deferrable
+ self.poll_sleep = poll_sleep
+ self.cancel_timeout = cancel_timeout
+ self.expected_terminal_state = expected_terminal_state
+ self.options = options
+ self.jinja_variables = jinja_variables
+ self.impersonation_chain = impersonation_chain
+ self.job_id: str | None = None
+
+ def execute(self, context: Context) -> dict[str, Any]:
+ self.job_id = self.hook.launch_beam_yaml_job(
+ job_name=self.job_name,
+ yaml_pipeline_file=self.yaml_pipeline_file,
+ append_job_name=self.append_job_name,
+ options=self.options,
+ jinja_variables=self.jinja_variables,
+ project_id=self.project_id,
+ location=self.region,
+ )
+
+ DataflowJobLink.persist(self, context, self.project_id, self.region,
self.job_id)
+
+ if self.deferrable:
+ self.defer(
+ trigger=DataflowStartYamlJobTrigger(
+ job_id=self.job_id,
+ project_id=self.project_id,
+ location=self.region,
+ gcp_conn_id=self.gcp_conn_id,
+ poll_sleep=self.poll_sleep,
+ cancel_timeout=self.cancel_timeout,
+ expected_terminal_state=self.expected_terminal_state,
+ impersonation_chain=self.impersonation_chain,
+ ),
+ method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
+ )
+
+ self.hook.wait_for_done(
+ job_name=self.job_name, location=self.region,
project_id=self.project_id, job_id=self.job_id
+ )
+ job = self.hook.get_job(job_id=self.job_id, location=self.region,
project_id=self.project_id)
+ return job
+
+ def execute_complete(self, context: Context, event: dict) -> dict[str,
Any]:
+ """Execute after the trigger returns an event."""
+ if event["status"] in ("error", "stopped"):
+ self.log.info("status: %s, msg: %s", event["status"],
event["message"])
+ raise AirflowException(event["message"])
+ job = event["job"]
+ self.log.info("Job %s completed with response %s", job["id"],
event["message"])
+ self.xcom_push(context, key="job_id", value=job["id"])
+
+ return job
+
+ def on_kill(self):
+ """
+ Cancel the dataflow job if a task instance gets killed.
+
+ This method will not be called if a task instance is killed in a
deferred
+ state.
+ """
+ self.log.info("On kill called.")
+ if self.job_id:
+ self.hook.cancel_job(
+ job_id=self.job_id,
+ project_id=self.project_id,
+ location=self.region,
+ )
+
+ @cached_property
+ def hook(self) -> DataflowHook:
+ return DataflowHook(
+ gcp_conn_id=self.gcp_conn_id,
+ poll_sleep=self.poll_sleep,
+ impersonation_chain=self.impersonation_chain,
+ drain_pipeline=self.drain_pipeline,
+ cancel_timeout=self.cancel_timeout,
+ expected_terminal_state=self.expected_terminal_state,
+ )
+
+
# TODO: Remove one day
@deprecated(
planned_removal_date="November 01, 2024",
diff --git a/airflow/providers/google/cloud/triggers/dataflow.py
b/airflow/providers/google/cloud/triggers/dataflow.py
index 01f96b98b7..4b994bf6e4 100644
--- a/airflow/providers/google/cloud/triggers/dataflow.py
+++ b/airflow/providers/google/cloud/triggers/dataflow.py
@@ -24,8 +24,10 @@ from typing import TYPE_CHECKING, Any, Sequence
from google.cloud.dataflow_v1beta3 import JobState
from google.cloud.dataflow_v1beta3.types import (
AutoscalingEvent,
+ Job,
JobMessage,
JobMetrics,
+ JobType,
MetricUpdate,
)
@@ -157,7 +159,7 @@ class TemplateJobStartTrigger(BaseTrigger):
class DataflowJobStatusTrigger(BaseTrigger):
"""
- Trigger that checks for metrics associated with a Dataflow job.
+ Trigger that monitors if a Dataflow job has reached any of the expected
statuses.
:param job_id: Required. ID of the job.
:param expected_statuses: The expected state(s) of the operation.
@@ -266,6 +268,148 @@ class DataflowJobStatusTrigger(BaseTrigger):
)
+class DataflowStartYamlJobTrigger(BaseTrigger):
+ """
+ Dataflow trigger that checks the state of a Dataflow YAML job.
+
+ :param job_id: Required. ID of the job.
+ :param project_id: Required. The Google Cloud project ID in which the job
was started.
+ :param location: The location where job is executed. If set to None then
+ the value of DEFAULT_DATAFLOW_LOCATION will be used.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param poll_sleep: Optional. The time in seconds to sleep between polling
Google Cloud Platform
+ for the Dataflow job.
+ :param cancel_timeout: Optional. How long (in seconds) operator should
wait for the pipeline to be
+ successfully cancelled when task is being killed.
+ :param expected_terminal_state: Optional. The expected terminal state of
the Dataflow job at which the
+ operator task is set to succeed. Defaults to 'JOB_STATE_DONE' for the
batch jobs and
+ 'JOB_STATE_RUNNING' for the streaming jobs.
+ :param impersonation_chain: Optional. Service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ def __init__(
+ self,
+ job_id: str,
+ project_id: str | None,
+ location: str = DEFAULT_DATAFLOW_LOCATION,
+ gcp_conn_id: str = "google_cloud_default",
+ poll_sleep: int = 10,
+ cancel_timeout: int | None = 5 * 60,
+ expected_terminal_state: str | None = None,
+ impersonation_chain: str | Sequence[str] | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.job_id = job_id
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.poll_sleep = poll_sleep
+ self.cancel_timeout = cancel_timeout
+ self.expected_terminal_state = expected_terminal_state
+ self.impersonation_chain = impersonation_chain
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize class arguments and classpath."""
+ return (
+
"airflow.providers.google.cloud.triggers.dataflow.DataflowStartYamlJobTrigger",
+ {
+ "project_id": self.project_id,
+ "job_id": self.job_id,
+ "location": self.location,
+ "gcp_conn_id": self.gcp_conn_id,
+ "poll_sleep": self.poll_sleep,
+ "expected_terminal_state": self.expected_terminal_state,
+ "impersonation_chain": self.impersonation_chain,
+ "cancel_timeout": self.cancel_timeout,
+ },
+ )
+
+ async def run(self):
+ """
+ Fetch job and yield events depending on the job's type and state.
+
+ Yield TriggerEvent if the job reaches a terminal state.
+ Otherwise awaits for a specified amount of time stored in
self.poll_sleep variable.
+ """
+ hook: AsyncDataflowHook = self._get_async_hook()
+ try:
+ while True:
+ job: Job = await hook.get_job(
+ job_id=self.job_id,
+ project_id=self.project_id,
+ location=self.location,
+ )
+ job_state = job.current_state
+ job_type = job.type_
+ if job_state.name == self.expected_terminal_state:
+ yield TriggerEvent(
+ {
+ "job": Job.to_dict(job),
+ "status": "success",
+ "message": f"Job reached the expected terminal
state: {self.expected_terminal_state}.",
+ }
+ )
+ return
+ elif job_type == JobType.JOB_TYPE_STREAMING and job_state ==
JobState.JOB_STATE_RUNNING:
+ yield TriggerEvent(
+ {
+ "job": Job.to_dict(job),
+ "status": "success",
+ "message": "Streaming job reached the RUNNING
state.",
+ }
+ )
+ return
+ elif job_type == JobType.JOB_TYPE_BATCH and job_state ==
JobState.JOB_STATE_DONE:
+ yield TriggerEvent(
+ {
+ "job": Job.to_dict(job),
+ "status": "success",
+ "message": "Batch job completed.",
+ }
+ )
+ return
+ elif job_state == JobState.JOB_STATE_FAILED:
+ yield TriggerEvent(
+ {
+ "job": Job.to_dict(job),
+ "status": "error",
+ "message": "Job failed.",
+ }
+ )
+ return
+ elif job_state == JobState.JOB_STATE_STOPPED:
+ yield TriggerEvent(
+ {
+ "job": Job.to_dict(job),
+ "status": "stopped",
+ "message": "Job was stopped.",
+ }
+ )
+ return
+ else:
+ self.log.info("Current job status is: %s", job_state.name)
+ self.log.info("Sleeping for %s seconds.", self.poll_sleep)
+ await asyncio.sleep(self.poll_sleep)
+ except Exception as e:
+ self.log.exception("Exception occurred while checking for job
completion.")
+ yield TriggerEvent({"job": None, "status": "error", "message":
str(e)})
+
+ def _get_async_hook(self) -> AsyncDataflowHook:
+ return AsyncDataflowHook(
+ gcp_conn_id=self.gcp_conn_id,
+ poll_sleep=self.poll_sleep,
+ impersonation_chain=self.impersonation_chain,
+ cancel_timeout=self.cancel_timeout,
+ )
+
+
class DataflowJobMetricsTrigger(BaseTrigger):
"""
Trigger that checks for metrics associated with a Dataflow job.
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
index 71fc3275fe..a9eb98ea9a 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
@@ -306,6 +306,38 @@ Here is an example of running Dataflow SQL job with
See the `Dataflow SQL reference
<https://cloud.google.com/dataflow/docs/reference/sql>`_.
+.. _howto/operator:DataflowStartYamlJobOperator:
+
+Dataflow YAML
+""""""""""""""
+Beam YAML is a no-code SDK for configuring Apache Beam pipelines by using YAML
files.
+You can use Beam YAML to author and run a Beam pipeline without writing any
code.
+This API can be used to define both streaming and batch pipelines.
+
+Here is an example of running Dataflow YAML job with
+:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowStartYamlJobOperator`:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_dataflow_start_yaml_job]
+ :end-before: [END howto_operator_dataflow_start_yaml_job]
+
+This operator can be run in deferrable mode by passing ``deferrable=True`` as
a parameter.
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_dataflow_start_yaml_job_def]
+ :end-before: [END howto_operator_dataflow_start_yaml_job_def]
+
+.. warning::
+ This operator requires ``gcloud`` command (Google Cloud SDK) must be
installed on the Airflow worker
+ <https://cloud.google.com/sdk/docs/install>`__
+
+See the `Dataflow YAML reference
+<https://cloud.google.com/sdk/gcloud/reference/dataflow/yaml>`_.
+
.. _howto/operator:DataflowStopJobOperator:
Stopping a pipeline
diff --git a/tests/always/test_example_dags.py
b/tests/always/test_example_dags.py
index 2dfcfb6e37..9d10ce5cad 100644
--- a/tests/always/test_example_dags.py
+++ b/tests/always/test_example_dags.py
@@ -51,6 +51,7 @@ IGNORE_AIRFLOW_PROVIDER_DEPRECATION_WARNING: tuple[str, ...]
= (
# If the deprecation is postponed, the item should be added to this tuple,
# and a corresponding Issue should be created on GitHub.
"tests/system/providers/google/cloud/bigquery/example_bigquery_operations.py",
+ "tests/system/providers/google/cloud/dataflow/example_dataflow_sql.py",
"tests/system/providers/google/cloud/dataproc/example_dataproc_gke.py",
"tests/system/providers/google/cloud/datapipelines/example_datapipeline.py",
"tests/system/providers/google/cloud/gcs/example_gcs_sensor.py",
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py
b/tests/providers/google/cloud/operators/test_dataflow.py
index a024ffd7a7..14787dba19 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -41,6 +41,7 @@ from airflow.providers.google.cloud.operators.dataflow import
(
DataflowRunPipelineOperator,
DataflowStartFlexTemplateOperator,
DataflowStartSqlJobOperator,
+ DataflowStartYamlJobOperator,
DataflowStopJobOperator,
DataflowTemplatedJobStartOperator,
)
@@ -711,16 +712,17 @@ class TestDataflowStartFlexTemplateOperator:
class TestDataflowStartSqlJobOperator:
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
def test_execute(self, mock_hook):
- start_sql = DataflowStartSqlJobOperator(
- task_id="start_sql_query",
- job_name=TEST_SQL_JOB_NAME,
- query=TEST_SQL_QUERY,
- options=deepcopy(TEST_SQL_OPTIONS),
- location=TEST_LOCATION,
- do_xcom_push=True,
- )
+ with pytest.warns(AirflowProviderDeprecationWarning):
+ start_sql = DataflowStartSqlJobOperator(
+ task_id="start_sql_query",
+ job_name=TEST_SQL_JOB_NAME,
+ query=TEST_SQL_QUERY,
+ options=deepcopy(TEST_SQL_OPTIONS),
+ location=TEST_LOCATION,
+ do_xcom_push=True,
+ )
+ start_sql.execute(mock.MagicMock())
- start_sql.execute(mock.MagicMock())
mock_hook.assert_called_once_with(
gcp_conn_id="google_cloud_default",
drain_pipeline=False,
@@ -741,6 +743,95 @@ class TestDataflowStartSqlJobOperator:
)
+class TestDataflowStartYamlJobOperator:
+ @pytest.fixture
+ def sync_operator(self):
+ return DataflowStartYamlJobOperator(
+ task_id="start_dataflow_yaml_job_sync",
+ job_name="dataflow_yaml_job",
+ yaml_pipeline_file="test_file_path",
+ append_job_name=False,
+ project_id=TEST_PROJECT,
+ region=TEST_LOCATION,
+ gcp_conn_id=GCP_CONN_ID,
+ expected_terminal_state=DataflowJobStatus.JOB_STATE_DONE,
+ )
+
+ @pytest.fixture
+ def deferrable_operator(self):
+ return DataflowStartYamlJobOperator(
+ task_id="start_dataflow_yaml_job_def",
+ job_name="dataflow_yaml_job",
+ yaml_pipeline_file="test_file_path",
+ append_job_name=False,
+ project_id=TEST_PROJECT,
+ region=TEST_LOCATION,
+ gcp_conn_id=GCP_CONN_ID,
+ deferrable=True,
+ expected_terminal_state=DataflowJobStatus.JOB_STATE_RUNNING,
+ )
+
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowHook")
+ def test_execute(self, mock_hook, sync_operator):
+ sync_operator.execute(mock.MagicMock())
+ mock_hook.assert_called_once_with(
+ poll_sleep=sync_operator.poll_sleep,
+ drain_pipeline=False,
+ impersonation_chain=None,
+ cancel_timeout=sync_operator.cancel_timeout,
+ expected_terminal_state=DataflowJobStatus.JOB_STATE_DONE,
+ gcp_conn_id=GCP_CONN_ID,
+ )
+ mock_hook.return_value.launch_beam_yaml_job.assert_called_once_with(
+ job_name=sync_operator.job_name,
+ yaml_pipeline_file=sync_operator.yaml_pipeline_file,
+ append_job_name=False,
+ options=None,
+ jinja_variables=None,
+ project_id=TEST_PROJECT,
+ location=TEST_LOCATION,
+ )
+
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowStartYamlJobOperator.defer")
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowHook")
+ def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method,
deferrable_operator):
+ deferrable_operator.execute(mock.MagicMock())
+ mock_hook.assert_called_once_with(
+ poll_sleep=deferrable_operator.poll_sleep,
+ drain_pipeline=False,
+ impersonation_chain=None,
+ cancel_timeout=deferrable_operator.cancel_timeout,
+ expected_terminal_state=DataflowJobStatus.JOB_STATE_RUNNING,
+ gcp_conn_id=GCP_CONN_ID,
+ )
+ mock_defer_method.assert_called_once()
+
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowStartYamlJobOperator.xcom_push")
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowHook")
+ def test_execute_complete_success(self, mock_hook, mock_xcom_push,
deferrable_operator):
+ expected_result = {"id": JOB_ID}
+ actual_result = deferrable_operator.execute_complete(
+ context=None,
+ event={
+ "status": "success",
+ "message": "Batch job completed.",
+ "job": expected_result,
+ },
+ )
+ mock_xcom_push.assert_called_with(None, key="job_id", value=JOB_ID)
+ assert actual_result == expected_result
+
+ def test_execute_complete_error_status_raises_exception(self,
deferrable_operator):
+ with pytest.raises(AirflowException, match="Job failed."):
+ deferrable_operator.execute_complete(
+ context=None, event={"status": "error", "message": "Job
failed."}
+ )
+ with pytest.raises(AirflowException, match="Job was stopped."):
+ deferrable_operator.execute_complete(
+ context=None, event={"status": "stopped", "message": "Job was
stopped."}
+ )
+
+
class TestDataflowStopJobOperator:
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
def test_exec_job_id(self, dataflow_mock):
diff --git a/tests/providers/google/cloud/triggers/test_dataflow.py
b/tests/providers/google/cloud/triggers/test_dataflow.py
index 2b9b63afa4..67b2f41009 100644
--- a/tests/providers/google/cloud/triggers/test_dataflow.py
+++ b/tests/providers/google/cloud/triggers/test_dataflow.py
@@ -22,7 +22,7 @@ import logging
from unittest import mock
import pytest
-from google.cloud.dataflow_v1beta3 import JobState
+from google.cloud.dataflow_v1beta3 import Job, JobState, JobType
from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus
from airflow.providers.google.cloud.triggers.dataflow import (
@@ -30,6 +30,7 @@ from airflow.providers.google.cloud.triggers.dataflow import (
DataflowJobMessagesTrigger,
DataflowJobMetricsTrigger,
DataflowJobStatusTrigger,
+ DataflowStartYamlJobTrigger,
TemplateJobStartTrigger,
)
from airflow.triggers.base import TriggerEvent
@@ -108,6 +109,24 @@ def dataflow_job_status_trigger():
)
[email protected]
+def dataflow_start_yaml_job_trigger():
+ return DataflowStartYamlJobTrigger(
+ project_id=PROJECT_ID,
+ job_id=JOB_ID,
+ location=LOCATION,
+ gcp_conn_id=GCP_CONN_ID,
+ poll_sleep=POLL_SLEEP,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ cancel_timeout=CANCEL_TIMEOUT,
+ )
+
+
[email protected]
+def test_dataflow_batch_job():
+ return Job(id=JOB_ID, current_state=JobState.JOB_STATE_DONE,
type_=JobType.JOB_TYPE_BATCH)
+
+
class TestTemplateJobStartTrigger:
def test_serialize(self, template_job_start_trigger):
actual_data = template_job_start_trigger.serialize()
@@ -548,13 +567,11 @@ class TestDataflowJobMetricsTrigger:
mock_get_job_metrics,
mock_job_status,
dataflow_job_metrics_trigger,
- caplog,
):
"""Test that DataflowJobMetricsTrigger is still in loop if the job
status is RUNNING."""
dataflow_job_metrics_trigger.fail_on_terminal_state = True
mock_job_status.return_value = JobState.JOB_STATE_RUNNING
mock_get_job_metrics.return_value = []
- caplog.set_level(logging.INFO)
task =
asyncio.create_task(dataflow_job_metrics_trigger.run().__anext__())
await asyncio.sleep(0.5)
assert task.done() is False
@@ -703,12 +720,10 @@ class TestDataflowJobStatusTrigger:
self,
mock_job_status,
dataflow_job_status_trigger,
- caplog,
):
"""Test that DataflowJobStatusTrigger is still in loop if the job
status neither terminal nor expected."""
dataflow_job_status_trigger.expected_statuses =
{DataflowJobStatus.JOB_STATE_DONE}
mock_job_status.return_value = JobState.JOB_STATE_RUNNING
- caplog.set_level(logging.INFO)
task =
asyncio.create_task(dataflow_job_status_trigger.run().__anext__())
await asyncio.sleep(0.5)
assert task.done() is False
@@ -729,3 +744,119 @@ class TestDataflowJobStatusTrigger:
)
actual_event = await dataflow_job_status_trigger.run().asend(None)
assert expected_event == actual_event
+
+
+class TestDataflowStartYamlJobTrigger:
+ def test_serialize(self, dataflow_start_yaml_job_trigger):
+ actual_data = dataflow_start_yaml_job_trigger.serialize()
+ expected_data = (
+
"airflow.providers.google.cloud.triggers.dataflow.DataflowStartYamlJobTrigger",
+ {
+ "project_id": PROJECT_ID,
+ "job_id": JOB_ID,
+ "location": LOCATION,
+ "gcp_conn_id": GCP_CONN_ID,
+ "poll_sleep": POLL_SLEEP,
+ "expected_terminal_state": None,
+ "impersonation_chain": IMPERSONATION_CHAIN,
+ "cancel_timeout": CANCEL_TIMEOUT,
+ },
+ )
+ assert actual_data == expected_data
+
+ @pytest.mark.parametrize(
+ "attr, expected",
+ [
+ ("gcp_conn_id", GCP_CONN_ID),
+ ("poll_sleep", POLL_SLEEP),
+ ("impersonation_chain", IMPERSONATION_CHAIN),
+ ("cancel_timeout", CANCEL_TIMEOUT),
+ ],
+ )
+ def test_get_async_hook(self, dataflow_start_yaml_job_trigger, attr,
expected):
+ hook = dataflow_start_yaml_job_trigger._get_async_hook()
+ actual = hook._hook_kwargs.get(attr)
+ assert actual is not None
+ assert actual == expected
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job")
+ async def test_run_loop_return_success_event(
+ self, mock_get_job, dataflow_start_yaml_job_trigger,
test_dataflow_batch_job
+ ):
+ mock_get_job.return_value = test_dataflow_batch_job
+ expected_event = TriggerEvent(
+ {
+ "job": Job.to_dict(test_dataflow_batch_job),
+ "status": "success",
+ "message": "Batch job completed.",
+ }
+ )
+ actual_event = await dataflow_start_yaml_job_trigger.run().asend(None)
+ assert actual_event == expected_event
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job")
+ async def test_run_loop_return_failed_event(
+ self, mock_get_job, dataflow_start_yaml_job_trigger,
test_dataflow_batch_job
+ ):
+ test_dataflow_batch_job.current_state = JobState.JOB_STATE_FAILED
+ mock_get_job.return_value = test_dataflow_batch_job
+ expected_event = TriggerEvent(
+ {
+ "job": Job.to_dict(test_dataflow_batch_job),
+ "status": "error",
+ "message": "Job failed.",
+ }
+ )
+ actual_event = await dataflow_start_yaml_job_trigger.run().asend(None)
+ assert actual_event == expected_event
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job")
+ async def test_run_loop_return_stopped_event(
+ self, mock_get_job, dataflow_start_yaml_job_trigger,
test_dataflow_batch_job
+ ):
+ test_dataflow_batch_job.current_state = JobState.JOB_STATE_STOPPED
+ mock_get_job.return_value = test_dataflow_batch_job
+ expected_event = TriggerEvent(
+ {
+ "job": Job.to_dict(test_dataflow_batch_job),
+ "status": "stopped",
+ "message": "Job was stopped.",
+ }
+ )
+ actual_event = await dataflow_start_yaml_job_trigger.run().asend(None)
+ assert actual_event == expected_event
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job")
+ async def test_run_loop_return_expected_state_event(
+ self, mock_get_job, dataflow_start_yaml_job_trigger,
test_dataflow_batch_job
+ ):
+ dataflow_start_yaml_job_trigger.expected_terminal_state =
DataflowJobStatus.JOB_STATE_RUNNING
+ test_dataflow_batch_job.current_state = JobState.JOB_STATE_RUNNING
+ mock_get_job.return_value = test_dataflow_batch_job
+ expected_event = TriggerEvent(
+ {
+ "job": Job.to_dict(test_dataflow_batch_job),
+ "status": "success",
+ "message": f"Job reached the expected terminal state:
{DataflowJobStatus.JOB_STATE_RUNNING}.",
+ }
+ )
+ actual_event = await dataflow_start_yaml_job_trigger.run().asend(None)
+ assert actual_event == expected_event
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job")
+ async def test_run_loop_is_still_running(
+ self, mock_get_job, dataflow_start_yaml_job_trigger,
test_dataflow_batch_job
+ ):
+ """Test that DataflowStartYamlJobTrigger is still in loop if the job
status neither terminal nor expected."""
+ dataflow_start_yaml_job_trigger.expected_terminal_state =
DataflowJobStatus.JOB_STATE_STOPPED
+ test_dataflow_batch_job.current_state = JobState.JOB_STATE_RUNNING
+ mock_get_job.return_value = test_dataflow_batch_job
+ task =
asyncio.create_task(dataflow_start_yaml_job_trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+ assert task.done() is False
+ task.cancel()
diff --git
a/tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py
b/tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py
new file mode 100644
index 0000000000..2bddc1de3f
--- /dev/null
+++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py
@@ -0,0 +1,172 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Example Airflow DAG for Google Cloud Dataflow YAML service.
+
+Requirements:
+ This test requires ``gcloud`` command (Google Cloud SDK) to be installed
on the Airflow worker
+ <https://cloud.google.com/sdk/docs/install>`__
+"""
+
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus
+from airflow.providers.google.cloud.operators.bigquery import (
+ BigQueryCreateEmptyDatasetOperator,
+ BigQueryCreateEmptyTableOperator,
+ BigQueryDeleteDatasetOperator,
+ BigQueryInsertJobOperator,
+)
+from airflow.providers.google.cloud.operators.dataflow import
DataflowStartYamlJobOperator
+from airflow.utils.trigger_rule import TriggerRule
+from tests.system.providers.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
+
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or
DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
+DAG_ID = "dataflow_yaml"
+REGION = "europe-west2"
+DATAFLOW_YAML_JOB_NAME = f"{DAG_ID}_{ENV_ID}".replace("_", "-")
+BQ_DATASET = f"{DAG_ID}_{ENV_ID}".replace("-", "_")
+BQ_INPUT_TABLE = f"input_{DAG_ID}".replace("-", "_")
+BQ_OUTPUT_TABLE = f"output_{DAG_ID}".replace("-", "_")
+DATAFLOW_YAML_PIPELINE_FILE_URL = (
+
"gs://airflow-system-tests-resources/dataflow/yaml/example_beam_yaml_bq.yaml"
+)
+
+BQ_VARIABLES = {
+ "project": PROJECT_ID,
+ "dataset": BQ_DATASET,
+ "input": BQ_INPUT_TABLE,
+ "output": BQ_OUTPUT_TABLE,
+}
+
+BQ_VARIABLES_DEF = {
+ "project": PROJECT_ID,
+ "dataset": BQ_DATASET,
+ "input": BQ_INPUT_TABLE,
+ "output": f"{BQ_OUTPUT_TABLE}_def",
+}
+
+INSERT_ROWS_QUERY = (
+ f"INSERT {BQ_DATASET}.{BQ_INPUT_TABLE} VALUES "
+ "('John Doe', 900, 'USA'), "
+ "('Alice Storm', 1200, 'Australia'),"
+ "('Bob Max', 1000, 'Australia'),"
+ "('Peter Jackson', 800, 'New Zealand'),"
+ "('Hobby Doyle', 1100, 'USA'),"
+ "('Terrance Phillips', 2222, 'Canada'),"
+ "('Joe Schmoe', 1500, 'Canada'),"
+ "('Dominique Levillaine', 2780, 'France');"
+)
+
+
+with DAG(
+ dag_id=DAG_ID,
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "dataflow", "yaml"],
+) as dag:
+ create_bq_dataset = BigQueryCreateEmptyDatasetOperator(
+ task_id="create_bq_dataset",
+ dataset_id=BQ_DATASET,
+ location=REGION,
+ )
+
+ create_bq_input_table = BigQueryCreateEmptyTableOperator(
+ task_id="create_bq_input_table",
+ dataset_id=BQ_DATASET,
+ table_id=BQ_INPUT_TABLE,
+ schema_fields=[
+ {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"},
+ {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"},
+ {"name": "country", "type": "STRING", "mode": "NULLABLE"},
+ ],
+ )
+
+ insert_data_into_bq_table = BigQueryInsertJobOperator(
+ task_id="insert_data_into_bq_table",
+ configuration={
+ "query": {
+ "query": INSERT_ROWS_QUERY,
+ "useLegacySql": False,
+ "priority": "BATCH",
+ }
+ },
+ location=REGION,
+ )
+
+ # [START howto_operator_dataflow_start_yaml_job]
+ start_dataflow_yaml_job = DataflowStartYamlJobOperator(
+ task_id="start_dataflow_yaml_job",
+ job_name=DATAFLOW_YAML_JOB_NAME,
+ yaml_pipeline_file=DATAFLOW_YAML_PIPELINE_FILE_URL,
+ append_job_name=True,
+ deferrable=False,
+ region=REGION,
+ project_id=PROJECT_ID,
+ jinja_variables=BQ_VARIABLES,
+ )
+ # [END howto_operator_dataflow_start_yaml_job]
+
+ # [START howto_operator_dataflow_start_yaml_job_def]
+ start_dataflow_yaml_job_def = DataflowStartYamlJobOperator(
+ task_id="start_dataflow_yaml_job_def",
+ job_name=DATAFLOW_YAML_JOB_NAME,
+ yaml_pipeline_file=DATAFLOW_YAML_PIPELINE_FILE_URL,
+ append_job_name=True,
+ deferrable=True,
+ region=REGION,
+ project_id=PROJECT_ID,
+ jinja_variables=BQ_VARIABLES_DEF,
+ expected_terminal_state=DataflowJobStatus.JOB_STATE_DONE,
+ )
+ # [END howto_operator_dataflow_start_yaml_job_def]
+
+ delete_bq_dataset = BigQueryDeleteDatasetOperator(
+ task_id="delete_bq_dataset",
+ dataset_id=BQ_DATASET,
+ delete_contents=True,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ (
+ # TEST SETUP
+ create_bq_dataset
+ >> create_bq_input_table
+ >> insert_data_into_bq_table
+ # TEST BODY
+ >> [start_dataflow_yaml_job, start_dataflow_yaml_job_def]
+ # TEST TEARDOWN
+ >> delete_bq_dataset
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)