This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 0caec9f  Dataflow - add waiting for successful job cancel (#11501)
0caec9f is described below

commit 0caec9fd32bee2b3036b5d7bdcb56bd6a3b9dccf
Author: Tobiasz Kędzierski <[email protected]>
AuthorDate: Fri Nov 6 15:14:14 2020 +0100

    Dataflow - add waiting for successful job cancel (#11501)
    
    Co-authored-by: Kamil Breguła <[email protected]>
---
 airflow/providers/google/cloud/hooks/dataflow.py   | 46 ++++++++++++++++--
 .../providers/google/cloud/operators/dataflow.py   | 26 +++++++++-
 .../providers/google/cloud/hooks/test_dataflow.py  | 56 ++++++++++++++++++++++
 .../google/cloud/operators/test_mlengine_utils.py  |  6 ++-
 4 files changed, 129 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/dataflow.py 
b/airflow/providers/google/cloud/hooks/dataflow.py
index e17001f..7467207 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -28,7 +28,7 @@ import uuid
 import warnings
 from copy import deepcopy
 from tempfile import TemporaryDirectory
-from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, 
Union, cast
+from typing import Any, Callable, Dict, List, Optional, Sequence, Set, 
TypeVar, Union, cast
 
 from googleapiclient.discovery import build
 
@@ -36,6 +36,7 @@ from airflow.exceptions import AirflowException
 from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.python_virtualenv import prepare_virtualenv
+from airflow.utils.timeout import timeout
 
 # This is the default location
 # https://cloud.google.com/dataflow/pipelines/specifying-exec-params
@@ -147,9 +148,10 @@ class _DataflowJobsController(LoggingMixin):
         not by specific job ID, then actions will be performed on all matching 
jobs.
     :param drain_pipeline: Optional, set to True if want to stop streaming job 
by draining it
         instead of canceling.
+    :param cancel_timeout: wait time in seconds for successful job canceling
     """
 
-    def __init__(
+    def __init__(  # pylint: disable=too-many-arguments
         self,
         dataflow: Any,
         project_number: str,
@@ -160,6 +162,7 @@ class _DataflowJobsController(LoggingMixin):
         num_retries: int = 0,
         multiple_jobs: bool = False,
         drain_pipeline: bool = False,
+        cancel_timeout: Optional[int] = 5 * 60,
     ) -> None:
 
         super().__init__()
@@ -171,8 +174,9 @@ class _DataflowJobsController(LoggingMixin):
         self._job_id = job_id
         self._num_retries = num_retries
         self._poll_sleep = poll_sleep
-        self.drain_pipeline = drain_pipeline
+        self._cancel_timeout = cancel_timeout
         self._jobs: Optional[List[dict]] = None
+        self.drain_pipeline = drain_pipeline
 
     def is_job_running(self) -> bool:
         """
@@ -317,6 +321,29 @@ class _DataflowJobsController(LoggingMixin):
 
         return self._jobs
 
+    def _wait_for_states(self, expected_states: Set[str]):
+        """Waiting for the jobs to reach a certain state."""
+        if not self._jobs:
+            raise ValueError("The _jobs should be set")
+        while True:
+            self._refresh_jobs()
+            job_states = {job['currentState'] for job in self._jobs}
+            if not job_states.difference(expected_states):
+                return
+            unexpected_failed_end_states = expected_states - 
DataflowJobStatus.FAILED_END_STATES
+            if unexpected_failed_end_states.intersection(job_states):
+                unexpected_failed_jobs = {
+                    job for job in self._jobs if job['currentState'] in 
unexpected_failed_end_states
+                }
+                raise AirflowException(
+                    "Jobs failed: "
+                    + ", ".join(
+                        f"ID: {job['id']} name: {job['name']} state: 
{job['currentState']}"
+                        for job in unexpected_failed_jobs
+                    )
+                )
+            time.sleep(self._poll_sleep)
+
     def cancel(self) -> None:
         """Cancels or drains current job"""
         jobs = self.get_jobs()
@@ -342,6 +369,12 @@ class _DataflowJobsController(LoggingMixin):
                     )
                 )
             batch.execute()
+            if self._cancel_timeout and isinstance(self._cancel_timeout, int):
+                timeout_error_message = "Canceling jobs failed due to timeout 
({}s): {}".format(
+                    self._cancel_timeout, ", ".join(job_ids)
+                )
+                with timeout(seconds=self._cancel_timeout, 
error_message=timeout_error_message):
+                    
self._wait_for_states({DataflowJobStatus.JOB_STATE_CANCELLED})
         else:
             self.log.info("No jobs to cancel")
 
@@ -453,9 +486,11 @@ class DataflowHook(GoogleBaseHook):
         poll_sleep: int = 10,
         impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
         drain_pipeline: bool = False,
+        cancel_timeout: Optional[int] = 5 * 60,
     ) -> None:
         self.poll_sleep = poll_sleep
         self.drain_pipeline = drain_pipeline
+        self.cancel_timeout = cancel_timeout
         super().__init__(
             gcp_conn_id=gcp_conn_id,
             delegate_to=delegate_to,
@@ -496,6 +531,7 @@ class DataflowHook(GoogleBaseHook):
             num_retries=self.num_retries,
             multiple_jobs=multiple_jobs,
             drain_pipeline=self.drain_pipeline,
+            cancel_timeout=self.cancel_timeout,
         )
         job_controller.wait_for_done()
 
@@ -669,6 +705,7 @@ class DataflowHook(GoogleBaseHook):
             poll_sleep=self.poll_sleep,
             num_retries=self.num_retries,
             drain_pipeline=self.drain_pipeline,
+            cancel_timeout=self.cancel_timeout,
         )
         jobs_controller.wait_for_done()
         return response["job"]
@@ -714,6 +751,7 @@ class DataflowHook(GoogleBaseHook):
             location=location,
             poll_sleep=self.poll_sleep,
             num_retries=self.num_retries,
+            cancel_timeout=self.cancel_timeout,
         )
         jobs_controller.wait_for_done()
 
@@ -898,6 +936,7 @@ class DataflowHook(GoogleBaseHook):
             poll_sleep=self.poll_sleep,
             drain_pipeline=self.drain_pipeline,
             num_retries=self.num_retries,
+            cancel_timeout=self.cancel_timeout,
         )
         return jobs_controller.is_job_running()
 
@@ -933,6 +972,7 @@ class DataflowHook(GoogleBaseHook):
             poll_sleep=self.poll_sleep,
             drain_pipeline=self.drain_pipeline,
             num_retries=self.num_retries,
+            cancel_timeout=self.cancel_timeout,
         )
         jobs_controller.cancel()
 
diff --git a/airflow/providers/google/cloud/operators/dataflow.py 
b/airflow/providers/google/cloud/operators/dataflow.py
index d2844a3..faca570 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -132,6 +132,9 @@ class DataflowCreateJavaJobOperator(BaseOperator):
     :type check_if_running: CheckJobRunning(IgnoreJob = do not check if 
running, FinishIfRunning=
         if job is running finish with nothing, WaitForRun= wait until job 
finished and the run job)
         ``jar``, ``options``, and ``job_name`` are templated so you can use 
variables in them.
+    :param cancel_timeout: How long (in seconds) operator should wait for the 
pipeline to be
+        successfully cancelled when task is being killed.
+    :type cancel_timeout: Optional[int]
 
     Note that both
     ``dataflow_default_options`` and ``options`` will be merged to specify 
pipeline
@@ -193,6 +196,7 @@ class DataflowCreateJavaJobOperator(BaseOperator):
         job_class: Optional[str] = None,
         check_if_running: CheckJobRunning = CheckJobRunning.WaitForRun,
         multiple_jobs: Optional[bool] = None,
+        cancel_timeout: Optional[int] = 10 * 60,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -214,6 +218,7 @@ class DataflowCreateJavaJobOperator(BaseOperator):
         self.poll_sleep = poll_sleep
         self.job_class = job_class
         self.check_if_running = check_if_running
+        self.cancel_timeout = cancel_timeout
         self.job_id = None
         self.hook = None
 
@@ -222,6 +227,7 @@ class DataflowCreateJavaJobOperator(BaseOperator):
             gcp_conn_id=self.gcp_conn_id,
             delegate_to=self.delegate_to,
             poll_sleep=self.poll_sleep,
+            cancel_timeout=self.cancel_timeout,
         )
         dataflow_options = copy.copy(self.dataflow_default_options)
         dataflow_options.update(self.options)
@@ -324,6 +330,9 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
             `https://cloud.google.com/dataflow/pipelines/specifying-exec-params
             
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
     :type environment: Optional[dict]
+    :param cancel_timeout: How long (in seconds) operator should wait for the 
pipeline to be
+        successfully cancelled when task is being killed.
+    :type cancel_timeout: Optional[int]
 
     It's a good practice to define dataflow_* parameters in the default_args 
of the dag
     like the project, zone and staging location.
@@ -401,6 +410,7 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
         poll_sleep: int = 10,
         impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
         environment: Optional[Dict] = None,
+        cancel_timeout: Optional[int] = 10 * 60,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -418,6 +428,7 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
         self.hook: Optional[DataflowHook] = None
         self.impersonation_chain = impersonation_chain
         self.environment = environment
+        self.cancel_timeout = cancel_timeout
 
     def execute(self, context) -> dict:
         self.hook = DataflowHook(
@@ -425,6 +436,7 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
             delegate_to=self.delegate_to,
             poll_sleep=self.poll_sleep,
             impersonation_chain=self.impersonation_chain,
+            cancel_timeout=self.cancel_timeout,
         )
 
         def set_current_job_id(job_id):
@@ -473,6 +485,9 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
         instead of canceling during during killing task instance. See:
         https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
     :type drain_pipeline: bool
+    :param cancel_timeout: How long (in seconds) operator should wait for the 
pipeline to be
+        successfully cancelled when task is being killed.
+    :type cancel_timeout: Optional[int]
     """
 
     template_fields = ["body", "location", "project_id", "gcp_conn_id"]
@@ -486,6 +501,7 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
         gcp_conn_id: str = "google_cloud_default",
         delegate_to: Optional[str] = None,
         drain_pipeline: bool = False,
+        cancel_timeout: Optional[int] = 10 * 60,
         *args,
         **kwargs,
     ) -> None:
@@ -495,15 +511,17 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
         self.project_id = project_id
         self.gcp_conn_id = gcp_conn_id
         self.delegate_to = delegate_to
+        self.drain_pipeline = drain_pipeline
+        self.cancel_timeout = cancel_timeout
         self.job_id = None
         self.hook: Optional[DataflowHook] = None
-        self.drain_pipeline = drain_pipeline
 
     def execute(self, context):
         self.hook = DataflowHook(
             gcp_conn_id=self.gcp_conn_id,
             delegate_to=self.delegate_to,
             drain_pipeline=self.drain_pipeline,
+            cancel_timeout=self.cancel_timeout,
         )
 
         def set_current_job_id(job_id):
@@ -692,6 +710,9 @@ class DataflowCreatePythonJobOperator(BaseOperator):
         instead of canceling during during killing task instance. See:
         https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
     :type drain_pipeline: bool
+    :param cancel_timeout: How long (in seconds) operator should wait for the 
pipeline to be
+        successfully cancelled when task is being killed.
+    :type cancel_timeout: Optional[int]
     """
 
     template_fields = ["options", "dataflow_default_options", "job_name", 
"py_file"]
@@ -714,6 +735,7 @@ class DataflowCreatePythonJobOperator(BaseOperator):
         delegate_to: Optional[str] = None,
         poll_sleep: int = 10,
         drain_pipeline: bool = False,
+        cancel_timeout: Optional[int] = 10 * 60,
         **kwargs,
     ) -> None:
 
@@ -736,6 +758,7 @@ class DataflowCreatePythonJobOperator(BaseOperator):
         self.delegate_to = delegate_to
         self.poll_sleep = poll_sleep
         self.drain_pipeline = drain_pipeline
+        self.cancel_timeout = cancel_timeout
         self.job_id = None
         self.hook = None
 
@@ -754,6 +777,7 @@ class DataflowCreatePythonJobOperator(BaseOperator):
                 delegate_to=self.delegate_to,
                 poll_sleep=self.poll_sleep,
                 drain_pipeline=self.drain_pipeline,
+                cancel_timeout=self.cancel_timeout,
             )
             dataflow_options = self.dataflow_default_options.copy()
             dataflow_options.update(self.options)
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py 
b/tests/providers/google/cloud/hooks/test_dataflow.py
index 6dbe1f3..3a84cf3 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -91,6 +91,7 @@ TEST_FLEX_PARAMETERS = {
     },
 }
 TEST_PROJECT_ID = 'test-project-id'
+
 TEST_SQL_JOB_NAME = 'test-sql-job-name'
 TEST_DATASET = 'test-dataset'
 TEST_SQL_OPTIONS = {
@@ -109,6 +110,8 @@ GROUP BY sales_region;
 """
 TEST_SQL_JOB_ID = 'test-job-id'
 
+DEFAULT_CANCEL_TIMEOUT = 5 * 60
+
 
 class TestFallbackToVariables(unittest.TestCase):
     def test_support_project_id_parameter(self):
@@ -675,6 +678,7 @@ class TestDataflowTemplateHook(unittest.TestCase):
             project_number=TEST_PROJECT,
             location=DEFAULT_DATAFLOW_LOCATION,
             drain_pipeline=False,
+            cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
         )
         mock_controller.return_value.wait_for_done.assert_called_once()
 
@@ -712,6 +716,7 @@ class TestDataflowTemplateHook(unittest.TestCase):
             project_number=TEST_PROJECT,
             location=TEST_LOCATION,
             drain_pipeline=False,
+            cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
         )
         mock_controller.return_value.wait_for_done.assert_called_once()
 
@@ -751,6 +756,7 @@ class TestDataflowTemplateHook(unittest.TestCase):
             project_number=TEST_PROJECT,
             location=TEST_LOCATION,
             drain_pipeline=False,
+            cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
         )
         mock_controller.return_value.wait_for_done.assert_called_once()
 
@@ -794,6 +800,7 @@ class TestDataflowTemplateHook(unittest.TestCase):
             poll_sleep=10,
             project_number=TEST_PROJECT,
             drain_pipeline=False,
+            cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
         )
         mock_uuid.assert_called_once_with()
 
@@ -841,6 +848,7 @@ class TestDataflowTemplateHook(unittest.TestCase):
             poll_sleep=10,
             project_number=TEST_PROJECT,
             drain_pipeline=False,
+            cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
         )
         mock_uuid.assert_called_once_with()
 
@@ -872,6 +880,7 @@ class TestDataflowTemplateHook(unittest.TestCase):
             location=TEST_LOCATION,
             poll_sleep=self.dataflow_hook.poll_sleep,
             num_retries=self.dataflow_hook.num_retries,
+            cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
         )
         
mock_controller.return_value.get_jobs.wait_for_done.assrt_called_once_with()
         mock_controller.return_value.get_jobs.assrt_called_once_with()
@@ -893,6 +902,7 @@ class TestDataflowTemplateHook(unittest.TestCase):
             project_number=TEST_PROJECT,
             num_retries=5,
             drain_pipeline=False,
+            cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
         )
         jobs_controller.cancel()
 
@@ -1290,6 +1300,51 @@ class TestDataflowJob(unittest.TestCase):
         )
         mock_batch.add.assert_called_once_with(mock_update.return_value)
 
+    @mock.patch("airflow.providers.google.cloud.hooks.dataflow.timeout")
+    @mock.patch("time.sleep")
+    def test_dataflow_job_cancel_job_cancel_timeout(self, mock_sleep, 
mock_timeout):
+        mock_jobs = 
self.mock_dataflow.projects.return_value.locations.return_value.jobs
+        get_method = mock_jobs.return_value.get
+        get_method.return_value.execute.side_effect = [
+            {"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": 
DataflowJobStatus.JOB_STATE_CANCELLING},
+            {"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": 
DataflowJobStatus.JOB_STATE_CANCELLING},
+            {"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": 
DataflowJobStatus.JOB_STATE_CANCELLING},
+            {"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": 
DataflowJobStatus.JOB_STATE_CANCELLING},
+            {"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": 
DataflowJobStatus.JOB_STATE_CANCELLED},
+        ]
+
+        mock_jobs.return_value.list_next.return_value = None
+        dataflow_job = _DataflowJobsController(
+            dataflow=self.mock_dataflow,
+            project_number=TEST_PROJECT,
+            name=UNIQUE_JOB_NAME,
+            location=TEST_LOCATION,
+            poll_sleep=4,
+            job_id=TEST_JOB_ID,
+            num_retries=20,
+            multiple_jobs=False,
+            cancel_timeout=10,
+        )
+        dataflow_job.cancel()
+
+        get_method.assert_called_with(jobId=TEST_JOB_ID, 
location=TEST_LOCATION, projectId=TEST_PROJECT)
+        get_method.return_value.execute.assert_called_with(num_retries=20)
+
+        self.mock_dataflow.new_batch_http_request.assert_called_once_with()
+        mock_batch = self.mock_dataflow.new_batch_http_request.return_value
+        mock_update = mock_jobs.return_value.update
+        mock_update.assert_called_once_with(
+            body={'requestedState': 'JOB_STATE_CANCELLED'},
+            jobId='test-job-id',
+            location=TEST_LOCATION,
+            projectId='test-project',
+        )
+        mock_batch.add.assert_called_once_with(mock_update.return_value)
+        mock_sleep.assert_has_calls([mock.call(4), mock.call(4), mock.call(4)])
+        mock_timeout.assert_called_once_with(
+            seconds=10, error_message='Canceling jobs failed due to timeout 
(10s): test-job-id'
+        )
+
     @parameterized.expand(
         [
             (False, "JOB_TYPE_BATCH", "JOB_STATE_CANCELLED"),
@@ -1324,6 +1379,7 @@ class TestDataflowJob(unittest.TestCase):
             num_retries=20,
             multiple_jobs=False,
             drain_pipeline=drain_pipeline,
+            cancel_timeout=None,
         )
         dataflow_job.cancel()
 
diff --git a/tests/providers/google/cloud/operators/test_mlengine_utils.py 
b/tests/providers/google/cloud/operators/test_mlengine_utils.py
index 6c13315..9a3a142 100644
--- a/tests/providers/google/cloud/operators/test_mlengine_utils.py
+++ b/tests/providers/google/cloud/operators/test_mlengine_utils.py
@@ -109,7 +109,11 @@ class TestCreateEvaluateOps(unittest.TestCase):
             hook_instance.start_python_dataflow.return_value = None
             summary.execute(None)
             mock_dataflow_hook.assert_called_once_with(
-                gcp_conn_id='google_cloud_default', delegate_to=None, 
poll_sleep=10, drain_pipeline=False
+                gcp_conn_id='google_cloud_default',
+                delegate_to=None,
+                poll_sleep=10,
+                drain_pipeline=False,
+                cancel_timeout=600,
             )
             hook_instance.start_python_dataflow.assert_called_once_with(
                 job_name='{{task.task_id}}',

Reply via email to