This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 0caec9f Dataflow - add waiting for successful job cancel (#11501)
0caec9f is described below
commit 0caec9fd32bee2b3036b5d7bdcb56bd6a3b9dccf
Author: Tobiasz Kędzierski <[email protected]>
AuthorDate: Fri Nov 6 15:14:14 2020 +0100
Dataflow - add waiting for successful job cancel (#11501)
Co-authored-by: Kamil Breguła <[email protected]>
---
airflow/providers/google/cloud/hooks/dataflow.py | 46 ++++++++++++++++--
.../providers/google/cloud/operators/dataflow.py | 26 +++++++++-
.../providers/google/cloud/hooks/test_dataflow.py | 56 ++++++++++++++++++++++
.../google/cloud/operators/test_mlengine_utils.py | 6 ++-
4 files changed, 129 insertions(+), 5 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/dataflow.py
b/airflow/providers/google/cloud/hooks/dataflow.py
index e17001f..7467207 100644
--- a/airflow/providers/google/cloud/hooks/dataflow.py
+++ b/airflow/providers/google/cloud/hooks/dataflow.py
@@ -28,7 +28,7 @@ import uuid
import warnings
from copy import deepcopy
from tempfile import TemporaryDirectory
-from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar,
Union, cast
+from typing import Any, Callable, Dict, List, Optional, Sequence, Set,
TypeVar, Union, cast
from googleapiclient.discovery import build
@@ -36,6 +36,7 @@ from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.python_virtualenv import prepare_virtualenv
+from airflow.utils.timeout import timeout
# This is the default location
# https://cloud.google.com/dataflow/pipelines/specifying-exec-params
@@ -147,9 +148,10 @@ class _DataflowJobsController(LoggingMixin):
not by specific job ID, then actions will be performed on all matching
jobs.
:param drain_pipeline: Optional, set to True if want to stop streaming job
by draining it
instead of canceling.
+ :param cancel_timeout: wait time in seconds for successful job canceling
"""
- def __init__(
+ def __init__( # pylint: disable=too-many-arguments
self,
dataflow: Any,
project_number: str,
@@ -160,6 +162,7 @@ class _DataflowJobsController(LoggingMixin):
num_retries: int = 0,
multiple_jobs: bool = False,
drain_pipeline: bool = False,
+ cancel_timeout: Optional[int] = 5 * 60,
) -> None:
super().__init__()
@@ -171,8 +174,9 @@ class _DataflowJobsController(LoggingMixin):
self._job_id = job_id
self._num_retries = num_retries
self._poll_sleep = poll_sleep
- self.drain_pipeline = drain_pipeline
+ self._cancel_timeout = cancel_timeout
self._jobs: Optional[List[dict]] = None
+ self.drain_pipeline = drain_pipeline
def is_job_running(self) -> bool:
"""
@@ -317,6 +321,29 @@ class _DataflowJobsController(LoggingMixin):
return self._jobs
+ def _wait_for_states(self, expected_states: Set[str]):
+ """Waiting for the jobs to reach a certain state."""
+ if not self._jobs:
+ raise ValueError("The _jobs should be set")
+ while True:
+ self._refresh_jobs()
+ job_states = {job['currentState'] for job in self._jobs}
+ if not job_states.difference(expected_states):
+ return
+ unexpected_failed_end_states = expected_states -
DataflowJobStatus.FAILED_END_STATES
+ if unexpected_failed_end_states.intersection(job_states):
+ unexpected_failed_jobs = {
+ job for job in self._jobs if job['currentState'] in
unexpected_failed_end_states
+ }
+ raise AirflowException(
+ "Jobs failed: "
+ + ", ".join(
+ f"ID: {job['id']} name: {job['name']} state:
{job['currentState']}"
+ for job in unexpected_failed_jobs
+ )
+ )
+ time.sleep(self._poll_sleep)
+
def cancel(self) -> None:
"""Cancels or drains current job"""
jobs = self.get_jobs()
@@ -342,6 +369,12 @@ class _DataflowJobsController(LoggingMixin):
)
)
batch.execute()
+ if self._cancel_timeout and isinstance(self._cancel_timeout, int):
+ timeout_error_message = "Canceling jobs failed due to timeout
({}s): {}".format(
+ self._cancel_timeout, ", ".join(job_ids)
+ )
+ with timeout(seconds=self._cancel_timeout,
error_message=timeout_error_message):
+
self._wait_for_states({DataflowJobStatus.JOB_STATE_CANCELLED})
else:
self.log.info("No jobs to cancel")
@@ -453,9 +486,11 @@ class DataflowHook(GoogleBaseHook):
poll_sleep: int = 10,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
drain_pipeline: bool = False,
+ cancel_timeout: Optional[int] = 5 * 60,
) -> None:
self.poll_sleep = poll_sleep
self.drain_pipeline = drain_pipeline
+ self.cancel_timeout = cancel_timeout
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
@@ -496,6 +531,7 @@ class DataflowHook(GoogleBaseHook):
num_retries=self.num_retries,
multiple_jobs=multiple_jobs,
drain_pipeline=self.drain_pipeline,
+ cancel_timeout=self.cancel_timeout,
)
job_controller.wait_for_done()
@@ -669,6 +705,7 @@ class DataflowHook(GoogleBaseHook):
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
drain_pipeline=self.drain_pipeline,
+ cancel_timeout=self.cancel_timeout,
)
jobs_controller.wait_for_done()
return response["job"]
@@ -714,6 +751,7 @@ class DataflowHook(GoogleBaseHook):
location=location,
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
+ cancel_timeout=self.cancel_timeout,
)
jobs_controller.wait_for_done()
@@ -898,6 +936,7 @@ class DataflowHook(GoogleBaseHook):
poll_sleep=self.poll_sleep,
drain_pipeline=self.drain_pipeline,
num_retries=self.num_retries,
+ cancel_timeout=self.cancel_timeout,
)
return jobs_controller.is_job_running()
@@ -933,6 +972,7 @@ class DataflowHook(GoogleBaseHook):
poll_sleep=self.poll_sleep,
drain_pipeline=self.drain_pipeline,
num_retries=self.num_retries,
+ cancel_timeout=self.cancel_timeout,
)
jobs_controller.cancel()
diff --git a/airflow/providers/google/cloud/operators/dataflow.py
b/airflow/providers/google/cloud/operators/dataflow.py
index d2844a3..faca570 100644
--- a/airflow/providers/google/cloud/operators/dataflow.py
+++ b/airflow/providers/google/cloud/operators/dataflow.py
@@ -132,6 +132,9 @@ class DataflowCreateJavaJobOperator(BaseOperator):
:type check_if_running: CheckJobRunning(IgnoreJob = do not check if
running, FinishIfRunning=
if job is running finish with nothing, WaitForRun= wait until job
finished and the run job)
``jar``, ``options``, and ``job_name`` are templated so you can use
variables in them.
+ :param cancel_timeout: How long (in seconds) operator should wait for the
pipeline to be
+ successfully cancelled when task is being killed.
+ :type cancel_timeout: Optional[int]
Note that both
``dataflow_default_options`` and ``options`` will be merged to specify
pipeline
@@ -193,6 +196,7 @@ class DataflowCreateJavaJobOperator(BaseOperator):
job_class: Optional[str] = None,
check_if_running: CheckJobRunning = CheckJobRunning.WaitForRun,
multiple_jobs: Optional[bool] = None,
+ cancel_timeout: Optional[int] = 10 * 60,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -214,6 +218,7 @@ class DataflowCreateJavaJobOperator(BaseOperator):
self.poll_sleep = poll_sleep
self.job_class = job_class
self.check_if_running = check_if_running
+ self.cancel_timeout = cancel_timeout
self.job_id = None
self.hook = None
@@ -222,6 +227,7 @@ class DataflowCreateJavaJobOperator(BaseOperator):
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
poll_sleep=self.poll_sleep,
+ cancel_timeout=self.cancel_timeout,
)
dataflow_options = copy.copy(self.dataflow_default_options)
dataflow_options.update(self.options)
@@ -324,6 +330,9 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
`https://cloud.google.com/dataflow/pipelines/specifying-exec-params
<https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment>`__
:type environment: Optional[dict]
+ :param cancel_timeout: How long (in seconds) operator should wait for the
pipeline to be
+ successfully cancelled when task is being killed.
+ :type cancel_timeout: Optional[int]
It's a good practice to define dataflow_* parameters in the default_args
of the dag
like the project, zone and staging location.
@@ -401,6 +410,7 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
poll_sleep: int = 10,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
environment: Optional[Dict] = None,
+ cancel_timeout: Optional[int] = 10 * 60,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -418,6 +428,7 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
self.hook: Optional[DataflowHook] = None
self.impersonation_chain = impersonation_chain
self.environment = environment
+ self.cancel_timeout = cancel_timeout
def execute(self, context) -> dict:
self.hook = DataflowHook(
@@ -425,6 +436,7 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
delegate_to=self.delegate_to,
poll_sleep=self.poll_sleep,
impersonation_chain=self.impersonation_chain,
+ cancel_timeout=self.cancel_timeout,
)
def set_current_job_id(job_id):
@@ -473,6 +485,9 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
instead of canceling during during killing task instance. See:
https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
:type drain_pipeline: bool
+ :param cancel_timeout: How long (in seconds) operator should wait for the
pipeline to be
+ successfully cancelled when task is being killed.
+ :type cancel_timeout: Optional[int]
"""
template_fields = ["body", "location", "project_id", "gcp_conn_id"]
@@ -486,6 +501,7 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
drain_pipeline: bool = False,
+ cancel_timeout: Optional[int] = 10 * 60,
*args,
**kwargs,
) -> None:
@@ -495,15 +511,17 @@ class DataflowStartFlexTemplateOperator(BaseOperator):
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
+ self.drain_pipeline = drain_pipeline
+ self.cancel_timeout = cancel_timeout
self.job_id = None
self.hook: Optional[DataflowHook] = None
- self.drain_pipeline = drain_pipeline
def execute(self, context):
self.hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
drain_pipeline=self.drain_pipeline,
+ cancel_timeout=self.cancel_timeout,
)
def set_current_job_id(job_id):
@@ -692,6 +710,9 @@ class DataflowCreatePythonJobOperator(BaseOperator):
instead of canceling during during killing task instance. See:
https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
:type drain_pipeline: bool
+ :param cancel_timeout: How long (in seconds) operator should wait for the
pipeline to be
+ successfully cancelled when task is being killed.
+ :type cancel_timeout: Optional[int]
"""
template_fields = ["options", "dataflow_default_options", "job_name",
"py_file"]
@@ -714,6 +735,7 @@ class DataflowCreatePythonJobOperator(BaseOperator):
delegate_to: Optional[str] = None,
poll_sleep: int = 10,
drain_pipeline: bool = False,
+ cancel_timeout: Optional[int] = 10 * 60,
**kwargs,
) -> None:
@@ -736,6 +758,7 @@ class DataflowCreatePythonJobOperator(BaseOperator):
self.delegate_to = delegate_to
self.poll_sleep = poll_sleep
self.drain_pipeline = drain_pipeline
+ self.cancel_timeout = cancel_timeout
self.job_id = None
self.hook = None
@@ -754,6 +777,7 @@ class DataflowCreatePythonJobOperator(BaseOperator):
delegate_to=self.delegate_to,
poll_sleep=self.poll_sleep,
drain_pipeline=self.drain_pipeline,
+ cancel_timeout=self.cancel_timeout,
)
dataflow_options = self.dataflow_default_options.copy()
dataflow_options.update(self.options)
diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py
b/tests/providers/google/cloud/hooks/test_dataflow.py
index 6dbe1f3..3a84cf3 100644
--- a/tests/providers/google/cloud/hooks/test_dataflow.py
+++ b/tests/providers/google/cloud/hooks/test_dataflow.py
@@ -91,6 +91,7 @@ TEST_FLEX_PARAMETERS = {
},
}
TEST_PROJECT_ID = 'test-project-id'
+
TEST_SQL_JOB_NAME = 'test-sql-job-name'
TEST_DATASET = 'test-dataset'
TEST_SQL_OPTIONS = {
@@ -109,6 +110,8 @@ GROUP BY sales_region;
"""
TEST_SQL_JOB_ID = 'test-job-id'
+DEFAULT_CANCEL_TIMEOUT = 5 * 60
+
class TestFallbackToVariables(unittest.TestCase):
def test_support_project_id_parameter(self):
@@ -675,6 +678,7 @@ class TestDataflowTemplateHook(unittest.TestCase):
project_number=TEST_PROJECT,
location=DEFAULT_DATAFLOW_LOCATION,
drain_pipeline=False,
+ cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
)
mock_controller.return_value.wait_for_done.assert_called_once()
@@ -712,6 +716,7 @@ class TestDataflowTemplateHook(unittest.TestCase):
project_number=TEST_PROJECT,
location=TEST_LOCATION,
drain_pipeline=False,
+ cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
)
mock_controller.return_value.wait_for_done.assert_called_once()
@@ -751,6 +756,7 @@ class TestDataflowTemplateHook(unittest.TestCase):
project_number=TEST_PROJECT,
location=TEST_LOCATION,
drain_pipeline=False,
+ cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
)
mock_controller.return_value.wait_for_done.assert_called_once()
@@ -794,6 +800,7 @@ class TestDataflowTemplateHook(unittest.TestCase):
poll_sleep=10,
project_number=TEST_PROJECT,
drain_pipeline=False,
+ cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
)
mock_uuid.assert_called_once_with()
@@ -841,6 +848,7 @@ class TestDataflowTemplateHook(unittest.TestCase):
poll_sleep=10,
project_number=TEST_PROJECT,
drain_pipeline=False,
+ cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
)
mock_uuid.assert_called_once_with()
@@ -872,6 +880,7 @@ class TestDataflowTemplateHook(unittest.TestCase):
location=TEST_LOCATION,
poll_sleep=self.dataflow_hook.poll_sleep,
num_retries=self.dataflow_hook.num_retries,
+ cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
)
mock_controller.return_value.get_jobs.wait_for_done.assrt_called_once_with()
mock_controller.return_value.get_jobs.assrt_called_once_with()
@@ -893,6 +902,7 @@ class TestDataflowTemplateHook(unittest.TestCase):
project_number=TEST_PROJECT,
num_retries=5,
drain_pipeline=False,
+ cancel_timeout=DEFAULT_CANCEL_TIMEOUT,
)
jobs_controller.cancel()
@@ -1290,6 +1300,51 @@ class TestDataflowJob(unittest.TestCase):
)
mock_batch.add.assert_called_once_with(mock_update.return_value)
+ @mock.patch("airflow.providers.google.cloud.hooks.dataflow.timeout")
+ @mock.patch("time.sleep")
+ def test_dataflow_job_cancel_job_cancel_timeout(self, mock_sleep,
mock_timeout):
+ mock_jobs =
self.mock_dataflow.projects.return_value.locations.return_value.jobs
+ get_method = mock_jobs.return_value.get
+ get_method.return_value.execute.side_effect = [
+ {"id": TEST_JOB_ID, "name": JOB_NAME, "currentState":
DataflowJobStatus.JOB_STATE_CANCELLING},
+ {"id": TEST_JOB_ID, "name": JOB_NAME, "currentState":
DataflowJobStatus.JOB_STATE_CANCELLING},
+ {"id": TEST_JOB_ID, "name": JOB_NAME, "currentState":
DataflowJobStatus.JOB_STATE_CANCELLING},
+ {"id": TEST_JOB_ID, "name": JOB_NAME, "currentState":
DataflowJobStatus.JOB_STATE_CANCELLING},
+ {"id": TEST_JOB_ID, "name": JOB_NAME, "currentState":
DataflowJobStatus.JOB_STATE_CANCELLED},
+ ]
+
+ mock_jobs.return_value.list_next.return_value = None
+ dataflow_job = _DataflowJobsController(
+ dataflow=self.mock_dataflow,
+ project_number=TEST_PROJECT,
+ name=UNIQUE_JOB_NAME,
+ location=TEST_LOCATION,
+ poll_sleep=4,
+ job_id=TEST_JOB_ID,
+ num_retries=20,
+ multiple_jobs=False,
+ cancel_timeout=10,
+ )
+ dataflow_job.cancel()
+
+ get_method.assert_called_with(jobId=TEST_JOB_ID,
location=TEST_LOCATION, projectId=TEST_PROJECT)
+ get_method.return_value.execute.assert_called_with(num_retries=20)
+
+ self.mock_dataflow.new_batch_http_request.assert_called_once_with()
+ mock_batch = self.mock_dataflow.new_batch_http_request.return_value
+ mock_update = mock_jobs.return_value.update
+ mock_update.assert_called_once_with(
+ body={'requestedState': 'JOB_STATE_CANCELLED'},
+ jobId='test-job-id',
+ location=TEST_LOCATION,
+ projectId='test-project',
+ )
+ mock_batch.add.assert_called_once_with(mock_update.return_value)
+ mock_sleep.assert_has_calls([mock.call(4), mock.call(4), mock.call(4)])
+ mock_timeout.assert_called_once_with(
+ seconds=10, error_message='Canceling jobs failed due to timeout
(10s): test-job-id'
+ )
+
@parameterized.expand(
[
(False, "JOB_TYPE_BATCH", "JOB_STATE_CANCELLED"),
@@ -1324,6 +1379,7 @@ class TestDataflowJob(unittest.TestCase):
num_retries=20,
multiple_jobs=False,
drain_pipeline=drain_pipeline,
+ cancel_timeout=None,
)
dataflow_job.cancel()
diff --git a/tests/providers/google/cloud/operators/test_mlengine_utils.py
b/tests/providers/google/cloud/operators/test_mlengine_utils.py
index 6c13315..9a3a142 100644
--- a/tests/providers/google/cloud/operators/test_mlengine_utils.py
+++ b/tests/providers/google/cloud/operators/test_mlengine_utils.py
@@ -109,7 +109,11 @@ class TestCreateEvaluateOps(unittest.TestCase):
hook_instance.start_python_dataflow.return_value = None
summary.execute(None)
mock_dataflow_hook.assert_called_once_with(
- gcp_conn_id='google_cloud_default', delegate_to=None,
poll_sleep=10, drain_pipeline=False
+ gcp_conn_id='google_cloud_default',
+ delegate_to=None,
+ poll_sleep=10,
+ drain_pipeline=False,
+ cancel_timeout=600,
)
hook_instance.start_python_dataflow.assert_called_once_with(
job_name='{{task.task_id}}',