This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 050a47add8 Add `expected_terminal_state` parameter to Dataflow 
operators (#34217)
050a47add8 is described below

commit 050a47add822cde6d9abcd609df59c98caae13b0
Author: Shahar Epstein <[email protected]>
AuthorDate: Mon Sep 11 13:55:54 2023 +0300

    Add `expected_terminal_state` parameter to Dataflow operators (#34217)
---
 airflow/providers/google/cloud/hooks/dataflow.py   | 60 +++++++++++------
 .../providers/google/cloud/operators/dataflow.py   | 15 +++++
 .../providers/google/cloud/hooks/test_dataflow.py  | 76 +++++++++++++++++-----
 .../google/cloud/operators/test_dataflow.py        |  3 +
 4 files changed, 120 insertions(+), 34 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/dataflow.py 
b/airflow/providers/google/cloud/hooks/dataflow.py
index 819a674b82..3b0f3a2a8c 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -175,7 +175,7 @@ class _DataflowJobsController(LoggingMixin):
     :param num_retries: Maximum number of retries in case of connection 
problems.
     :param multiple_jobs: If set to true this task will be searched by name 
prefix (``name`` parameter),
         not by specific job ID, then actions will be performed on all matching 
jobs.
-    :param drain_pipeline: Optional, set to True if want to stop streaming job 
by draining it
+    :param drain_pipeline: Optional, set to True if we want to stop streaming 
job by draining it
         instead of canceling.
     :param cancel_timeout: wait time in seconds for successful job canceling
     :param wait_until_finished: If True, wait for the end of pipeline 
execution before exiting. If False,
@@ -183,8 +183,8 @@ class _DataflowJobsController(LoggingMixin):
 
         The default behavior depends on the type of pipeline:
 
-        * for the streaming pipeline, wait for jobs to start,
-        * for the batch pipeline, wait for the jobs to complete.
+        * for the streaming pipeline, wait for jobs to be in JOB_STATE_RUNNING,
+        * for the batch pipeline, wait for the jobs to be in JOB_STATE_DONE.
     """
 
     def __init__(
@@ -200,6 +200,7 @@ class _DataflowJobsController(LoggingMixin):
         drain_pipeline: bool = False,
         cancel_timeout: int | None = 5 * 60,
         wait_until_finished: bool | None = None,
+        expected_terminal_state: str | None = None,
     ) -> None:
 
         super().__init__()
@@ -215,6 +216,7 @@ class _DataflowJobsController(LoggingMixin):
         self._jobs: list[dict] | None = None
         self.drain_pipeline = drain_pipeline
         self._wait_until_finished = wait_until_finished
+        self._expected_terminal_state = expected_terminal_state
 
     def is_job_running(self) -> bool:
         """
@@ -391,27 +393,44 @@ class _DataflowJobsController(LoggingMixin):
         :return: True if job is done.
         :raise: Exception
         """
-        if self._wait_until_finished is None:
-            wait_for_running = job.get("type") == 
DataflowJobType.JOB_TYPE_STREAMING
+        current_state = job["currentState"]
+        is_streaming = job.get("type") == DataflowJobType.JOB_TYPE_STREAMING
+
+        if self._expected_terminal_state is None:
+            if is_streaming:
+                self._expected_terminal_state = 
DataflowJobStatus.JOB_STATE_RUNNING
+            else:
+                self._expected_terminal_state = 
DataflowJobStatus.JOB_STATE_DONE
         else:
-            wait_for_running = not self._wait_until_finished
+            terminal_states = DataflowJobStatus.TERMINAL_STATES | 
{DataflowJobStatus.JOB_STATE_RUNNING}
+            if self._expected_terminal_state not in terminal_states:
+                raise Exception(
+                    f"Google Cloud Dataflow job's expected terminal state "
+                    f"'{self._expected_terminal_state}' is invalid."
+                    f" The value should be any of the following: 
{terminal_states}"
+                )
+            elif is_streaming and self._expected_terminal_state == 
DataflowJobStatus.JOB_STATE_DONE:
+                raise Exception(
+                    "Google Cloud Dataflow job's expected terminal state 
cannot be "
+                    "JOB_STATE_DONE while it is a streaming job"
+                )
+            elif not is_streaming and self._expected_terminal_state == 
DataflowJobStatus.JOB_STATE_DRAINED:
+                raise Exception(
+                    "Google Cloud Dataflow job's expected terminal state 
cannot be "
+                    "JOB_STATE_DRAINED while it is a batch job"
+                )
 
-        if job["currentState"] == DataflowJobStatus.JOB_STATE_DONE:
+        if not self._wait_until_finished and current_state == 
self._expected_terminal_state:
             return True
-        elif job["currentState"] == DataflowJobStatus.JOB_STATE_FAILED:
-            raise Exception(f"Google Cloud Dataflow job {job['name']} has 
failed.")
-        elif job["currentState"] == DataflowJobStatus.JOB_STATE_CANCELLED:
-            raise Exception(f"Google Cloud Dataflow job {job['name']} was 
cancelled.")
-        elif job["currentState"] == DataflowJobStatus.JOB_STATE_DRAINED:
-            raise Exception(f"Google Cloud Dataflow job {job['name']} was 
drained.")
-        elif job["currentState"] == DataflowJobStatus.JOB_STATE_UPDATED:
-            raise Exception(f"Google Cloud Dataflow job {job['name']} was 
updated.")
-        elif job["currentState"] == DataflowJobStatus.JOB_STATE_RUNNING and 
wait_for_running:
-            return True
-        elif job["currentState"] in DataflowJobStatus.AWAITING_STATES:
+
+        if current_state in DataflowJobStatus.AWAITING_STATES:
             return self._wait_until_finished is False
+
         self.log.debug("Current job: %s", str(job))
-        raise Exception(f"Google Cloud Dataflow job {job['name']} was unknown 
state: {job['currentState']}")
+        raise Exception(
+            f"Google Cloud Dataflow job {job['name']} is in an unexpected 
terminal state: {current_state}, "
+            f"expected terminal state: {self._expected_terminal_state}"
+        )
 
     def wait_for_done(self) -> None:
         """Helper method to wait for result of submitted job."""
@@ -514,6 +533,7 @@ class DataflowHook(GoogleBaseHook):
         drain_pipeline: bool = False,
         cancel_timeout: int | None = 5 * 60,
         wait_until_finished: bool | None = None,
+        expected_terminal_state: str | None = None,
         **kwargs,
     ) -> None:
         if kwargs.get("delegate_to") is not None:
@@ -527,6 +547,7 @@ class DataflowHook(GoogleBaseHook):
         self.wait_until_finished = wait_until_finished
         self.job_id: str | None = None
         self.beam_hook = BeamHook(BeamRunnerType.DataflowRunner)
+        self.expected_terminal_state = expected_terminal_state
         super().__init__(
             gcp_conn_id=gcp_conn_id,
             impersonation_chain=impersonation_chain,
@@ -691,6 +712,7 @@ class DataflowHook(GoogleBaseHook):
             drain_pipeline=self.drain_pipeline,
             cancel_timeout=self.cancel_timeout,
             wait_until_finished=self.wait_until_finished,
+            expected_terminal_state=self.expected_terminal_state,
         )
         jobs_controller.wait_for_done()
         return response["job"]
diff --git a/airflow/providers/google/cloud/operators/dataflow.py 
b/airflow/providers/google/cloud/operators/dataflow.py
index c9daa16966..bda2c1721f 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -288,6 +288,8 @@ class 
DataflowCreateJavaJobOperator(GoogleCloudBaseOperator):
 
         If you in your pipeline do not call the wait_for_pipeline method, and 
pass wait_until_finish=False
         to the operator, the second loop will check once is job not in 
terminal state and exit the loop.
+    :param expected_terminal_state: The expected terminal state of the 
operator on which the corresponding
+        Airflow task succeeds. When not specified, it will be determined by 
the hook.
 
     Note that both
     ``dataflow_default_options`` and ``options`` will be merged to specify 
pipeline
@@ -349,6 +351,7 @@ class 
DataflowCreateJavaJobOperator(GoogleCloudBaseOperator):
         multiple_jobs: bool = False,
         cancel_timeout: int | None = 10 * 60,
         wait_until_finished: bool | None = None,
+        expected_terminal_state: str | None = None,
         **kwargs,
     ) -> None:
         # TODO: Remove one day
@@ -378,6 +381,7 @@ class 
DataflowCreateJavaJobOperator(GoogleCloudBaseOperator):
         self.check_if_running = check_if_running
         self.cancel_timeout = cancel_timeout
         self.wait_until_finished = wait_until_finished
+        self.expected_terminal_state = expected_terminal_state
         self.job_id = None
         self.beam_hook: BeamHook | None = None
         self.dataflow_hook: DataflowHook | None = None
@@ -390,6 +394,7 @@ class 
DataflowCreateJavaJobOperator(GoogleCloudBaseOperator):
             poll_sleep=self.poll_sleep,
             cancel_timeout=self.cancel_timeout,
             wait_until_finished=self.wait_until_finished,
+            expected_terminal_state=self.expected_terminal_state,
         )
         job_name = 
self.dataflow_hook.build_dataflow_job_name(job_name=self.job_name)
         pipeline_options = copy.deepcopy(self.dataflow_default_options)
@@ -531,6 +536,8 @@ class 
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
 
         If you in your pipeline do not call the wait_for_pipeline method, and 
pass wait_until_finish=False
         to the operator, the second loop will check once is job not in 
terminal state and exit the loop.
+    :param expected_terminal_state: The expected terminal state of the 
operator on which the corresponding
+        Airflow task succeeds. When not specified, it will be determined by 
the hook.
 
     It's a good practice to define dataflow_* parameters in the default_args 
of the dag
     like the project, zone and staging location.
@@ -614,6 +621,7 @@ class 
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
         wait_until_finished: bool | None = None,
         append_job_name: bool = True,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        expected_terminal_state: str | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -633,6 +641,7 @@ class 
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
         self.wait_until_finished = wait_until_finished
         self.append_job_name = append_job_name
         self.deferrable = deferrable
+        self.expected_terminal_state = expected_terminal_state
 
         self.job: dict | None = None
 
@@ -657,6 +666,7 @@ class 
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
             impersonation_chain=self.impersonation_chain,
             cancel_timeout=self.cancel_timeout,
             wait_until_finished=self.wait_until_finished,
+            expected_terminal_state=self.expected_terminal_state,
         )
         return hook
 
@@ -787,6 +797,8 @@ class 
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account 
(templated).
     :param deferrable: Run operator in the deferrable mode.
+    :param expected_terminal_state: The expected final status of the operator 
on which the corresponding
+        Airflow task succeeds. When not specified, it will be determined by 
the hook.
     :param append_job_name: True if unique suffix has to be appended to job 
name.
     """
 
@@ -805,6 +817,7 @@ class 
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
         impersonation_chain: str | Sequence[str] | None = None,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         append_job_name: bool = True,
+        expected_terminal_state: str | None = None,
         *args,
         **kwargs,
     ) -> None:
@@ -819,6 +832,7 @@ class 
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
         self.job: dict | None = None
         self.impersonation_chain = impersonation_chain
         self.deferrable = deferrable
+        self.expected_terminal_state = expected_terminal_state
         self.append_job_name = append_job_name
 
         self._validate_deferrable_params()
@@ -842,6 +856,7 @@ class 
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
             cancel_timeout=self.cancel_timeout,
             wait_until_finished=self.wait_until_finished,
             impersonation_chain=self.impersonation_chain,
+            expected_terminal_state=self.expected_terminal_state,
         )
         return hook
 
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py 
b/tests/providers/google/cloud/hooks/test_dataflow.py
index 0738f533f5..92e5898458 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -861,6 +861,7 @@ class TestDataflowTemplateHook:
             project_number=TEST_PROJECT,
             location=DEFAULT_DATAFLOW_LOCATION,
             drain_pipeline=False,
+            expected_terminal_state=None,
             cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
             wait_until_finished=None,
         )
@@ -900,6 +901,7 @@ class TestDataflowTemplateHook:
             project_number=TEST_PROJECT,
             location=TEST_LOCATION,
             drain_pipeline=False,
+            expected_terminal_state=None,
             cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
             wait_until_finished=None,
         )
@@ -943,6 +945,7 @@ class TestDataflowTemplateHook:
             drain_pipeline=False,
             cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
             wait_until_finished=None,
+            expected_terminal_state=None,
         )
         mock_controller.return_value.wait_for_done.assert_called_once()
 
@@ -986,6 +989,7 @@ class TestDataflowTemplateHook:
             drain_pipeline=False,
             cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
             wait_until_finished=None,
+            expected_terminal_state=None,
         )
         mock_uuid.assert_called_once_with()
 
@@ -1033,6 +1037,7 @@ class TestDataflowTemplateHook:
             drain_pipeline=False,
             cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
             wait_until_finished=None,
+            expected_terminal_state=None,
         )
         mock_uuid.assert_called_once_with()
 
@@ -1232,13 +1237,13 @@ class TestDataflowJob:
     @pytest.mark.parametrize(
         "state, exception_regex",
         [
-            (DataflowJobStatus.JOB_STATE_FAILED, "Google Cloud Dataflow job 
name-2 has failed\\."),
-            (DataflowJobStatus.JOB_STATE_CANCELLED, "Google Cloud Dataflow job 
name-2 was cancelled\\."),
-            (DataflowJobStatus.JOB_STATE_DRAINED, "Google Cloud Dataflow job 
name-2 was drained\\."),
-            (DataflowJobStatus.JOB_STATE_UPDATED, "Google Cloud Dataflow job 
name-2 was updated\\."),
+            (DataflowJobStatus.JOB_STATE_FAILED, "unexpected terminal state: 
JOB_STATE_FAILED"),
+            (DataflowJobStatus.JOB_STATE_CANCELLED, "unexpected terminal 
state: JOB_STATE_CANCELLED"),
+            (DataflowJobStatus.JOB_STATE_DRAINED, "unexpected terminal state: 
JOB_STATE_DRAINED"),
+            (DataflowJobStatus.JOB_STATE_UPDATED, "unexpected terminal state: 
JOB_STATE_UPDATED"),
             (
                 DataflowJobStatus.JOB_STATE_UNKNOWN,
-                "Google Cloud Dataflow job name-2 was unknown state: 
JOB_STATE_UNKNOWN",
+                "JOB_STATE_UNKNOWN",
             ),
         ],
     )
@@ -1446,52 +1451,52 @@ class TestDataflowJob:
             (
                 DataflowJobType.JOB_TYPE_BATCH,
                 DataflowJobStatus.JOB_STATE_FAILED,
-                "Google Cloud Dataflow job name-2 has failed\\.",
+                "JOB_STATE_FAILED",
             ),
             (
                 DataflowJobType.JOB_TYPE_STREAMING,
                 DataflowJobStatus.JOB_STATE_FAILED,
-                "Google Cloud Dataflow job name-2 has failed\\.",
+                "JOB_STATE_FAILED",
             ),
             (
                 DataflowJobType.JOB_TYPE_STREAMING,
                 DataflowJobStatus.JOB_STATE_UNKNOWN,
-                "Google Cloud Dataflow job name-2 was unknown state: 
JOB_STATE_UNKNOWN",
+                "JOB_STATE_UNKNOWN",
             ),
             (
                 DataflowJobType.JOB_TYPE_BATCH,
                 DataflowJobStatus.JOB_STATE_UNKNOWN,
-                "Google Cloud Dataflow job name-2 was unknown state: 
JOB_STATE_UNKNOWN",
+                "JOB_STATE_UNKNOWN",
             ),
             (
                 DataflowJobType.JOB_TYPE_BATCH,
                 DataflowJobStatus.JOB_STATE_CANCELLED,
-                "Google Cloud Dataflow job name-2 was cancelled\\.",
+                "JOB_STATE_CANCELLED",
             ),
             (
                 DataflowJobType.JOB_TYPE_STREAMING,
                 DataflowJobStatus.JOB_STATE_CANCELLED,
-                "Google Cloud Dataflow job name-2 was cancelled\\.",
+                "JOB_STATE_CANCELLED",
             ),
             (
                 DataflowJobType.JOB_TYPE_BATCH,
                 DataflowJobStatus.JOB_STATE_DRAINED,
-                "Google Cloud Dataflow job name-2 was drained\\.",
+                "JOB_STATE_DRAINED",
             ),
             (
                 DataflowJobType.JOB_TYPE_STREAMING,
                 DataflowJobStatus.JOB_STATE_DRAINED,
-                "Google Cloud Dataflow job name-2 was drained\\.",
+                "JOB_STATE_DRAINED",
             ),
             (
                 DataflowJobType.JOB_TYPE_BATCH,
                 DataflowJobStatus.JOB_STATE_UPDATED,
-                "Google Cloud Dataflow job name-2 was updated\\.",
+                "JOB_STATE_UPDATED",
             ),
             (
                 DataflowJobType.JOB_TYPE_STREAMING,
                 DataflowJobStatus.JOB_STATE_UPDATED,
-                "Google Cloud Dataflow job name-2 was updated\\.",
+                "JOB_STATE_UPDATED",
             ),
         ],
     )
@@ -1510,6 +1515,47 @@ class TestDataflowJob:
         with pytest.raises(Exception, match=exception_regex):
             dataflow_job._check_dataflow_job_state(job)
 
+    @pytest.mark.parametrize(
+        "job_type, expected_terminal_state, match",
+        [
+            (
+                DataflowJobType.JOB_TYPE_BATCH,
+                "test",
+                "invalid",
+            ),
+            (
+                DataflowJobType.JOB_TYPE_STREAMING,
+                DataflowJobStatus.JOB_STATE_DONE,
+                "cannot be JOB_STATE_DONE while it is a streaming job",
+            ),
+            (
+                DataflowJobType.JOB_TYPE_BATCH,
+                DataflowJobStatus.JOB_STATE_DRAINED,
+                "cannot be JOB_STATE_DRAINED while it is a batch job",
+            ),
+        ],
+    )
+    def test_check_dataflow_job_state__invalid_expected_state(self, job_type, 
expected_terminal_state, match):
+        job = {
+            "id": "id-2",
+            "name": "name-2",
+            "type": job_type,
+            "currentState": DataflowJobStatus.JOB_STATE_QUEUED,
+        }
+        dataflow_job = _DataflowJobsController(
+            dataflow=self.mock_dataflow,
+            project_number=TEST_PROJECT,
+            name=UNIQUE_JOB_NAME,
+            location=TEST_LOCATION,
+            poll_sleep=0,
+            job_id=TEST_JOB_ID,
+            num_retries=20,
+            multiple_jobs=False,
+            expected_terminal_state=expected_terminal_state,
+        )
+        with pytest.raises(Exception, match=match):
+            dataflow_job._check_dataflow_job_state(job)
+
     def test_dataflow_job_cancel_job(self):
         mock_jobs = 
self.mock_dataflow.projects.return_value.locations.return_value.jobs
         get_method = mock_jobs.return_value.get
diff --git a/tests/providers/google/cloud/operators/test_dataflow.py 
b/tests/providers/google/cloud/operators/test_dataflow.py
index 99cc094612..301a69962e 100644
--- a/tests/providers/google/cloud/operators/test_dataflow.py
+++ b/tests/providers/google/cloud/operators/test_dataflow.py
@@ -24,6 +24,7 @@ from unittest import mock
 import pytest as pytest
 
 import airflow
+from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus
 from airflow.providers.google.cloud.operators.dataflow import (
     CheckJobRunning,
     DataflowCreateJavaJobOperator,
@@ -581,6 +582,7 @@ class TestDataflowStartFlexTemplateOperator:
             do_xcom_push=True,
             project_id=TEST_PROJECT,
             location=TEST_LOCATION,
+            expected_terminal_state=DataflowJobStatus.JOB_STATE_DONE,
         )
 
     @pytest.fixture
@@ -600,6 +602,7 @@ class TestDataflowStartFlexTemplateOperator:
         mock_dataflow.assert_called_once_with(
             gcp_conn_id="google_cloud_default",
             drain_pipeline=False,
+            expected_terminal_state=DataflowJobStatus.JOB_STATE_DONE,
             cancel_timeout=600,
             wait_until_finished=None,
             impersonation_chain=None,

Reply via email to