This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 93488d09f9 DataflowStartFlexTemplateOperator. Check for Dataflow job 
type each check cycle. (#40584)
93488d09f9 is described below

commit 93488d09f9754d7652824f2fe43743d6f9cd4b08
Author: Oleksandr Tkachov <[email protected]>
AuthorDate: Thu Jul 4 14:07:09 2024 +0300

    DataflowStartFlexTemplateOperator. Check for Dataflow job type each check 
cycle. (#40584)
    
    
    
    ---------
    
    Co-authored-by: Oleksandr Tkachov <[email protected]>
---
 airflow/providers/google/cloud/hooks/dataflow.py   | 50 +++++++--------
 .../providers/google/cloud/hooks/test_dataflow.py  | 72 ++++++++++++++++++++++
 2 files changed, 98 insertions(+), 24 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/dataflow.py 
b/airflow/providers/google/cloud/hooks/dataflow.py
index 8fc8ca9d59..7f684cb617 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -419,32 +419,34 @@ class _DataflowJobsController(LoggingMixin):
         current_state = job["currentState"]
         is_streaming = job.get("type") == DataflowJobType.JOB_TYPE_STREAMING
 
-        if self._expected_terminal_state is None:
+        current_expected_state = self._expected_terminal_state
+
+        if current_expected_state is None:
             if is_streaming:
-                self._expected_terminal_state = 
DataflowJobStatus.JOB_STATE_RUNNING
+                current_expected_state = DataflowJobStatus.JOB_STATE_RUNNING
             else:
-                self._expected_terminal_state = 
DataflowJobStatus.JOB_STATE_DONE
-        else:
-            terminal_states = DataflowJobStatus.TERMINAL_STATES | 
{DataflowJobStatus.JOB_STATE_RUNNING}
-            if self._expected_terminal_state not in terminal_states:
-                raise AirflowException(
-                    f"Google Cloud Dataflow job's expected terminal state "
-                    f"'{self._expected_terminal_state}' is invalid."
-                    f" The value should be any of the following: 
{terminal_states}"
-                )
-            elif is_streaming and self._expected_terminal_state == 
DataflowJobStatus.JOB_STATE_DONE:
-                raise AirflowException(
-                    "Google Cloud Dataflow job's expected terminal state 
cannot be "
-                    "JOB_STATE_DONE while it is a streaming job"
-                )
-            elif not is_streaming and self._expected_terminal_state == 
DataflowJobStatus.JOB_STATE_DRAINED:
-                raise AirflowException(
-                    "Google Cloud Dataflow job's expected terminal state 
cannot be "
-                    "JOB_STATE_DRAINED while it is a batch job"
-                )
+                current_expected_state = DataflowJobStatus.JOB_STATE_DONE
+
+        terminal_states = DataflowJobStatus.TERMINAL_STATES | 
{DataflowJobStatus.JOB_STATE_RUNNING}
+        if current_expected_state not in terminal_states:
+            raise AirflowException(
+                f"Google Cloud Dataflow job's expected terminal state "
+                f"'{current_expected_state}' is invalid."
+                f" The value should be any of the following: {terminal_states}"
+            )
+        elif is_streaming and current_expected_state == 
DataflowJobStatus.JOB_STATE_DONE:
+            raise AirflowException(
+                "Google Cloud Dataflow job's expected terminal state cannot be 
"
+                "JOB_STATE_DONE while it is a streaming job"
+            )
+        elif not is_streaming and current_expected_state == 
DataflowJobStatus.JOB_STATE_DRAINED:
+            raise AirflowException(
+                "Google Cloud Dataflow job's expected terminal state cannot be 
"
+                "JOB_STATE_DRAINED while it is a batch job"
+            )
 
-        if current_state == self._expected_terminal_state:
-            if self._expected_terminal_state == 
DataflowJobStatus.JOB_STATE_RUNNING:
+        if current_state == current_expected_state:
+            if current_expected_state == DataflowJobStatus.JOB_STATE_RUNNING:
                 return not self._wait_until_finished
             return True
 
@@ -454,7 +456,7 @@ class _DataflowJobsController(LoggingMixin):
         self.log.debug("Current job: %s", job)
         raise AirflowException(
             f"Google Cloud Dataflow job {job['name']} is in an unexpected 
terminal state: {current_state}, "
-            f"expected terminal state: {self._expected_terminal_state}"
+            f"expected terminal state: {current_expected_state}"
         )
 
     def wait_for_done(self) -> None:
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py 
b/tests/providers/google/cloud/hooks/test_dataflow.py
index 8733f2bf9a..3925136288 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -1498,6 +1498,78 @@ class TestDataflowJob:
         result = dataflow_job._check_dataflow_job_state(job)
         assert result == expected_result
 
+    @pytest.mark.parametrize(
+        "jobs, wait_until_finished, expected_result",
+        [
+            # STREAMING
+            (
+                [
+                    (None, DataflowJobStatus.JOB_STATE_QUEUED),
+                    (None, DataflowJobStatus.JOB_STATE_PENDING),
+                    (DataflowJobType.JOB_TYPE_STREAMING, 
DataflowJobStatus.JOB_STATE_RUNNING),
+                ],
+                None,
+                True,
+            ),
+            (
+                [
+                    (None, DataflowJobStatus.JOB_STATE_QUEUED),
+                    (None, DataflowJobStatus.JOB_STATE_PENDING),
+                    (DataflowJobType.JOB_TYPE_STREAMING, 
DataflowJobStatus.JOB_STATE_RUNNING),
+                ],
+                True,
+                False,
+            ),
+            # BATCH
+            (
+                [
+                    (None, DataflowJobStatus.JOB_STATE_QUEUED),
+                    (None, DataflowJobStatus.JOB_STATE_PENDING),
+                    (DataflowJobType.JOB_TYPE_BATCH, 
DataflowJobStatus.JOB_STATE_RUNNING),
+                ],
+                False,
+                True,
+            ),
+            (
+                [
+                    (None, DataflowJobStatus.JOB_STATE_QUEUED),
+                    (None, DataflowJobStatus.JOB_STATE_PENDING),
+                    (DataflowJobType.JOB_TYPE_BATCH, 
DataflowJobStatus.JOB_STATE_RUNNING),
+                ],
+                None,
+                False,
+            ),
+            (
+                [
+                    (None, DataflowJobStatus.JOB_STATE_QUEUED),
+                    (None, DataflowJobStatus.JOB_STATE_PENDING),
+                    (DataflowJobType.JOB_TYPE_BATCH, 
DataflowJobStatus.JOB_STATE_DONE),
+                ],
+                None,
+                True,
+            ),
+        ],
+    )
+    def 
test_check_dataflow_job_state_without_job_type_changed_on_terminal_state(
+        self, jobs, wait_until_finished, expected_result
+    ):
+        dataflow_job = _DataflowJobsController(
+            dataflow=self.mock_dataflow,
+            project_number=TEST_PROJECT,
+            name="name-",
+            location=TEST_LOCATION,
+            poll_sleep=0,
+            job_id=None,
+            num_retries=20,
+            multiple_jobs=True,
+            wait_until_finished=wait_until_finished,
+        )
+        result = False
+        for current_job in jobs:
+            job = {"id": "id-2", "name": "name-2", "type": current_job[0], 
"currentState": current_job[1]}
+            result = dataflow_job._check_dataflow_job_state(job)
+        assert result == expected_result
+
     @pytest.mark.parametrize(
         "job_state, wait_until_finished, expected_result",
         [

Reply via email to