This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new ece685b Asynchronous execution of Dataproc jobs with a Sensor (#10673)
ece685b is described below
commit ece685b5b895ad1175440b49bf9e620dffd8248d
Author: Varun Dhussa <[email protected]>
AuthorDate: Sat Sep 5 17:41:37 2020 +0530
Asynchronous execution of Dataproc jobs with a Sensor (#10673)
---
.../google/cloud/example_dags/example_dataproc.py | 16 +++
.../providers/google/cloud/operators/dataproc.py | 32 +++++-
airflow/providers/google/cloud/sensors/dataproc.py | 81 +++++++++++++
docs/operators-and-hooks-ref.rst | 2 +-
.../google/cloud/operators/test_dataproc.py | 36 ++++++
.../google/cloud/sensors/test_dataproc.py | 128 +++++++++++++++++++++
6 files changed, 289 insertions(+), 6 deletions(-)
diff --git a/airflow/providers/google/cloud/example_dags/example_dataproc.py
b/airflow/providers/google/cloud/example_dags/example_dataproc.py
index fd463dc..494844c 100644
--- a/airflow/providers/google/cloud/example_dags/example_dataproc.py
+++ b/airflow/providers/google/cloud/example_dags/example_dataproc.py
@@ -29,6 +29,7 @@ from airflow.providers.google.cloud.operators.dataproc import
(
DataprocSubmitJobOperator,
DataprocUpdateClusterOperator,
)
+from airflow.providers.google.cloud.sensors.dataproc import DataprocJobSensor
from airflow.utils.dates import days_ago
PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id")
@@ -170,6 +171,20 @@ with models.DAG("example_gcp_dataproc",
start_date=days_ago(1), schedule_interva
task_id="spark_task", job=SPARK_JOB, location=REGION,
project_id=PROJECT_ID
)
+ # [START cloud_dataproc_async_submit_sensor]
+ spark_task_async = DataprocSubmitJobOperator(
+ task_id="spark_task_async", job=SPARK_JOB, location=REGION,
project_id=PROJECT_ID, asynchronous=True
+ )
+
+ spark_task_async_sensor = DataprocJobSensor(
+ task_id='spark_task_async_sensor_task',
+ location=REGION,
+ project_id=PROJECT_ID,
+
dataproc_job_id="{{task_instance.xcom_pull(task_ids='spark_task_async')}}",
+ poke_interval=10,
+ )
+ # [END cloud_dataproc_async_submit_sensor]
+
# [START how_to_cloud_dataproc_submit_job_to_cluster_operator]
pyspark_task = DataprocSubmitJobOperator(
task_id="pyspark_task", job=PYSPARK_JOB, location=REGION,
project_id=PROJECT_ID
@@ -199,6 +214,7 @@ with models.DAG("example_gcp_dataproc",
start_date=days_ago(1), schedule_interva
scale_cluster >> pig_task >> delete_cluster
scale_cluster >> spark_sql_task >> delete_cluster
scale_cluster >> spark_task >> delete_cluster
+ scale_cluster >> spark_task_async >> spark_task_async_sensor >>
delete_cluster
scale_cluster >> pyspark_task >> delete_cluster
scale_cluster >> sparkr_task >> delete_cluster
scale_cluster >> hadoop_task >> delete_cluster
diff --git a/airflow/providers/google/cloud/operators/dataproc.py
b/airflow/providers/google/cloud/operators/dataproc.py
index 403fd36..1438103 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -906,6 +906,10 @@ class DataprocJobBaseOperator(BaseOperator):
Service Account Token Creator IAM role to the directly preceding
identity, with first
account from the list granting this role to the originating account
(templated).
:type impersonation_chain: Union[str, Sequence[str]]
+ :param asynchronous: Flag to return after submitting the job to the
Dataproc API.
+ This is useful for submitting long running jobs and
+ waiting on them asynchronously using the DataprocJobSensor
+ :type asynchronous: bool
:var dataproc_job_id: The actual "jobId" as submitted to the Dataproc API.
This is useful for identifying or linking to the job in the Google
Cloud Console
@@ -930,6 +934,7 @@ class DataprocJobBaseOperator(BaseOperator):
region: str = 'global',
job_error_states: Optional[Set[str]] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ asynchronous: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -949,6 +954,7 @@ class DataprocJobBaseOperator(BaseOperator):
self.job_template = None
self.job = None
self.dataproc_job_id = None
+ self.asynchronous = asynchronous
def create_job_template(self):
"""
@@ -980,8 +986,13 @@ class DataprocJobBaseOperator(BaseOperator):
project_id=self.project_id, job=self.job["job"],
location=self.region,
)
job_id = job_object.reference.job_id
- self.hook.wait_for_job(job_id=job_id, location=self.region,
project_id=self.project_id)
- self.log.info('Job executed correctly.')
+ self.log.info('Job %s submitted successfully.', job_id)
+
+ if not self.asynchronous:
+ self.log.info('Waiting for job %s to complete', job_id)
+ self.hook.wait_for_job(job_id=job_id, location=self.region,
project_id=self.project_id)
+ self.log.info('Job %s completed successfully.', job_id)
+ return job_id
else:
raise AirflowException("Create a job template before")
@@ -1785,6 +1796,10 @@ class DataprocSubmitJobOperator(BaseOperator):
Service Account Token Creator IAM role to the directly preceding
identity, with first
account from the list granting this role to the originating account
(templated).
:type impersonation_chain: Union[str, Sequence[str]]
+ :param asynchronous: Flag to return after submitting the job to the
Dataproc API.
+ This is useful for submitting long running jobs and
+ waiting on them asynchronously using the DataprocJobSensor
+ :type asynchronous: bool
"""
template_fields = (
@@ -1807,6 +1822,7 @@ class DataprocSubmitJobOperator(BaseOperator):
metadata: Optional[Sequence[Tuple[str, str]]] = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ asynchronous: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -1819,6 +1835,7 @@ class DataprocSubmitJobOperator(BaseOperator):
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
+ self.asynchronous = asynchronous
def execute(self, context: Dict):
self.log.info("Submitting job")
@@ -1833,9 +1850,14 @@ class DataprocSubmitJobOperator(BaseOperator):
metadata=self.metadata,
)
job_id = job_object.reference.job_id
- self.log.info("Waiting for job %s to complete", job_id)
- hook.wait_for_job(job_id=job_id, project_id=self.project_id,
location=self.location)
- self.log.info("Job completed successfully.")
+ self.log.info('Job %s submitted successfully.', job_id)
+
+ if not self.asynchronous:
+ self.log.info('Waiting for job %s to complete', job_id)
+ hook.wait_for_job(job_id=job_id, location=self.location,
project_id=self.project_id)
+ self.log.info('Job %s completed successfully.', job_id)
+
+ return job_id
class DataprocUpdateClusterOperator(BaseOperator):
diff --git a/airflow/providers/google/cloud/sensors/dataproc.py
b/airflow/providers/google/cloud/sensors/dataproc.py
new file mode 100644
index 0000000..f84d63b
--- /dev/null
+++ b/airflow/providers/google/cloud/sensors/dataproc.py
@@ -0,0 +1,81 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This module contains a Dataproc Job sensor.
+"""
+# pylint: disable=C0302
+
+from google.cloud.dataproc_v1beta2.types import JobStatus
+
+from airflow.providers.google.cloud.hooks.dataproc import DataprocHook
+from airflow.sensors.base_sensor_operator import BaseSensorOperator
+from airflow.utils.decorators import apply_defaults
+from airflow.exceptions import AirflowException
+
+
+class DataprocJobSensor(BaseSensorOperator):
+ """
+ Check for the state of a previously submitted Dataproc job.
+
+ :param project_id: The ID of the google cloud project in which
+ to create the cluster. (templated)
+ :type project_id: str
+ :param dataproc_job_id: The Dataproc job ID to poll. (templated)
+ :type dataproc_job_id: str
+ :param location: Required. The Cloud Dataproc region in which to handle
the request. (templated)
+ :type location: str
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud
Platform.
+ :type gcp_conn_id: str
+ """
+
+ template_fields = ('project_id', 'location', 'dataproc_job_id')
+ ui_color = '#f0eee4'
+
+ @apply_defaults
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ dataproc_job_id: str,
+ location: str,
+ gcp_conn_id: str = 'google_cloud_default',
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.gcp_conn_id = gcp_conn_id
+ self.dataproc_job_id = dataproc_job_id
+ self.location = location
+
+ def poke(self, context):
+ hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
+ job = hook.get_job(job_id=self.dataproc_job_id,
location=self.location, project_id=self.project_id)
+ state = job.status.state
+
+ if state == JobStatus.ERROR:
+ raise AirflowException('Job failed:\n{}'.format(job))
+ elif state in {JobStatus.CANCELLED, JobStatus.CANCEL_PENDING,
JobStatus.CANCEL_STARTED}:
+ raise AirflowException('Job was cancelled:\n{}'.format(job))
+ elif JobStatus.DONE == state:
+ self.log.debug("Job %s completed successfully.",
self.dataproc_job_id)
+ return True
+ elif JobStatus.ATTEMPT_FAILURE == state:
+ self.log.debug("Job %s attempt has failed.", self.dataproc_job_id)
+
+ self.log.info("Waiting for job %s to complete.", self.dataproc_job_id)
+ return False
diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst
index 1cc71e5..4f7c9e7 100644
--- a/docs/operators-and-hooks-ref.rst
+++ b/docs/operators-and-hooks-ref.rst
@@ -764,7 +764,7 @@ These integrations allow you to perform various operations
within the Google Clo
- :doc:`How to use <howto/operator/google/cloud/dataproc>`
- :mod:`airflow.providers.google.cloud.hooks.dataproc`
- :mod:`airflow.providers.google.cloud.operators.dataproc`
- -
+ - :mod:`airflow.providers.google.cloud.sensors.dataproc`
* - `Datastore <https://cloud.google.com/datastore/>`__
- :doc:`How to use <howto/operator/google/cloud/datastore>`
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py
b/tests/providers/google/cloud/operators/test_dataproc.py
index ac705c1..69f8e9d 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -452,6 +452,42 @@ class TestDataprocSubmitJobOperator(unittest.TestCase):
job_id=job_id, project_id=GCP_PROJECT, location=GCP_LOCATION
)
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_execute_async(self, mock_hook):
+ job = {}
+ job_id = "job_id"
+ mock_hook.return_value.wait_for_job.return_value = None
+ mock_hook.return_value.submit_job.return_value.reference.job_id =
job_id
+
+ op = DataprocSubmitJobOperator(
+ task_id=TASK_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ job=job,
+ gcp_conn_id=GCP_CONN_ID,
+ retry=RETRY,
+ asynchronous=True,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ request_id=REQUEST_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ op.execute(context={})
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ mock_hook.return_value.submit_job.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ job=job,
+ request_id=REQUEST_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ mock_hook.return_value.wait_for_job.assert_not_called()
+
class TestDataprocUpdateClusterOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
diff --git a/tests/providers/google/cloud/sensors/test_dataproc.py
b/tests/providers/google/cloud/sensors/test_dataproc.py
new file mode 100644
index 0000000..f9b4003
--- /dev/null
+++ b/tests/providers/google/cloud/sensors/test_dataproc.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from unittest import mock
+
+from google.cloud.dataproc_v1beta2.types import JobStatus
+from airflow import AirflowException
+from airflow.providers.google.cloud.sensors.dataproc import DataprocJobSensor
+
+from airflow.version import version as airflow_version
+
+AIRFLOW_VERSION = "v" + airflow_version.replace(".", "-").replace("+", "-")
+
+DATAPROC_PATH = "airflow.providers.google.cloud.sensors.dataproc.{}"
+
+TASK_ID = "task-id"
+GCP_PROJECT = "test-project"
+GCP_LOCATION = "test-location"
+GCP_CONN_ID = "test-conn"
+TIMEOUT = 120
+
+
+class TestDataprocJobSensor(unittest.TestCase):
+ def create_job(self, state: int):
+ job = mock.Mock()
+ job.status = mock.Mock()
+ job.status.state = state
+ return job
+
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_done(self, mock_hook):
+ job = self.create_job(JobStatus.DONE)
+ job_id = "job_id"
+ mock_hook.return_value.get_job.return_value = job
+
+ sensor = DataprocJobSensor(
+ task_id=TASK_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ dataproc_job_id=job_id,
+ gcp_conn_id=GCP_CONN_ID,
+ timeout=TIMEOUT,
+ )
+ ret = sensor.poke(context={})
+
+ mock_hook.return_value.get_job.assert_called_once_with(
+ job_id=job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
+ )
+ self.assertTrue(ret)
+
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_error(self, mock_hook):
+ job = self.create_job(JobStatus.ERROR)
+ job_id = "job_id"
+ mock_hook.return_value.get_job.return_value = job
+
+ sensor = DataprocJobSensor(
+ task_id=TASK_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ dataproc_job_id=job_id,
+ gcp_conn_id=GCP_CONN_ID,
+ timeout=TIMEOUT,
+ )
+
+ with self.assertRaisesRegex(AirflowException, "Job failed"):
+ sensor.poke(context={})
+
+ mock_hook.return_value.get_job.assert_called_once_with(
+ job_id=job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
+ )
+
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_wait(self, mock_hook):
+ job = self.create_job(JobStatus.RUNNING)
+ job_id = "job_id"
+ mock_hook.return_value.get_job.return_value = job
+
+ sensor = DataprocJobSensor(
+ task_id=TASK_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ dataproc_job_id=job_id,
+ gcp_conn_id=GCP_CONN_ID,
+ timeout=TIMEOUT,
+ )
+ ret = sensor.poke(context={})
+
+ mock_hook.return_value.get_job.assert_called_once_with(
+ job_id=job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
+ )
+ self.assertFalse(ret)
+
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_cancelled(self, mock_hook):
+ job = self.create_job(JobStatus.CANCELLED)
+ job_id = "job_id"
+ mock_hook.return_value.get_job.return_value = job
+
+ sensor = DataprocJobSensor(
+ task_id=TASK_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ dataproc_job_id=job_id,
+ gcp_conn_id=GCP_CONN_ID,
+ timeout=TIMEOUT,
+ )
+ with self.assertRaisesRegex(AirflowException, "Job was cancelled"):
+ sensor.poke(context={})
+
+ mock_hook.return_value.get_job.assert_called_once_with(
+ job_id=job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
+ )