This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new ece685b  Asynchronous execution of Dataproc jobs with a Sensor (#10673)
ece685b is described below

commit ece685b5b895ad1175440b49bf9e620dffd8248d
Author: Varun Dhussa <[email protected]>
AuthorDate: Sat Sep 5 17:41:37 2020 +0530

    Asynchronous execution of Dataproc jobs with a Sensor (#10673)
---
 .../google/cloud/example_dags/example_dataproc.py  |  16 +++
 .../providers/google/cloud/operators/dataproc.py   |  32 +++++-
 airflow/providers/google/cloud/sensors/dataproc.py |  81 +++++++++++++
 docs/operators-and-hooks-ref.rst                   |   2 +-
 .../google/cloud/operators/test_dataproc.py        |  36 ++++++
 .../google/cloud/sensors/test_dataproc.py          | 128 +++++++++++++++++++++
 6 files changed, 289 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/google/cloud/example_dags/example_dataproc.py 
b/airflow/providers/google/cloud/example_dags/example_dataproc.py
index fd463dc..494844c 100644
--- a/airflow/providers/google/cloud/example_dags/example_dataproc.py
+++ b/airflow/providers/google/cloud/example_dags/example_dataproc.py
@@ -29,6 +29,7 @@ from airflow.providers.google.cloud.operators.dataproc import 
(
     DataprocSubmitJobOperator,
     DataprocUpdateClusterOperator,
 )
+from airflow.providers.google.cloud.sensors.dataproc import DataprocJobSensor
 from airflow.utils.dates import days_ago
 
 PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id")
@@ -170,6 +171,20 @@ with models.DAG("example_gcp_dataproc", 
start_date=days_ago(1), schedule_interva
         task_id="spark_task", job=SPARK_JOB, location=REGION, 
project_id=PROJECT_ID
     )
 
+    # [START cloud_dataproc_async_submit_sensor]
+    spark_task_async = DataprocSubmitJobOperator(
+        task_id="spark_task_async", job=SPARK_JOB, location=REGION, 
project_id=PROJECT_ID, asynchronous=True
+    )
+
+    spark_task_async_sensor = DataprocJobSensor(
+        task_id='spark_task_async_sensor_task',
+        location=REGION,
+        project_id=PROJECT_ID,
+        
dataproc_job_id="{{task_instance.xcom_pull(task_ids='spark_task_async')}}",
+        poke_interval=10,
+    )
+    # [END cloud_dataproc_async_submit_sensor]
+
     # [START how_to_cloud_dataproc_submit_job_to_cluster_operator]
     pyspark_task = DataprocSubmitJobOperator(
         task_id="pyspark_task", job=PYSPARK_JOB, location=REGION, 
project_id=PROJECT_ID
@@ -199,6 +214,7 @@ with models.DAG("example_gcp_dataproc", 
start_date=days_ago(1), schedule_interva
     scale_cluster >> pig_task >> delete_cluster
     scale_cluster >> spark_sql_task >> delete_cluster
     scale_cluster >> spark_task >> delete_cluster
+    scale_cluster >> spark_task_async >> spark_task_async_sensor >> 
delete_cluster
     scale_cluster >> pyspark_task >> delete_cluster
     scale_cluster >> sparkr_task >> delete_cluster
     scale_cluster >> hadoop_task >> delete_cluster
diff --git a/airflow/providers/google/cloud/operators/dataproc.py 
b/airflow/providers/google/cloud/operators/dataproc.py
index 403fd36..1438103 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -906,6 +906,10 @@ class DataprocJobBaseOperator(BaseOperator):
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account 
(templated).
     :type impersonation_chain: Union[str, Sequence[str]]
+    :param asynchronous: Flag to return after submitting the job to the 
Dataproc API.
+        This is useful for submitting long running jobs and
+        waiting on them asynchronously using the DataprocJobSensor
+    :type asynchronous: bool
 
     :var dataproc_job_id: The actual "jobId" as submitted to the Dataproc API.
         This is useful for identifying or linking to the job in the Google 
Cloud Console
@@ -930,6 +934,7 @@ class DataprocJobBaseOperator(BaseOperator):
         region: str = 'global',
         job_error_states: Optional[Set[str]] = None,
         impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        asynchronous: bool = False,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -949,6 +954,7 @@ class DataprocJobBaseOperator(BaseOperator):
         self.job_template = None
         self.job = None
         self.dataproc_job_id = None
+        self.asynchronous = asynchronous
 
     def create_job_template(self):
         """
@@ -980,8 +986,13 @@ class DataprocJobBaseOperator(BaseOperator):
                 project_id=self.project_id, job=self.job["job"], 
location=self.region,
             )
             job_id = job_object.reference.job_id
-            self.hook.wait_for_job(job_id=job_id, location=self.region, 
project_id=self.project_id)
-            self.log.info('Job executed correctly.')
+            self.log.info('Job %s submitted successfully.', job_id)
+
+            if not self.asynchronous:
+                self.log.info('Waiting for job %s to complete', job_id)
+                self.hook.wait_for_job(job_id=job_id, location=self.region, 
project_id=self.project_id)
+                self.log.info('Job %s completed successfully.', job_id)
+            return job_id
         else:
             raise AirflowException("Create a job template before")
 
@@ -1785,6 +1796,10 @@ class DataprocSubmitJobOperator(BaseOperator):
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account 
(templated).
     :type impersonation_chain: Union[str, Sequence[str]]
+    :param asynchronous: Flag to return after submitting the job to the 
Dataproc API.
+        This is useful for submitting long running jobs and
+        waiting on them asynchronously using the DataprocJobSensor
+    :type asynchronous: bool
     """
 
     template_fields = (
@@ -1807,6 +1822,7 @@ class DataprocSubmitJobOperator(BaseOperator):
         metadata: Optional[Sequence[Tuple[str, str]]] = None,
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        asynchronous: bool = False,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -1819,6 +1835,7 @@ class DataprocSubmitJobOperator(BaseOperator):
         self.metadata = metadata
         self.gcp_conn_id = gcp_conn_id
         self.impersonation_chain = impersonation_chain
+        self.asynchronous = asynchronous
 
     def execute(self, context: Dict):
         self.log.info("Submitting job")
@@ -1833,9 +1850,14 @@ class DataprocSubmitJobOperator(BaseOperator):
             metadata=self.metadata,
         )
         job_id = job_object.reference.job_id
-        self.log.info("Waiting for job %s to complete", job_id)
-        hook.wait_for_job(job_id=job_id, project_id=self.project_id, 
location=self.location)
-        self.log.info("Job completed successfully.")
+        self.log.info('Job %s submitted successfully.', job_id)
+
+        if not self.asynchronous:
+            self.log.info('Waiting for job %s to complete', job_id)
+            hook.wait_for_job(job_id=job_id, location=self.location, 
project_id=self.project_id)
+            self.log.info('Job %s completed successfully.', job_id)
+
+        return job_id
 
 
 class DataprocUpdateClusterOperator(BaseOperator):
diff --git a/airflow/providers/google/cloud/sensors/dataproc.py 
b/airflow/providers/google/cloud/sensors/dataproc.py
new file mode 100644
index 0000000..f84d63b
--- /dev/null
+++ b/airflow/providers/google/cloud/sensors/dataproc.py
@@ -0,0 +1,81 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This module contains a Dataproc Job sensor.
+"""
+# pylint: disable=C0302
+
+from google.cloud.dataproc_v1beta2.types import JobStatus
+
+from airflow.providers.google.cloud.hooks.dataproc import DataprocHook
+from airflow.sensors.base_sensor_operator import BaseSensorOperator
+from airflow.utils.decorators import apply_defaults
+from airflow.exceptions import AirflowException
+
+
+class DataprocJobSensor(BaseSensorOperator):
+    """
+    Check for the state of a previously submitted Dataproc job.
+
+    :param project_id: The ID of the google cloud project in which
+        to create the cluster. (templated)
+    :type project_id: str
+    :param dataproc_job_id: The Dataproc job ID to poll. (templated)
+    :type dataproc_job_id: str
+    :param location: Required. The Cloud Dataproc region in which to handle 
the request. (templated)
+    :type location: str
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud 
Platform.
+    :type gcp_conn_id: str
+    """
+
+    template_fields = ('project_id', 'location', 'dataproc_job_id')
+    ui_color = '#f0eee4'
+
+    @apply_defaults
+    def __init__(
+        self,
+        *,
+        project_id: str,
+        dataproc_job_id: str,
+        location: str,
+        gcp_conn_id: str = 'google_cloud_default',
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.gcp_conn_id = gcp_conn_id
+        self.dataproc_job_id = dataproc_job_id
+        self.location = location
+
+    def poke(self, context):
+        hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
+        job = hook.get_job(job_id=self.dataproc_job_id, 
location=self.location, project_id=self.project_id)
+        state = job.status.state
+
+        if state == JobStatus.ERROR:
+            raise AirflowException('Job failed:\n{}'.format(job))
+        elif state in {JobStatus.CANCELLED, JobStatus.CANCEL_PENDING, 
JobStatus.CANCEL_STARTED}:
+            raise AirflowException('Job was cancelled:\n{}'.format(job))
+        elif JobStatus.DONE == state:
+            self.log.debug("Job %s completed successfully.", 
self.dataproc_job_id)
+            return True
+        elif JobStatus.ATTEMPT_FAILURE == state:
+            self.log.debug("Job %s attempt has failed.", self.dataproc_job_id)
+
+        self.log.info("Waiting for job %s to complete.", self.dataproc_job_id)
+        return False
diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst
index 1cc71e5..4f7c9e7 100644
--- a/docs/operators-and-hooks-ref.rst
+++ b/docs/operators-and-hooks-ref.rst
@@ -764,7 +764,7 @@ These integrations allow you to perform various operations 
within the Google Clo
      - :doc:`How to use <howto/operator/google/cloud/dataproc>`
      - :mod:`airflow.providers.google.cloud.hooks.dataproc`
      - :mod:`airflow.providers.google.cloud.operators.dataproc`
-     -
+     - :mod:`airflow.providers.google.cloud.sensors.dataproc`
 
    * - `Datastore <https://cloud.google.com/datastore/>`__
      - :doc:`How to use <howto/operator/google/cloud/datastore>`
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py 
b/tests/providers/google/cloud/operators/test_dataproc.py
index ac705c1..69f8e9d 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -452,6 +452,42 @@ class TestDataprocSubmitJobOperator(unittest.TestCase):
             job_id=job_id, project_id=GCP_PROJECT, location=GCP_LOCATION
         )
 
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_execute_async(self, mock_hook):
+        job = {}
+        job_id = "job_id"
+        mock_hook.return_value.wait_for_job.return_value = None
+        mock_hook.return_value.submit_job.return_value.reference.job_id = 
job_id
+
+        op = DataprocSubmitJobOperator(
+            task_id=TASK_ID,
+            location=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            job=job,
+            gcp_conn_id=GCP_CONN_ID,
+            retry=RETRY,
+            asynchronous=True,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            request_id=REQUEST_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        op.execute(context={})
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        mock_hook.return_value.submit_job.assert_called_once_with(
+            project_id=GCP_PROJECT,
+            location=GCP_LOCATION,
+            job=job,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+        mock_hook.return_value.wait_for_job.assert_not_called()
+
 
 class TestDataprocUpdateClusterOperator(unittest.TestCase):
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
diff --git a/tests/providers/google/cloud/sensors/test_dataproc.py 
b/tests/providers/google/cloud/sensors/test_dataproc.py
new file mode 100644
index 0000000..f9b4003
--- /dev/null
+++ b/tests/providers/google/cloud/sensors/test_dataproc.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from unittest import mock
+
+from google.cloud.dataproc_v1beta2.types import JobStatus
+from airflow import AirflowException
+from airflow.providers.google.cloud.sensors.dataproc import DataprocJobSensor
+
+from airflow.version import version as airflow_version
+
+AIRFLOW_VERSION = "v" + airflow_version.replace(".", "-").replace("+", "-")
+
+DATAPROC_PATH = "airflow.providers.google.cloud.sensors.dataproc.{}"
+
+TASK_ID = "task-id"
+GCP_PROJECT = "test-project"
+GCP_LOCATION = "test-location"
+GCP_CONN_ID = "test-conn"
+TIMEOUT = 120
+
+
+class TestDataprocJobSensor(unittest.TestCase):
+    def create_job(self, state: int):
+        job = mock.Mock()
+        job.status = mock.Mock()
+        job.status.state = state
+        return job
+
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_done(self, mock_hook):
+        job = self.create_job(JobStatus.DONE)
+        job_id = "job_id"
+        mock_hook.return_value.get_job.return_value = job
+
+        sensor = DataprocJobSensor(
+            task_id=TASK_ID,
+            location=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            dataproc_job_id=job_id,
+            gcp_conn_id=GCP_CONN_ID,
+            timeout=TIMEOUT,
+        )
+        ret = sensor.poke(context={})
+
+        mock_hook.return_value.get_job.assert_called_once_with(
+            job_id=job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
+        )
+        self.assertTrue(ret)
+
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_error(self, mock_hook):
+        job = self.create_job(JobStatus.ERROR)
+        job_id = "job_id"
+        mock_hook.return_value.get_job.return_value = job
+
+        sensor = DataprocJobSensor(
+            task_id=TASK_ID,
+            location=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            dataproc_job_id=job_id,
+            gcp_conn_id=GCP_CONN_ID,
+            timeout=TIMEOUT,
+        )
+
+        with self.assertRaisesRegex(AirflowException, "Job failed"):
+            sensor.poke(context={})
+
+        mock_hook.return_value.get_job.assert_called_once_with(
+            job_id=job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
+        )
+
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_wait(self, mock_hook):
+        job = self.create_job(JobStatus.RUNNING)
+        job_id = "job_id"
+        mock_hook.return_value.get_job.return_value = job
+
+        sensor = DataprocJobSensor(
+            task_id=TASK_ID,
+            location=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            dataproc_job_id=job_id,
+            gcp_conn_id=GCP_CONN_ID,
+            timeout=TIMEOUT,
+        )
+        ret = sensor.poke(context={})
+
+        mock_hook.return_value.get_job.assert_called_once_with(
+            job_id=job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
+        )
+        self.assertFalse(ret)
+
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_cancelled(self, mock_hook):
+        job = self.create_job(JobStatus.CANCELLED)
+        job_id = "job_id"
+        mock_hook.return_value.get_job.return_value = job
+
+        sensor = DataprocJobSensor(
+            task_id=TASK_ID,
+            location=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            dataproc_job_id=job_id,
+            gcp_conn_id=GCP_CONN_ID,
+            timeout=TIMEOUT,
+        )
+        with self.assertRaisesRegex(AirflowException, "Job was cancelled"):
+            sensor.poke(context={})
+
+        mock_hook.return_value.get_job.assert_called_once_with(
+            job_id=job_id, location=GCP_LOCATION, project_id=GCP_PROJECT
+        )

Reply via email to