This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 050a47add8 Add `expected_terminal_state` parameter to Dataflow
operators (#34217)
050a47add8 is described below
commit 050a47add822cde6d9abcd609df59c98caae13b0
Author: Shahar Epstein <[email protected]>
AuthorDate: Mon Sep 11 13:55:54 2023 +0300
Add `expected_terminal_state` parameter to Dataflow operators (#34217)
---
airflow/providers/google/cloud/hooks/dataflow.py | 60 +++++++++++------
.../providers/google/cloud/operators/dataflow.py | 15 +++++
.../providers/google/cloud/hooks/test_dataflow.py | 76 +++++++++++++++++-----
.../google/cloud/operators/test_dataflow.py | 3 +
4 files changed, 120 insertions(+), 34 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/dataflow.py
b/airflow/providers/google/cloud/hooks/dataflow.py
index 819a674b82..3b0f3a2a8c 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -175,7 +175,7 @@ class _DataflowJobsController(LoggingMixin):
:param num_retries: Maximum number of retries in case of connection
problems.
:param multiple_jobs: If set to true this task will be searched by name
prefix (``name`` parameter),
not by specific job ID, then actions will be performed on all matching
jobs.
- :param drain_pipeline: Optional, set to True if want to stop streaming job
by draining it
+ :param drain_pipeline: Optional, set to True if we want to stop streaming
job by draining it
instead of canceling.
:param cancel_timeout: wait time in seconds for successful job canceling
:param wait_until_finished: If True, wait for the end of pipeline
execution before exiting. If False,
@@ -183,8 +183,8 @@ class _DataflowJobsController(LoggingMixin):
The default behavior depends on the type of pipeline:
- * for the streaming pipeline, wait for jobs to start,
- * for the batch pipeline, wait for the jobs to complete.
+ * for the streaming pipeline, wait for jobs to be in JOB_STATE_RUNNING,
+ * for the batch pipeline, wait for the jobs to be in JOB_STATE_DONE.
"""
def __init__(
@@ -200,6 +200,7 @@ class _DataflowJobsController(LoggingMixin):
drain_pipeline: bool = False,
cancel_timeout: int | None = 5 * 60,
wait_until_finished: bool | None = None,
+ expected_terminal_state: str | None = None,
) -> None:
super().__init__()
@@ -215,6 +216,7 @@ class _DataflowJobsController(LoggingMixin):
self._jobs: list[dict] | None = None
self.drain_pipeline = drain_pipeline
self._wait_until_finished = wait_until_finished
+ self._expected_terminal_state = expected_terminal_state
def is_job_running(self) -> bool:
"""
@@ -391,27 +393,44 @@ class _DataflowJobsController(LoggingMixin):
:return: True if job is done.
:raise: Exception
"""
- if self._wait_until_finished is None:
- wait_for_running = job.get("type") ==
DataflowJobType.JOB_TYPE_STREAMING
+ current_state = job["currentState"]
+ is_streaming = job.get("type") == DataflowJobType.JOB_TYPE_STREAMING
+
+ if self._expected_terminal_state is None:
+ if is_streaming:
+ self._expected_terminal_state =
DataflowJobStatus.JOB_STATE_RUNNING
+ else:
+ self._expected_terminal_state =
DataflowJobStatus.JOB_STATE_DONE
else:
- wait_for_running = not self._wait_until_finished
+ terminal_states = DataflowJobStatus.TERMINAL_STATES |
{DataflowJobStatus.JOB_STATE_RUNNING}
+ if self._expected_terminal_state not in terminal_states:
+ raise Exception(
+ f"Google Cloud Dataflow job's expected terminal state "
+ f"'{self._expected_terminal_state}' is invalid."
+ f" The value should be any of the following:
{terminal_states}"
+ )
+ elif is_streaming and self._expected_terminal_state ==
DataflowJobStatus.JOB_STATE_DONE:
+ raise Exception(
+ "Google Cloud Dataflow job's expected terminal state
cannot be "
+ "JOB_STATE_DONE while it is a streaming job"
+ )
+ elif not is_streaming and self._expected_terminal_state ==
DataflowJobStatus.JOB_STATE_DRAINED:
+ raise Exception(
+ "Google Cloud Dataflow job's expected terminal state
cannot be "
+ "JOB_STATE_DRAINED while it is a batch job"
+ )
- if job["currentState"] == DataflowJobStatus.JOB_STATE_DONE:
+ if not self._wait_until_finished and current_state ==
self._expected_terminal_state:
return True
- elif job["currentState"] == DataflowJobStatus.JOB_STATE_FAILED:
- raise Exception(f"Google Cloud Dataflow job {job['name']} has
failed.")
- elif job["currentState"] == DataflowJobStatus.JOB_STATE_CANCELLED:
- raise Exception(f"Google Cloud Dataflow job {job['name']} was
cancelled.")
- elif job["currentState"] == DataflowJobStatus.JOB_STATE_DRAINED:
- raise Exception(f"Google Cloud Dataflow job {job['name']} was
drained.")
- elif job["currentState"] == DataflowJobStatus.JOB_STATE_UPDATED:
- raise Exception(f"Google Cloud Dataflow job {job['name']} was
updated.")
- elif job["currentState"] == DataflowJobStatus.JOB_STATE_RUNNING and
wait_for_running:
- return True
- elif job["currentState"] in DataflowJobStatus.AWAITING_STATES:
+
+ if current_state in DataflowJobStatus.AWAITING_STATES:
return self._wait_until_finished is False
+
self.log.debug("Current job: %s", str(job))
- raise Exception(f"Google Cloud Dataflow job {job['name']} was unknown
state: {job['currentState']}")
+ raise Exception(
+ f"Google Cloud Dataflow job {job['name']} is in an unexpected
terminal state: {current_state}, "
+ f"expected terminal state: {self._expected_terminal_state}"
+ )
def wait_for_done(self) -> None:
"""Helper method to wait for result of submitted job."""
@@ -514,6 +533,7 @@ class DataflowHook(GoogleBaseHook):
drain_pipeline: bool = False,
cancel_timeout: int | None = 5 * 60,
wait_until_finished: bool | None = None,
+ expected_terminal_state: str | None = None,
**kwargs,
) -> None:
if kwargs.get("delegate_to") is not None:
@@ -527,6 +547,7 @@ class DataflowHook(GoogleBaseHook):
self.wait_until_finished = wait_until_finished
self.job_id: str | None = None
self.beam_hook = BeamHook(BeamRunnerType.DataflowRunner)
+ self.expected_terminal_state = expected_terminal_state
super().__init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
@@ -691,6 +712,7 @@ class DataflowHook(GoogleBaseHook):
drain_pipeline=self.drain_pipeline,
cancel_timeout=self.cancel_timeout,
wait_until_finished=self.wait_until_finished,
+ expected_terminal_state=self.expected_terminal_state,
)
jobs_controller.wait_for_done()
return response["job"]
diff --git a/airflow/providers/google/cloud/operators/dataflow.py
b/airflow/providers/google/cloud/operators/dataflow.py
index c9daa16966..bda2c1721f 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -288,6 +288,8 @@ class
DataflowCreateJavaJobOperator(GoogleCloudBaseOperator):
If you in your pipeline do not call the wait_for_pipeline method, and
pass wait_until_finish=False
to the operator, the second loop will check once is job not in
terminal state and exit the loop.
+ :param expected_terminal_state: The expected terminal state of the
operator on which the corresponding
+ Airflow task succeeds. When not specified, it will be determined by
the hook.
Note that both
``dataflow_default_options`` and ``options`` will be merged to specify
pipeline
@@ -349,6 +351,7 @@ class
DataflowCreateJavaJobOperator(GoogleCloudBaseOperator):
multiple_jobs: bool = False,
cancel_timeout: int | None = 10 * 60,
wait_until_finished: bool | None = None,
+ expected_terminal_state: str | None = None,
**kwargs,
) -> None:
# TODO: Remove one day
@@ -378,6 +381,7 @@ class
DataflowCreateJavaJobOperator(GoogleCloudBaseOperator):
self.check_if_running = check_if_running
self.cancel_timeout = cancel_timeout
self.wait_until_finished = wait_until_finished
+ self.expected_terminal_state = expected_terminal_state
self.job_id = None
self.beam_hook: BeamHook | None = None
self.dataflow_hook: DataflowHook | None = None
@@ -390,6 +394,7 @@ class
DataflowCreateJavaJobOperator(GoogleCloudBaseOperator):
poll_sleep=self.poll_sleep,
cancel_timeout=self.cancel_timeout,
wait_until_finished=self.wait_until_finished,
+ expected_terminal_state=self.expected_terminal_state,
)
job_name =
self.dataflow_hook.build_dataflow_job_name(job_name=self.job_name)
pipeline_options = copy.deepcopy(self.dataflow_default_options)
@@ -531,6 +536,8 @@ class
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
If you in your pipeline do not call the wait_for_pipeline method, and
pass wait_until_finish=False
to the operator, the second loop will check once is job not in
terminal state and exit the loop.
+ :param expected_terminal_state: The expected terminal state of the
operator on which the corresponding
+ Airflow task succeeds. When not specified, it will be determined by
the hook.
It's a good practice to define dataflow_* parameters in the default_args
of the dag
like the project, zone and staging location.
@@ -614,6 +621,7 @@ class
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
wait_until_finished: bool | None = None,
append_job_name: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ expected_terminal_state: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -633,6 +641,7 @@ class
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
self.wait_until_finished = wait_until_finished
self.append_job_name = append_job_name
self.deferrable = deferrable
+ self.expected_terminal_state = expected_terminal_state
self.job: dict | None = None
@@ -657,6 +666,7 @@ class
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
impersonation_chain=self.impersonation_chain,
cancel_timeout=self.cancel_timeout,
wait_until_finished=self.wait_until_finished,
+ expected_terminal_state=self.expected_terminal_state,
)
return hook
@@ -787,6 +797,8 @@ class
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
Service Account Token Creator IAM role to the directly preceding
identity, with first
account from the list granting this role to the originating account
(templated).
:param deferrable: Run operator in the deferrable mode.
+ :param expected_terminal_state: The expected final status of the operator
on which the corresponding
+ Airflow task succeeds. When not specified, it will be determined by
the hook.
:param append_job_name: True if unique suffix has to be appended to job
name.
"""
@@ -805,6 +817,7 @@ class
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
append_job_name: bool = True,
+ expected_terminal_state: str | None = None,
*args,
**kwargs,
) -> None:
@@ -819,6 +832,7 @@ class
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
self.job: dict | None = None
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
+ self.expected_terminal_state = expected_terminal_state
self.append_job_name = append_job_name
self._validate_deferrable_params()
@@ -842,6 +856,7 @@ class
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
cancel_timeout=self.cancel_timeout,
wait_until_finished=self.wait_until_finished,
impersonation_chain=self.impersonation_chain,
+ expected_terminal_state=self.expected_terminal_state,
)
return hook
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py
b/tests/providers/google/cloud/hooks/test_dataflow.py
index 0738f533f5..92e5898458 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -861,6 +861,7 @@ class TestDataflowTemplateHook:
project_number=TEST_PROJECT,
location=DEFAULT_DATAFLOW_LOCATION,
drain_pipeline=False,
+ expected_terminal_state=None,
cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
wait_until_finished=None,
)
@@ -900,6 +901,7 @@ class TestDataflowTemplateHook:
project_number=TEST_PROJECT,
location=TEST_LOCATION,
drain_pipeline=False,
+ expected_terminal_state=None,
cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
wait_until_finished=None,
)
@@ -943,6 +945,7 @@ class TestDataflowTemplateHook:
drain_pipeline=False,
cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
wait_until_finished=None,
+ expected_terminal_state=None,
)
mock_controller.return_value.wait_for_done.assert_called_once()
@@ -986,6 +989,7 @@ class TestDataflowTemplateHook:
drain_pipeline=False,
cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
wait_until_finished=None,
+ expected_terminal_state=None,
)
mock_uuid.assert_called_once_with()
@@ -1033,6 +1037,7 @@ class TestDataflowTemplateHook:
drain_pipeline=False,
cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
wait_until_finished=None,
+ expected_terminal_state=None,
)
mock_uuid.assert_called_once_with()
@@ -1232,13 +1237,13 @@ class TestDataflowJob:
@pytest.mark.parametrize(
"state, exception_regex",
[
- (DataflowJobStatus.JOB_STATE_FAILED, "Google Cloud Dataflow job
name-2 has failed\\."),
- (DataflowJobStatus.JOB_STATE_CANCELLED, "Google Cloud Dataflow job
name-2 was cancelled\\."),
- (DataflowJobStatus.JOB_STATE_DRAINED, "Google Cloud Dataflow job
name-2 was drained\\."),
- (DataflowJobStatus.JOB_STATE_UPDATED, "Google Cloud Dataflow job
name-2 was updated\\."),
+ (DataflowJobStatus.JOB_STATE_FAILED, "unexpected terminal state:
JOB_STATE_FAILED"),
+ (DataflowJobStatus.JOB_STATE_CANCELLED, "unexpected terminal
state: JOB_STATE_CANCELLED"),
+ (DataflowJobStatus.JOB_STATE_DRAINED, "unexpected terminal state:
JOB_STATE_DRAINED"),
+ (DataflowJobStatus.JOB_STATE_UPDATED, "unexpected terminal state:
JOB_STATE_UPDATED"),
(
DataflowJobStatus.JOB_STATE_UNKNOWN,
- "Google Cloud Dataflow job name-2 was unknown state:
JOB_STATE_UNKNOWN",
+ "JOB_STATE_UNKNOWN",
),
],
)
@@ -1446,52 +1451,52 @@ class TestDataflowJob:
(
DataflowJobType.JOB_TYPE_BATCH,
DataflowJobStatus.JOB_STATE_FAILED,
- "Google Cloud Dataflow job name-2 has failed\\.",
+ "JOB_STATE_FAILED",
),
(
DataflowJobType.JOB_TYPE_STREAMING,
DataflowJobStatus.JOB_STATE_FAILED,
- "Google Cloud Dataflow job name-2 has failed\\.",
+ "JOB_STATE_FAILED",
),
(
DataflowJobType.JOB_TYPE_STREAMING,
DataflowJobStatus.JOB_STATE_UNKNOWN,
- "Google Cloud Dataflow job name-2 was unknown state:
JOB_STATE_UNKNOWN",
+ "JOB_STATE_UNKNOWN",
),
(
DataflowJobType.JOB_TYPE_BATCH,
DataflowJobStatus.JOB_STATE_UNKNOWN,
- "Google Cloud Dataflow job name-2 was unknown state:
JOB_STATE_UNKNOWN",
+ "JOB_STATE_UNKNOWN",
),
(
DataflowJobType.JOB_TYPE_BATCH,
DataflowJobStatus.JOB_STATE_CANCELLED,
- "Google Cloud Dataflow job name-2 was cancelled\\.",
+ "JOB_STATE_CANCELLED",
),
(
DataflowJobType.JOB_TYPE_STREAMING,
DataflowJobStatus.JOB_STATE_CANCELLED,
- "Google Cloud Dataflow job name-2 was cancelled\\.",
+ "JOB_STATE_CANCELLED",
),
(
DataflowJobType.JOB_TYPE_BATCH,
DataflowJobStatus.JOB_STATE_DRAINED,
- "Google Cloud Dataflow job name-2 was drained\\.",
+ "JOB_STATE_DRAINED",
),
(
DataflowJobType.JOB_TYPE_STREAMING,
DataflowJobStatus.JOB_STATE_DRAINED,
- "Google Cloud Dataflow job name-2 was drained\\.",
+ "JOB_STATE_DRAINED",
),
(
DataflowJobType.JOB_TYPE_BATCH,
DataflowJobStatus.JOB_STATE_UPDATED,
- "Google Cloud Dataflow job name-2 was updated\\.",
+ "JOB_STATE_UPDATED",
),
(
DataflowJobType.JOB_TYPE_STREAMING,
DataflowJobStatus.JOB_STATE_UPDATED,
- "Google Cloud Dataflow job name-2 was updated\\.",
+ "JOB_STATE_UPDATED",
),
],
)
@@ -1510,6 +1515,47 @@ class TestDataflowJob:
with pytest.raises(Exception, match=exception_regex):
dataflow_job._check_dataflow_job_state(job)
+ @pytest.mark.parametrize(
+ "job_type, expected_terminal_state, match",
+ [
+ (
+ DataflowJobType.JOB_TYPE_BATCH,
+ "test",
+ "invalid",
+ ),
+ (
+ DataflowJobType.JOB_TYPE_STREAMING,
+ DataflowJobStatus.JOB_STATE_DONE,
+ "cannot be JOB_STATE_DONE while it is a streaming job",
+ ),
+ (
+ DataflowJobType.JOB_TYPE_BATCH,
+ DataflowJobStatus.JOB_STATE_DRAINED,
+ "cannot be JOB_STATE_DRAINED while it is a batch job",
+ ),
+ ],
+ )
+ def test_check_dataflow_job_state__invalid_expected_state(self, job_type,
expected_terminal_state, match):
+ job = {
+ "id": "id-2",
+ "name": "name-2",
+ "type": job_type,
+ "currentState": DataflowJobStatus.JOB_STATE_QUEUED,
+ }
+ dataflow_job = _DataflowJobsController(
+ dataflow=self.mock_dataflow,
+ project_number=TEST_PROJECT,
+ name=UNIQUE_JOB_NAME,
+ location=TEST_LOCATION,
+ poll_sleep=0,
+ job_id=TEST_JOB_ID,
+ num_retries=20,
+ multiple_jobs=False,
+ expected_terminal_state=expected_terminal_state,
+ )
+ with pytest.raises(Exception, match=match):
+ dataflow_job._check_dataflow_job_state(job)
+
def test_dataflow_job_cancel_job(self):
mock_jobs =
self.mock_dataflow.projects.return_value.locations.return_value.jobs
get_method = mock_jobs.return_value.get
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py
b/tests/providers/google/cloud/operators/test_dataflow.py
index 99cc094612..301a69962e 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -24,6 +24,7 @@ from unittest import mock
import pytest as pytest
import airflow
+from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus
from airflow.providers.google.cloud.operators.dataflow import (
CheckJobRunning,
DataflowCreateJavaJobOperator,
@@ -581,6 +582,7 @@ class TestDataflowStartFlexTemplateOperator:
do_xcom_push=True,
project_id=TEST_PROJECT,
location=TEST_LOCATION,
+ expected_terminal_state=DataflowJobStatus.JOB_STATE_DONE,
)
@pytest.fixture
@@ -600,6 +602,7 @@ class TestDataflowStartFlexTemplateOperator:
mock_dataflow.assert_called_once_with(
gcp_conn_id="google_cloud_default",
drain_pipeline=False,
+ expected_terminal_state=DataflowJobStatus.JOB_STATE_DONE,
cancel_timeout=600,
wait_until_finished=None,
impersonation_chain=None,