This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 28a240a18f Fix deferrable mode for DataflowTemplatedJobStartOperator
and DataflowStartFlexTemplateOperator (#39018)
28a240a18f is described below
commit 28a240a18f7e5958e69732f61d639e1d8f39152f
Author: Eugene <[email protected]>
AuthorDate: Mon Apr 29 02:40:20 2024 +0000
Fix deferrable mode for DataflowTemplatedJobStartOperator and
DataflowStartFlexTemplateOperator (#39018)
---
airflow/providers/google/cloud/hooks/dataflow.py | 177 +++++++++++++++++----
.../providers/google/cloud/operators/dataflow.py | 86 ++++++----
.../providers/google/cloud/triggers/dataflow.py | 2 +-
.../operators/cloud/dataflow.rst | 20 ++-
.../providers/google/cloud/hooks/test_dataflow.py | 52 ++++++
.../google/cloud/operators/test_dataflow.py | 50 ++++--
.../cloud/dataflow/example_dataflow_template.py | 42 ++++-
7 files changed, 336 insertions(+), 93 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/dataflow.py
b/airflow/providers/google/cloud/hooks/dataflow.py
index a9bf802b14..59eee63501 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -41,9 +41,13 @@ from google.cloud.dataflow_v1beta3 import (
MessagesV1Beta3AsyncClient,
MetricsV1Beta3AsyncClient,
)
-from google.cloud.dataflow_v1beta3.types import GetJobMetricsRequest,
JobMessageImportance, JobMetrics
+from google.cloud.dataflow_v1beta3.types import (
+ GetJobMetricsRequest,
+ JobMessageImportance,
+ JobMetrics,
+)
from google.cloud.dataflow_v1beta3.types.jobs import ListJobsRequest
-from googleapiclient.discovery import build
+from googleapiclient.discovery import Resource, build
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType,
beam_options_to_args
@@ -573,7 +577,7 @@ class DataflowHook(GoogleBaseHook):
impersonation_chain=impersonation_chain,
)
- def get_conn(self) -> build:
+ def get_conn(self) -> Resource:
"""Return a Google Cloud Dataflow service object."""
http_authorized = self._authorize()
return build("dataflow", "v1b3", http=http_authorized,
cache_discovery=False)
@@ -653,9 +657,9 @@ class DataflowHook(GoogleBaseHook):
on_new_job_callback: Callable[[dict], None] | None = None,
location: str = DEFAULT_DATAFLOW_LOCATION,
environment: dict | None = None,
- ) -> dict:
+ ) -> dict[str, str]:
"""
- Start Dataflow template job.
+ Launch a Dataflow job with a Classic Template and wait for its
completion.
:param job_name: The name of the job.
:param variables: Map of job runtime environment options.
@@ -688,26 +692,14 @@ class DataflowHook(GoogleBaseHook):
environment=environment,
)
- service = self.get_conn()
-
- request = (
- service.projects()
- .locations()
- .templates()
- .launch(
- projectId=project_id,
- location=location,
- gcsPath=dataflow_template,
- body={
- "jobName": name,
- "parameters": parameters,
- "environment": environment,
- },
- )
+ job: dict[str, str] = self.send_launch_template_request(
+ project_id=project_id,
+ location=location,
+ gcs_path=dataflow_template,
+ job_name=name,
+ parameters=parameters,
+ environment=environment,
)
- response = request.execute(num_retries=self.num_retries)
-
- job = response["job"]
if on_new_job_id_callback:
warnings.warn(
@@ -715,7 +707,7 @@ class DataflowHook(GoogleBaseHook):
AirflowProviderDeprecationWarning,
stacklevel=3,
)
- on_new_job_id_callback(job.get("id"))
+ on_new_job_id_callback(job["id"])
if on_new_job_callback:
on_new_job_callback(job)
@@ -734,7 +726,62 @@ class DataflowHook(GoogleBaseHook):
expected_terminal_state=self.expected_terminal_state,
)
jobs_controller.wait_for_done()
- return response["job"]
+ return job
+
+ @_fallback_to_location_from_variables
+ @_fallback_to_project_id_from_variables
+ @GoogleBaseHook.fallback_to_default_project_id
+ def launch_job_with_template(
+ self,
+ *,
+ job_name: str,
+ variables: dict,
+ parameters: dict,
+ dataflow_template: str,
+ project_id: str,
+ append_job_name: bool = True,
+ location: str = DEFAULT_DATAFLOW_LOCATION,
+ environment: dict | None = None,
+ ) -> dict[str, str]:
+ """
+ Launch a Dataflow job with a Classic Template and exit without waiting
for its completion.
+
+ :param job_name: The name of the job.
+ :param variables: Map of job runtime environment options.
+ It will update environment argument if passed.
+
+ .. seealso::
+ For more information on possible configurations, look at the
API documentation
+
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
+
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
+
+ :param parameters: Parameters for the template
+ :param dataflow_template: GCS path to the template.
+ :param project_id: Optional, the Google Cloud project ID in which to
start a job.
+ If set to None or missing, the default project_id from the Google
Cloud connection is used.
+ :param append_job_name: True if unique suffix has to be appended to
job name.
+ :param location: Job location.
+
+ .. seealso::
+ For more information on possible configurations, look at the
API documentation
+
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
+
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
+ :return: the Dataflow job response
+ """
+ name = self.build_dataflow_job_name(job_name, append_job_name)
+ environment = self._update_environment(
+ variables=variables,
+ environment=environment,
+ )
+ job: dict[str, str] = self.send_launch_template_request(
+ project_id=project_id,
+ location=location,
+ gcs_path=dataflow_template,
+ job_name=name,
+ parameters=parameters,
+ environment=environment,
+ )
+ return job
def _update_environment(self, variables: dict, environment: dict | None =
None) -> dict:
environment = environment or {}
@@ -770,6 +817,35 @@ class DataflowHook(GoogleBaseHook):
return environment
+ def send_launch_template_request(
+ self,
+ *,
+ project_id: str,
+ location: str,
+ gcs_path: str,
+ job_name: str,
+ parameters: dict,
+ environment: dict,
+ ) -> dict[str, str]:
+ service: Resource = self.get_conn()
+ request = (
+ service.projects()
+ .locations()
+ .templates()
+ .launch(
+ projectId=project_id,
+ location=location,
+ gcsPath=gcs_path,
+ body={
+ "jobName": job_name,
+ "parameters": parameters,
+ "environment": environment,
+ },
+ )
+ )
+ response: dict = request.execute(num_retries=self.num_retries)
+ return response["job"]
+
@GoogleBaseHook.fallback_to_default_project_id
def start_flex_template(
self,
@@ -778,9 +854,9 @@ class DataflowHook(GoogleBaseHook):
project_id: str,
on_new_job_id_callback: Callable[[str], None] | None = None,
on_new_job_callback: Callable[[dict], None] | None = None,
- ) -> dict:
+ ) -> dict[str, str]:
"""
- Start flex templates with the Dataflow pipeline.
+ Launch a Dataflow job with a Flex Template and wait for its completion.
:param body: The request body. See:
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
@@ -791,15 +867,16 @@ class DataflowHook(GoogleBaseHook):
:param on_new_job_callback: A callback that is called when a Job is
detected.
:return: the Job
"""
- service = self.get_conn()
+ service: Resource = self.get_conn()
request = (
service.projects()
.locations()
.flexTemplates()
.launch(projectId=project_id, body=body, location=location)
)
- response = request.execute(num_retries=self.num_retries)
+ response: dict = request.execute(num_retries=self.num_retries)
job = response["job"]
+ job_id: str = job["id"]
if on_new_job_id_callback:
warnings.warn(
@@ -807,7 +884,7 @@ class DataflowHook(GoogleBaseHook):
AirflowProviderDeprecationWarning,
stacklevel=3,
)
- on_new_job_id_callback(job.get("id"))
+ on_new_job_id_callback(job_id)
if on_new_job_callback:
on_new_job_callback(job)
@@ -815,7 +892,7 @@ class DataflowHook(GoogleBaseHook):
jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
- job_id=job.get("id"),
+ job_id=job_id,
location=location,
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
@@ -826,6 +903,42 @@ class DataflowHook(GoogleBaseHook):
return jobs_controller.get_jobs(refresh=True)[0]
+ @GoogleBaseHook.fallback_to_default_project_id
+ def launch_job_with_flex_template(
+ self,
+ body: dict,
+ location: str,
+ project_id: str,
+ ) -> dict[str, str]:
+ """
+ Launch a Dataflow Job with a Flex Template and exit without waiting
for the job completion.
+
+ :param body: The request body. See:
+
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
+ :param location: The location of the Dataflow job (for example
europe-west1)
+ :param project_id: The ID of the GCP project that owns the job.
+ If set to ``None`` or missing, the default project_id from the GCP
connection is used.
+ :return: a Dataflow job response
+ """
+ service: Resource = self.get_conn()
+ request = (
+ service.projects()
+ .locations()
+ .flexTemplates()
+ .launch(projectId=project_id, body=body, location=location)
+ )
+ response: dict = request.execute(num_retries=self.num_retries)
+ return response["job"]
+
+ @staticmethod
+ def extract_job_id(job: dict) -> str:
+ try:
+ return job["id"]
+ except KeyError:
+ raise AirflowException(
+ "While reading job object after template execution error
occurred. Job object has no id."
+ )
+
@_fallback_to_location_from_variables
@_fallback_to_project_id_from_variables
@GoogleBaseHook.fallback_to_default_project_id
diff --git a/airflow/providers/google/cloud/operators/dataflow.py
b/airflow/providers/google/cloud/operators/dataflow.py
index 4a6f197e14..424cb8d805 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -41,6 +41,7 @@ from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.links.dataflow import DataflowJobLink
from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.dataflow import
TemplateJobStartTrigger
+from airflow.providers.google.common.consts import
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID
from airflow.version import version
@@ -460,7 +461,7 @@ class
DataflowCreateJavaJobOperator(GoogleCloudBaseOperator):
class DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
"""
- Start a Templated Cloud Dataflow job; the parameters of the operation will
be passed to the job.
+ Start a Dataflow job with a classic template; the parameters of the
operation will be passed to the job.
.. seealso::
For more information on how to use this operator, take a look at the
guide:
@@ -643,7 +644,7 @@ class
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
self.deferrable = deferrable
self.expected_terminal_state = expected_terminal_state
- self.job: dict | None = None
+ self.job: dict[str, str] | None = None
self._validate_deferrable_params()
@@ -681,29 +682,34 @@ class
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
if not self.location:
self.location = DEFAULT_DATAFLOW_LOCATION
- self.job = self.hook.start_template_dataflow(
+ if not self.deferrable:
+ self.job = self.hook.start_template_dataflow(
+ job_name=self.job_name,
+ variables=options,
+ parameters=self.parameters,
+ dataflow_template=self.template,
+ on_new_job_callback=set_current_job,
+ project_id=self.project_id,
+ location=self.location,
+ environment=self.environment,
+ append_job_name=self.append_job_name,
+ )
+ job_id = self.hook.extract_job_id(self.job)
+ self.xcom_push(context, key="job_id", value=job_id)
+ return job_id
+
+ self.job = self.hook.launch_job_with_template(
job_name=self.job_name,
variables=options,
parameters=self.parameters,
dataflow_template=self.template,
- on_new_job_callback=set_current_job,
project_id=self.project_id,
+ append_job_name=self.append_job_name,
location=self.location,
environment=self.environment,
- append_job_name=self.append_job_name,
)
- job_id = self.job.get("id")
-
- if job_id is None:
- raise AirflowException(
- "While reading job object after template execution error
occurred. Job object has no id."
- )
-
- if not self.deferrable:
- return job_id
-
- context["ti"].xcom_push(key="job_id", value=job_id)
-
+ job_id = self.hook.extract_job_id(self.job)
+ DataflowJobLink.persist(self, context, self.project_id, self.location,
job_id)
self.defer(
trigger=TemplateJobStartTrigger(
project_id=self.project_id,
@@ -714,16 +720,17 @@ class
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
impersonation_chain=self.impersonation_chain,
cancel_timeout=self.cancel_timeout,
),
- method_name="execute_complete",
+ method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
)
- def execute_complete(self, context: Context, event: dict[str, Any]):
+ def execute_complete(self, context: Context, event: dict[str, Any]) -> str:
"""Execute after trigger finishes its work."""
if event["status"] in ("error", "stopped"):
self.log.info("status: %s, msg: %s", event["status"],
event["message"])
raise AirflowException(event["message"])
job_id = event["job_id"]
+ self.xcom_push(context, key="job_id", value=job_id)
self.log.info("Task %s completed with response %s", self.task_id,
event["message"])
return job_id
@@ -741,7 +748,7 @@ class
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
class DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
"""
- Starts flex templates with the Dataflow pipeline.
+ Starts a Dataflow Job with a Flex Template.
.. seealso::
For more information on how to use this operator, take a look at the
guide:
@@ -803,6 +810,9 @@ class
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
:param expected_terminal_state: The expected final status of the operator
on which the corresponding
Airflow task succeeds. When not specified, it will be determined by
the hook.
:param append_job_name: True if unique suffix has to be appended to job
name.
+ :param poll_sleep: The time in seconds to sleep between polling Google
+ Cloud Platform for the dataflow job status while the job is in the
+ JOB_STATE_RUNNING state.
"""
template_fields: Sequence[str] = ("body", "location", "project_id",
"gcp_conn_id")
@@ -821,6 +831,7 @@ class
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
append_job_name: bool = True,
expected_terminal_state: str | None = None,
+ poll_sleep: int = 10,
*args,
**kwargs,
) -> None:
@@ -832,11 +843,12 @@ class
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
self.drain_pipeline = drain_pipeline
self.cancel_timeout = cancel_timeout
self.wait_until_finished = wait_until_finished
- self.job: dict | None = None
+ self.job: dict[str, str] | None = None
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.expected_terminal_state = expected_terminal_state
self.append_job_name = append_job_name
+ self.poll_sleep = poll_sleep
self._validate_deferrable_params()
@@ -871,32 +883,35 @@ class
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
self.job = current_job
DataflowJobLink.persist(self, context, self.project_id,
self.location, self.job.get("id"))
- self.job = self.hook.start_flex_template(
+ if not self.deferrable:
+ self.job = self.hook.start_flex_template(
+ body=self.body,
+ location=self.location,
+ project_id=self.project_id,
+ on_new_job_callback=set_current_job,
+ )
+ job_id = self.hook.extract_job_id(self.job)
+ self.xcom_push(context, key="job_id", value=job_id)
+ return self.job
+
+ self.job = self.hook.launch_job_with_flex_template(
body=self.body,
location=self.location,
project_id=self.project_id,
- on_new_job_callback=set_current_job,
)
-
- job_id = self.job.get("id")
- if job_id is None:
- raise AirflowException(
- "While reading job object after template execution error
occurred. Job object has no id."
- )
-
- if not self.deferrable:
- return self.job
-
+ job_id = self.hook.extract_job_id(self.job)
+ DataflowJobLink.persist(self, context, self.project_id, self.location,
job_id)
self.defer(
trigger=TemplateJobStartTrigger(
project_id=self.project_id,
job_id=job_id,
location=self.location,
gcp_conn_id=self.gcp_conn_id,
+ poll_sleep=self.poll_sleep,
impersonation_chain=self.impersonation_chain,
cancel_timeout=self.cancel_timeout,
),
- method_name="execute_complete",
+ method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
)
def _append_uuid_to_job_name(self):
@@ -907,7 +922,7 @@ class
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
job_body["jobName"] = job_name
self.log.info("Job name was changed to %s", job_name)
- def execute_complete(self, context: Context, event: dict):
+ def execute_complete(self, context: Context, event: dict) -> dict[str,
str]:
"""Execute after trigger finishes its work."""
if event["status"] in ("error", "stopped"):
self.log.info("status: %s, msg: %s", event["status"],
event["message"])
@@ -915,6 +930,7 @@ class
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
job_id = event["job_id"]
self.log.info("Task %s completed with response %s", job_id,
event["message"])
+ self.xcom_push(context, key="job_id", value=job_id)
job = self.hook.get_job(job_id=job_id, project_id=self.project_id,
location=self.location)
return job
diff --git a/airflow/providers/google/cloud/triggers/dataflow.py
b/airflow/providers/google/cloud/triggers/dataflow.py
index 32f68a9fd7..577c0bbf60 100644
--- a/airflow/providers/google/cloud/triggers/dataflow.py
+++ b/airflow/providers/google/cloud/triggers/dataflow.py
@@ -138,7 +138,7 @@ class TemplateJobStartTrigger(BaseTrigger):
return
else:
self.log.info("Job is still running...")
- self.log.info("Current job status is: %s", status)
+ self.log.info("Current job status is: %s", status.name)
self.log.info("Sleeping for %s seconds.", self.poll_sleep)
await asyncio.sleep(self.poll_sleep)
except Exception as e:
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
index d3f1bd6df4..f9302af8c3 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst
@@ -208,7 +208,7 @@ from the staging and execution steps. There are two types
of templates for Dataf
See the `official documentation for Dataflow templates
<https://cloud.google.com/dataflow/docs/concepts/dataflow-templates>`_ for
more information.
-Here is an example of running Classic template with
+Here is an example of running a Dataflow job using a Classic Template with
:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator`:
.. exampleinclude::
/../../tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
@@ -217,10 +217,18 @@ Here is an example of running Classic template with
:start-after: [START howto_operator_start_template_job]
:end-before: [END howto_operator_start_template_job]
+Also for this action you can use the operator in the deferrable mode:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_start_template_job_deferrable]
+ :end-before: [END howto_operator_start_template_job_deferrable]
+
See the `list of Google-provided templates that can be used with this operator
<https://cloud.google.com/dataflow/docs/guides/templates/provided-templates>`_.
-Here is an example of running Flex template with
+Here is an example of running a Dataflow job using a Flex Template with
:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowStartFlexTemplateOperator`:
.. exampleinclude::
/../../tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
@@ -229,6 +237,14 @@ Here is an example of running Flex template with
:start-after: [START howto_operator_start_flex_template_job]
:end-before: [END howto_operator_start_flex_template_job]
+Also for this action you can use the operator in the deferrable mode:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_start_flex_template_job_deferrable]
+ :end-before: [END howto_operator_start_flex_template_job_deferrable]
+
.. _howto/operator:DataflowStartSqlJobOperator:
Dataflow SQL
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py
b/tests/providers/google/cloud/hooks/test_dataflow.py
index 2458b48e81..1c8f768ea3 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -1052,6 +1052,34 @@ class TestDataflowTemplateHook:
)
mock_uuid.assert_called_once_with()
+ @mock.patch(DATAFLOW_STRING.format("uuid.uuid4"), return_value=MOCK_UUID)
+ @mock.patch(DATAFLOW_STRING.format("DataflowHook.get_conn"))
+ def test_launch_job_with_template(self, mock_conn, mock_uuid):
+ launch_method = (
+
mock_conn.return_value.projects.return_value.locations.return_value.templates.return_value.launch
+ )
+ launch_method.return_value.execute.return_value = {"job": {"id":
TEST_JOB_ID}}
+ variables = {"zone": "us-central1-f", "tempLocation": "gs://test/temp"}
+ result = self.dataflow_hook.launch_job_with_template(
+ job_name=JOB_NAME,
+ variables=copy.deepcopy(variables),
+ parameters=PARAMETERS,
+ dataflow_template=TEST_TEMPLATE,
+ project_id=TEST_PROJECT,
+ )
+
+ launch_method.assert_called_once_with(
+ body={
+ "jobName": f"test-dataflow-pipeline-{MOCK_UUID_PREFIX}",
+ "parameters": PARAMETERS,
+ "environment": variables,
+ },
+ gcsPath="gs://dataflow-templates/wordcount/template_file",
+ projectId=TEST_PROJECT,
+ location=DEFAULT_DATAFLOW_LOCATION,
+ )
+ assert result == {"id": TEST_JOB_ID}
+
@mock.patch(DATAFLOW_STRING.format("_DataflowJobsController"))
@mock.patch(DATAFLOW_STRING.format("DataflowHook.get_conn"))
def test_start_flex_template(self, mock_conn, mock_controller):
@@ -1088,6 +1116,26 @@ class TestDataflowTemplateHook:
mock_controller.return_value.get_jobs.assert_called_once_with(refresh=True)
assert result == {"id": TEST_JOB_ID}
+ @mock.patch(DATAFLOW_STRING.format("DataflowHook.get_conn"))
+ def test_launch_job_with_flex_template(self, mock_conn):
+ expected_job = {"id": TEST_JOB_ID}
+
+ mock_locations = mock_conn.return_value.projects.return_value.locations
+ launch_method =
mock_locations.return_value.flexTemplates.return_value.launch
+ launch_method.return_value.execute.return_value = {"job": expected_job}
+
+ result = self.dataflow_hook.launch_job_with_flex_template(
+ body={"launchParameter": TEST_FLEX_PARAMETERS},
+ location=TEST_LOCATION,
+ project_id=TEST_PROJECT_ID,
+ )
+ launch_method.assert_called_once_with(
+ projectId="test-project-id",
+ body={"launchParameter": TEST_FLEX_PARAMETERS},
+ location=TEST_LOCATION,
+ )
+ assert result == {"id": TEST_JOB_ID}
+
@mock.patch(DATAFLOW_STRING.format("_DataflowJobsController"))
@mock.patch(DATAFLOW_STRING.format("DataflowHook.get_conn"))
def test_cancel_job(self, mock_get_conn, jobs_controller):
@@ -1177,6 +1225,10 @@ class TestDataflowTemplateHook:
on_new_job_callback=mock.MagicMock(),
)
+ def test_extract_job_id_raises_exception(self):
+ with pytest.raises(AirflowException):
+ self.dataflow_hook.extract_job_id({"not_id": True})
+
class TestDataflowJob:
def setup_method(self):
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py
b/tests/providers/google/cloud/operators/test_dataflow.py
index 495287b9af..ebbf471383 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -102,6 +102,7 @@ TEST_SQL_JOB = {"id": "test-job-id"}
GCP_CONN_ID = "test_gcp_conn_id"
IMPERSONATION_CHAIN = ["impersonate", "this"]
CANCEL_TIMEOUT = 10 * 420
+DATAFLOW_PATH = "airflow.providers.google.cloud.operators.dataflow"
class TestDataflowCreatePythonJobOperator:
@@ -488,11 +489,12 @@ class TestDataflowTemplatedJobStartOperator:
cancel_timeout=CANCEL_TIMEOUT,
)
-
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
- def test_exec(self, dataflow_mock, sync_operator):
- start_template_hook =
dataflow_mock.return_value.start_template_dataflow
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push")
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowHook")
+ def test_execute(self, hook_mock, mock_xcom_push, sync_operator):
+ start_template_hook = hook_mock.return_value.start_template_dataflow
sync_operator.execute(None)
- assert dataflow_mock.called
+ assert hook_mock.called
expected_options = {
"project": "test",
"stagingLocation": "gs://test/staging",
@@ -512,10 +514,27 @@ class TestDataflowTemplatedJobStartOperator:
append_job_name=True,
)
-
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator.defer")
-
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator.hook")
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.defer")
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowHook")
def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method,
deferrable_operator):
deferrable_operator.execute(mock.MagicMock())
+ expected_variables = {
+ "project": "test",
+ "stagingLocation": "gs://test/staging",
+ "tempLocation": "gs://test/temp",
+ "zone": "us-central1-f",
+ "EXTRA_OPTION": "TEST_A",
+ }
+
mock_hook.return_value.launch_job_with_template.assert_called_once_with(
+ job_name=JOB_NAME,
+ variables=expected_variables,
+ parameters=PARAMETERS,
+ dataflow_template=TEMPLATE,
+ project_id=TEST_PROJECT,
+ append_job_name=True,
+ location=TEST_LOCATION,
+ environment={"maxWorkers": 2},
+ )
mock_defer_method.assert_called_once()
def test_validation_deferrable_params_raises_error(self):
@@ -540,8 +559,9 @@ class TestDataflowTemplatedJobStartOperator:
DataflowTemplatedJobStartOperator(**init_kwargs)
@pytest.mark.db_test
-
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook.start_template_dataflow")
- def test_start_with_custom_region(self, dataflow_mock):
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push")
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowHook.start_template_dataflow")
+ def test_start_with_custom_region(self, dataflow_mock, mock_xcom_push):
init_kwargs = {
"task_id": TASK_ID,
"template": TEMPLATE,
@@ -560,8 +580,9 @@ class TestDataflowTemplatedJobStartOperator:
assert kwargs["location"] == DEFAULT_DATAFLOW_LOCATION
@pytest.mark.db_test
-
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook.start_template_dataflow")
- def test_start_with_location(self, dataflow_mock):
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push")
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowHook.start_template_dataflow")
+ def test_start_with_location(self, dataflow_mock, mock_xcom_push):
init_kwargs = {
"task_id": TASK_ID,
"template": TEMPLATE,
@@ -601,7 +622,7 @@ class TestDataflowStartFlexTemplateOperator:
deferrable=True,
)
-
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowHook")
def test_execute(self, mock_dataflow, sync_operator):
sync_operator.execute(mock.MagicMock())
mock_dataflow.assert_called_once_with(
@@ -640,16 +661,15 @@ class TestDataflowStartFlexTemplateOperator:
with pytest.raises(ValueError):
DataflowStartFlexTemplateOperator(**init_kwargs)
-
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowStartFlexTemplateOperator.defer")
-
@mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook")
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowStartFlexTemplateOperator.defer")
+ @mock.patch(f"{DATAFLOW_PATH}.DataflowHook")
def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method,
deferrable_operator):
deferrable_operator.execute(mock.MagicMock())
- mock_hook.return_value.start_flex_template.assert_called_once_with(
+
mock_hook.return_value.launch_job_with_flex_template.assert_called_once_with(
body={"launchParameter": TEST_FLEX_PARAMETERS},
location=TEST_LOCATION,
project_id=TEST_PROJECT,
- on_new_job_callback=mock.ANY,
)
mock_defer_method.assert_called_once()
diff --git
a/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
b/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
index b6eec97a16..2a3e747eb7 100644
--- a/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
+++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_template.py
@@ -17,7 +17,8 @@
# under the License.
"""
-Example Airflow DAG for testing Google Dataflow
+Example Airflow DAG for testing Google Dataflow.
+
:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator`
operator.
"""
@@ -27,6 +28,7 @@ import os
from datetime import datetime
from pathlib import Path
+from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.dataflow import (
DataflowStartFlexTemplateOperator,
@@ -104,6 +106,7 @@ with DAG(
template="gs://dataflow-templates/latest/Word_Count",
parameters={"inputFile": f"gs://{BUCKET_NAME}/{CSV_FILE_NAME}",
"output": GCS_OUTPUT},
location=LOCATION,
+ wait_until_finished=True,
)
# [END howto_operator_start_template_job]
@@ -114,20 +117,43 @@ with DAG(
body=BODY,
location=LOCATION,
append_job_name=False,
+ wait_until_finished=True,
)
# [END howto_operator_start_flex_template_job]
+ # [START howto_operator_start_template_job_deferrable]
+ start_template_job_deferrable = DataflowTemplatedJobStartOperator(
+ task_id="start_template_job_deferrable",
+ project_id=PROJECT_ID,
+ template="gs://dataflow-templates/latest/Word_Count",
+ parameters={"inputFile": f"gs://{BUCKET_NAME}/{CSV_FILE_NAME}",
"output": GCS_OUTPUT},
+ location=LOCATION,
+ deferrable=True,
+ )
+ # [END howto_operator_start_template_job_deferrable]
+
+ # [START howto_operator_start_flex_template_job_deferrable]
+ start_flex_template_job_deferrable = DataflowStartFlexTemplateOperator(
+ task_id="start_flex_template_job_deferrable",
+ project_id=PROJECT_ID,
+ body=BODY,
+ location=LOCATION,
+ append_job_name=False,
+ deferrable=True,
+ )
+ # [END howto_operator_start_flex_template_job_deferrable]
+
delete_bucket = GCSDeleteBucketOperator(
task_id="delete_bucket", bucket_name=BUCKET_NAME,
trigger_rule=TriggerRule.ALL_DONE
)
- (
- create_bucket
- >> upload_file
- >> upload_schema
- >> start_template_job
- >> start_flex_template_job
- >> delete_bucket
+ chain(
+ create_bucket,
+ upload_file,
+ upload_schema,
+ [start_template_job, start_flex_template_job],
+ [start_template_job_deferrable, start_flex_template_job_deferrable],
+ delete_bucket,
)
from tests.system.utils.watcher import watcher