This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 93488d09f9 DataflowStartFlexTemplateOperator. Check for Dataflow job
type each check cycle. (#40584)
93488d09f9 is described below
commit 93488d09f9754d7652824f2fe43743d6f9cd4b08
Author: Oleksandr Tkachov <[email protected]>
AuthorDate: Thu Jul 4 14:07:09 2024 +0300
DataflowStartFlexTemplateOperator. Check for Dataflow job type each check
cycle. (#40584)
---------
Co-authored-by: Oleksandr Tkachov <[email protected]>
---
airflow/providers/google/cloud/hooks/dataflow.py | 50 +++++++--------
.../providers/google/cloud/hooks/test_dataflow.py | 72 ++++++++++++++++++++++
2 files changed, 98 insertions(+), 24 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/dataflow.py
b/airflow/providers/google/cloud/hooks/dataflow.py
index 8fc8ca9d59..7f684cb617 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -419,32 +419,34 @@ class _DataflowJobsController(LoggingMixin):
current_state = job["currentState"]
is_streaming = job.get("type") == DataflowJobType.JOB_TYPE_STREAMING
- if self._expected_terminal_state is None:
+ current_expected_state = self._expected_terminal_state
+
+ if current_expected_state is None:
if is_streaming:
- self._expected_terminal_state =
DataflowJobStatus.JOB_STATE_RUNNING
+ current_expected_state = DataflowJobStatus.JOB_STATE_RUNNING
else:
- self._expected_terminal_state =
DataflowJobStatus.JOB_STATE_DONE
- else:
- terminal_states = DataflowJobStatus.TERMINAL_STATES |
{DataflowJobStatus.JOB_STATE_RUNNING}
- if self._expected_terminal_state not in terminal_states:
- raise AirflowException(
- f"Google Cloud Dataflow job's expected terminal state "
- f"'{self._expected_terminal_state}' is invalid."
- f" The value should be any of the following:
{terminal_states}"
- )
- elif is_streaming and self._expected_terminal_state ==
DataflowJobStatus.JOB_STATE_DONE:
- raise AirflowException(
- "Google Cloud Dataflow job's expected terminal state
cannot be "
- "JOB_STATE_DONE while it is a streaming job"
- )
- elif not is_streaming and self._expected_terminal_state ==
DataflowJobStatus.JOB_STATE_DRAINED:
- raise AirflowException(
- "Google Cloud Dataflow job's expected terminal state
cannot be "
- "JOB_STATE_DRAINED while it is a batch job"
- )
+ current_expected_state = DataflowJobStatus.JOB_STATE_DONE
+
+ terminal_states = DataflowJobStatus.TERMINAL_STATES |
{DataflowJobStatus.JOB_STATE_RUNNING}
+ if current_expected_state not in terminal_states:
+ raise AirflowException(
+ f"Google Cloud Dataflow job's expected terminal state "
+ f"'{current_expected_state}' is invalid."
+ f" The value should be any of the following: {terminal_states}"
+ )
+ elif is_streaming and current_expected_state ==
DataflowJobStatus.JOB_STATE_DONE:
+ raise AirflowException(
+ "Google Cloud Dataflow job's expected terminal state cannot be
"
+ "JOB_STATE_DONE while it is a streaming job"
+ )
+ elif not is_streaming and current_expected_state ==
DataflowJobStatus.JOB_STATE_DRAINED:
+ raise AirflowException(
+ "Google Cloud Dataflow job's expected terminal state cannot be
"
+ "JOB_STATE_DRAINED while it is a batch job"
+ )
- if current_state == self._expected_terminal_state:
- if self._expected_terminal_state ==
DataflowJobStatus.JOB_STATE_RUNNING:
+ if current_state == current_expected_state:
+ if current_expected_state == DataflowJobStatus.JOB_STATE_RUNNING:
return not self._wait_until_finished
return True
@@ -454,7 +456,7 @@ class _DataflowJobsController(LoggingMixin):
self.log.debug("Current job: %s", job)
raise AirflowException(
f"Google Cloud Dataflow job {job['name']} is in an unexpected
terminal state: {current_state}, "
- f"expected terminal state: {self._expected_terminal_state}"
+ f"expected terminal state: {current_expected_state}"
)
def wait_for_done(self) -> None:
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py
b/tests/providers/google/cloud/hooks/test_dataflow.py
index 8733f2bf9a..3925136288 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -1498,6 +1498,78 @@ class TestDataflowJob:
result = dataflow_job._check_dataflow_job_state(job)
assert result == expected_result
+ @pytest.mark.parametrize(
+ "jobs, wait_until_finished, expected_result",
+ [
+ # STREAMING
+ (
+ [
+ (None, DataflowJobStatus.JOB_STATE_QUEUED),
+ (None, DataflowJobStatus.JOB_STATE_PENDING),
+ (DataflowJobType.JOB_TYPE_STREAMING,
DataflowJobStatus.JOB_STATE_RUNNING),
+ ],
+ None,
+ True,
+ ),
+ (
+ [
+ (None, DataflowJobStatus.JOB_STATE_QUEUED),
+ (None, DataflowJobStatus.JOB_STATE_PENDING),
+ (DataflowJobType.JOB_TYPE_STREAMING,
DataflowJobStatus.JOB_STATE_RUNNING),
+ ],
+ True,
+ False,
+ ),
+ # BATCH
+ (
+ [
+ (None, DataflowJobStatus.JOB_STATE_QUEUED),
+ (None, DataflowJobStatus.JOB_STATE_PENDING),
+ (DataflowJobType.JOB_TYPE_BATCH,
DataflowJobStatus.JOB_STATE_RUNNING),
+ ],
+ False,
+ True,
+ ),
+ (
+ [
+ (None, DataflowJobStatus.JOB_STATE_QUEUED),
+ (None, DataflowJobStatus.JOB_STATE_PENDING),
+ (DataflowJobType.JOB_TYPE_BATCH,
DataflowJobStatus.JOB_STATE_RUNNING),
+ ],
+ None,
+ False,
+ ),
+ (
+ [
+ (None, DataflowJobStatus.JOB_STATE_QUEUED),
+ (None, DataflowJobStatus.JOB_STATE_PENDING),
+ (DataflowJobType.JOB_TYPE_BATCH,
DataflowJobStatus.JOB_STATE_DONE),
+ ],
+ None,
+ True,
+ ),
+ ],
+ )
+ def
test_check_dataflow_job_state_without_job_type_changed_on_terminal_state(
+ self, jobs, wait_until_finished, expected_result
+ ):
+ dataflow_job = _DataflowJobsController(
+ dataflow=self.mock_dataflow,
+ project_number=TEST_PROJECT,
+ name="name-",
+ location=TEST_LOCATION,
+ poll_sleep=0,
+ job_id=None,
+ num_retries=20,
+ multiple_jobs=True,
+ wait_until_finished=wait_until_finished,
+ )
+ result = False
+ for current_job in jobs:
+ job = {"id": "id-2", "name": "name-2", "type": current_job[0],
"currentState": current_job[1]}
+ result = dataflow_job._check_dataflow_job_state(job)
+ assert result == expected_result
+
@pytest.mark.parametrize(
"job_state, wait_until_finished, expected_result",
[