This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9d9351728c Add deferrable mode to DataprocCreateBatchOperator (#28457)
9d9351728c is described below

commit 9d9351728cac9f9ed3bea0504dcfa8da15a7461b
Author: Beata Kossakowska <[email protected]>
AuthorDate: Mon Jan 30 12:16:28 2023 +0100

    Add deferrable mode to DataprocCreateBatchOperator (#28457)
---
 .../providers/google/cloud/operators/dataproc.py   |  62 +++++++++--
 .../providers/google/cloud/triggers/dataproc.py    |  72 ++++++++++++-
 .../operators/cloud/dataproc.rst                   |   8 ++
 .../google/cloud/operators/test_dataproc.py        |  47 +++++++-
 .../google/cloud/triggers/test_dataproc.py         | 118 ++++++++++++++++++++-
 .../dataproc/example_dataproc_batch_deferrable.py  |  90 ++++++++++++++++
 6 files changed, 385 insertions(+), 12 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/dataproc.py 
b/airflow/providers/google/cloud/operators/dataproc.py
index 316a15981f..d2fc241a28 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -50,7 +50,11 @@ from airflow.providers.google.cloud.links.dataproc import (
     DataprocLink,
     DataprocListLink,
 )
-from airflow.providers.google.cloud.triggers.dataproc import 
DataprocClusterTrigger, DataprocSubmitTrigger
+from airflow.providers.google.cloud.triggers.dataproc import (
+    DataprocBatchTrigger,
+    DataprocClusterTrigger,
+    DataprocSubmitTrigger,
+)
 from airflow.utils import timezone
 
 if TYPE_CHECKING:
@@ -2134,6 +2138,8 @@ class DataprocCreateBatchOperator(BaseOperator):
     :param asynchronous: Flag to return after creating batch to the Dataproc 
API.
         This is useful for creating long-running batch and
         waiting on them asynchronously using the DataprocBatchSensor
+    :param deferrable: Run operator in the deferrable mode.
+    :param polling_interval_seconds: Time (seconds) to wait between calls to 
check the run status.
     """
 
     template_fields: Sequence[str] = (
@@ -2151,7 +2157,7 @@ class DataprocCreateBatchOperator(BaseOperator):
         region: str | None = None,
         project_id: str | None = None,
         batch: dict | Batch,
-        batch_id: str | None = None,
+        batch_id: str,
         request_id: str | None = None,
         retry: Retry | _MethodDefault = DEFAULT,
         timeout: float | None = None,
@@ -2160,9 +2166,13 @@ class DataprocCreateBatchOperator(BaseOperator):
         impersonation_chain: str | Sequence[str] | None = None,
         result_retry: Retry | _MethodDefault = DEFAULT,
         asynchronous: bool = False,
+        deferrable: bool = False,
+        polling_interval_seconds: int = 5,
         **kwargs,
     ):
         super().__init__(**kwargs)
+        if deferrable and polling_interval_seconds <= 0:
+            raise ValueError("Invalid value for polling_interval_seconds. 
Expected value greater than 0")
         self.region = region
         self.project_id = project_id
         self.batch = batch
@@ -2176,6 +2186,8 @@ class DataprocCreateBatchOperator(BaseOperator):
         self.impersonation_chain = impersonation_chain
         self.operation: operation.Operation | None = None
         self.asynchronous = asynchronous
+        self.deferrable = deferrable
+        self.polling_interval_seconds = polling_interval_seconds
 
     def execute(self, context: Context):
         hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, 
impersonation_chain=self.impersonation_chain)
@@ -2195,13 +2207,30 @@ class DataprocCreateBatchOperator(BaseOperator):
             )
             if self.operation is None:
                 raise RuntimeError("The operation should be set here!")
-            if not self.asynchronous:
-                result = hook.wait_for_operation(
-                    timeout=self.timeout, result_retry=self.result_retry, 
operation=self.operation
-                )
-                self.log.info("Batch %s created", self.batch_id)
+
+            if not self.deferrable:
+                if not self.asynchronous:
+                    result = hook.wait_for_operation(
+                        timeout=self.timeout, result_retry=self.result_retry, 
operation=self.operation
+                    )
+                    self.log.info("Batch %s created", self.batch_id)
+
+                else:
+                    return self.operation.operation.name
+
             else:
-                return self.operation.operation.name
+                self.defer(
+                    trigger=DataprocBatchTrigger(
+                        batch_id=self.batch_id,
+                        project_id=self.project_id,
+                        region=self.region,
+                        gcp_conn_id=self.gcp_conn_id,
+                        impersonation_chain=self.impersonation_chain,
+                        polling_interval_seconds=self.polling_interval_seconds,
+                    ),
+                    method_name="execute_complete",
+                )
+
         except AlreadyExists:
             self.log.info("Batch with given id already exists")
             if self.batch_id is None:
@@ -2233,6 +2262,23 @@ class DataprocCreateBatchOperator(BaseOperator):
         DataprocLink.persist(context=context, task_instance=self, 
url=DATAPROC_BATCH_LINK, resource=batch_id)
         return Batch.to_dict(result)
 
+    def execute_complete(self, context, event=None) -> None:
+        """
+        Callback for when the trigger fires - returns immediately.
+        Relies on trigger to throw an exception, otherwise it assumes 
execution was
+        successful.
+        """
+        if event is None:
+            raise AirflowException("Batch failed.")
+        batch_state = event["batch_state"]
+        batch_id = event["batch_id"]
+
+        if batch_state == Batch.State.FAILED:
+            raise AirflowException(f"Batch failed:\n{batch_id}")
+        if batch_state == Batch.State.CANCELLED:
+            raise AirflowException(f"Batch was cancelled:\n{batch_id}")
+        self.log.info("%s completed successfully.", self.task_id)
+
     def on_kill(self):
         if self.operation:
             self.operation.cancel()
diff --git a/airflow/providers/google/cloud/triggers/dataproc.py 
b/airflow/providers/google/cloud/triggers/dataproc.py
index 48d3666f4a..486434e9fa 100644
--- a/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/airflow/providers/google/cloud/triggers/dataproc.py
@@ -22,7 +22,7 @@ import asyncio
 import warnings
 from typing import Any, AsyncIterator, Sequence
 
-from google.cloud.dataproc_v1 import ClusterStatus, JobStatus
+from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus
 
 from airflow import AirflowException
 from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
@@ -149,3 +149,73 @@ class DataprocClusterTrigger(BaseTrigger):
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
         )
+
+
+class DataprocBatchTrigger(BaseTrigger):
+    """
+    DataprocCreateBatchTrigger run on the trigger worker to perform create 
Build operation
+
+    :param batch_id: The ID of the build.
+    :param project_id: Google Cloud Project where the job is running
+    :param region: The Cloud Dataproc region in which to handle the request.
+    :param gcp_conn_id: Optional, the connection ID used to connect to Google 
Cloud Platform.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    :param polling_interval_seconds: polling period in seconds to check for 
the status
+    """
+
+    def __init__(
+        self,
+        batch_id: str,
+        region: str,
+        project_id: str | None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        polling_interval_seconds: float = 5.0,
+    ):
+        super().__init__()
+        self.batch_id = batch_id
+        self.project_id = project_id
+        self.region = region
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.polling_interval_seconds = polling_interval_seconds
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes DataprocBatchTrigger arguments and classpath."""
+        return (
+            
"airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger",
+            {
+                "batch_id": self.batch_id,
+                "project_id": self.project_id,
+                "region": self.region,
+                "gcp_conn_id": self.gcp_conn_id,
+                "impersonation_chain": self.impersonation_chain,
+                "polling_interval_seconds": self.polling_interval_seconds,
+            },
+        )
+
+    async def run(self):
+        hook = DataprocAsyncHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
+
+        while True:
+            batch = await hook.get_batch(
+                project_id=self.project_id, region=self.region, 
batch_id=self.batch_id
+            )
+            state = batch.state
+
+            if state in (Batch.State.FAILED, Batch.State.SUCCEEDED, 
Batch.State.CANCELLED):
+                break
+            self.log.info("Current state is %s", state)
+            self.log.info("Sleeping for %s seconds.", 
self.polling_interval_seconds)
+            await asyncio.sleep(self.polling_interval_seconds)
+        yield TriggerEvent({"batch_id": self.batch_id, "batch_state": state})
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst 
b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
index d613031042..8816bb4254 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
@@ -305,6 +305,14 @@ To check if operation succeeded you can use
     :start-after: [START how_to_cloud_dataproc_batch_async_sensor]
     :end-before: [END how_to_cloud_dataproc_batch_async_sensor]
 
+Also for all this action you can use operator in the deferrable mode:
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/dataproc/example_dataproc_batch_deferrable.py
+    :language: python
+    :dedent: 4
+    :start-after: [START how_to_cloud_dataproc_create_batch_operator_async]
+    :end-before: [END how_to_cloud_dataproc_create_batch_operator_async]
+
 Get a Batch
 -----------
 
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py 
b/tests/providers/google/cloud/operators/test_dataproc.py
index 70ebd991a8..929b1dd232 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -54,7 +54,11 @@ from airflow.providers.google.cloud.operators.dataproc 
import (
     DataprocSubmitSparkSqlJobOperator,
     DataprocUpdateClusterOperator,
 )
-from airflow.providers.google.cloud.triggers.dataproc import 
DataprocClusterTrigger, DataprocSubmitTrigger
+from airflow.providers.google.cloud.triggers.dataproc import (
+    DataprocBatchTrigger,
+    DataprocClusterTrigger,
+    DataprocSubmitTrigger,
+)
 from airflow.providers.google.common.consts import 
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 from airflow.serialization.serialized_objects import SerializedDAG
 from airflow.utils.timezone import datetime
@@ -2032,3 +2036,44 @@ class TestDataprocListBatchesOperator:
             timeout=TIMEOUT,
             metadata=METADATA,
         )
+
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
+    def test_execute_deferrable(self, mock_trigger_hook, mock_hook):
+        mock_hook.return_value.submit_job.return_value.reference.job_id = 
TEST_JOB_ID
+
+        op = DataprocCreateBatchOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            batch=BATCH,
+            batch_id="batch_id",
+            gcp_conn_id=GCP_CONN_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            request_id=REQUEST_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            deferrable=True,
+        )
+        with pytest.raises(TaskDeferred) as exc:
+            op.execute(mock.MagicMock())
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        mock_hook.return_value.create_batch.assert_called_once_with(
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            batch_id="batch_id",
+            batch=BATCH,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+        mock_hook.return_value.wait_for_job.assert_not_called()
+
+        assert isinstance(exc.value.trigger, DataprocBatchTrigger)
+        assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py 
b/tests/providers/google/cloud/triggers/test_dataproc.py
index 854c02c0de..90ea36c31d 100644
--- a/tests/providers/google/cloud/triggers/test_dataproc.py
+++ b/tests/providers/google/cloud/triggers/test_dataproc.py
@@ -18,16 +18,24 @@ from __future__ import annotations
 
 import asyncio
 import logging
+from asyncio import Future
 
 import pytest
-from google.cloud.dataproc_v1 import ClusterStatus
+from google.cloud.dataproc_v1 import Batch, ClusterStatus
 
-from airflow.providers.google.cloud.triggers.dataproc import 
DataprocClusterTrigger
+from airflow.providers.google.cloud.triggers.dataproc import 
DataprocBatchTrigger, DataprocClusterTrigger
 from airflow.triggers.base import TriggerEvent
 from tests.providers.google.cloud.utils.compat import async_mock
 
 TEST_PROJECT_ID = "project-id"
 TEST_REGION = "region"
+TEST_BATCH_ID = "batch-id"
+BATCH_CONFIG = {
+    "spark_batch": {
+        "jar_file_uris": 
["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
+        "main_class": "org.apache.spark.examples.SparkPi",
+    },
+}
 TEST_CLUSTER_NAME = "cluster_name"
 TEST_POLL_INTERVAL = 5
 TEST_GCP_CONN_ID = "google_cloud_default"
@@ -45,6 +53,19 @@ def trigger():
     )
 
 
[email protected]
+def batch_trigger():
+    trigger = DataprocBatchTrigger(
+        batch_id=TEST_BATCH_ID,
+        project_id=TEST_PROJECT_ID,
+        region=TEST_REGION,
+        gcp_conn_id=TEST_GCP_CONN_ID,
+        impersonation_chain=None,
+        polling_interval_seconds=TEST_POLL_INTERVAL,
+    )
+    return trigger
+
+
 @pytest.fixture()
 def async_get_cluster():
     def func(**kwargs):
@@ -57,6 +78,18 @@ def async_get_cluster():
     return func
 
 
[email protected]()
+def async_get_batch():
+    def func(**kwargs):
+        m = async_mock.MagicMock()
+        m.configure_mock(**kwargs)
+        f = Future()
+        f.set_result(m)
+        return f
+
+    return func
+
+
 class TestDataprocClusterTrigger:
     def 
test_async_cluster_trigger_serialization_should_execute_successfully(self, 
trigger):
         classpath, kwargs = trigger.serialize()
@@ -134,3 +167,84 @@ class TestDataprocClusterTrigger:
         assert not task.done()
         assert f"Current state is: {ClusterStatus.State.CREATING}"
         assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
+
+
+class TestDataprocBatchTrigger:
+    def 
test_async_create_batch_trigger_serialization_should_execute_successfully(self, 
batch_trigger):
+        """
+        Asserts that the DataprocBatchTrigger correctly serializes its 
arguments
+        and classpath.
+        """
+
+        classpath, kwargs = batch_trigger.serialize()
+        assert classpath == 
"airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger"
+        assert kwargs == {
+            "batch_id": TEST_BATCH_ID,
+            "project_id": TEST_PROJECT_ID,
+            "region": TEST_REGION,
+            "gcp_conn_id": TEST_GCP_CONN_ID,
+            "impersonation_chain": None,
+            "polling_interval_seconds": TEST_POLL_INTERVAL,
+        }
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
+    async def 
test_async_create_batch_trigger_triggers_on_success_should_execute_successfully(
+        self, mock_hook, batch_trigger, async_get_batch
+    ):
+        """
+        Tests the DataprocBatchTrigger only fires once the batch execution 
reaches a successful state.
+        """
+
+        mock_hook.return_value = async_get_batch(state=Batch.State.SUCCEEDED, 
batch_id=TEST_BATCH_ID)
+
+        expected_event = TriggerEvent(
+            {
+                "batch_id": TEST_BATCH_ID,
+                "batch_state": Batch.State.SUCCEEDED,
+            }
+        )
+
+        actual_event = await (batch_trigger.run()).asend(None)
+        await asyncio.sleep(0.5)
+        assert expected_event == actual_event
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
+    async def test_async_create_batch_trigger_run_returns_failed_event(
+        self, mock_hook, batch_trigger, async_get_batch
+    ):
+        mock_hook.return_value = async_get_batch(state=Batch.State.FAILED, 
batch_id=TEST_BATCH_ID)
+
+        expected_event = TriggerEvent({"batch_id": TEST_BATCH_ID, 
"batch_state": Batch.State.FAILED})
+
+        actual_event = await (batch_trigger.run()).asend(None)
+        await asyncio.sleep(0.5)
+        assert expected_event == actual_event
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
+    async def test_create_batch_run_returns_cancelled_event(self, mock_hook, 
batch_trigger, async_get_batch):
+        mock_hook.return_value = async_get_batch(state=Batch.State.CANCELLED, 
batch_id=TEST_BATCH_ID)
+
+        expected_event = TriggerEvent({"batch_id": TEST_BATCH_ID, 
"batch_state": Batch.State.CANCELLED})
+
+        actual_event = await (batch_trigger.run()).asend(None)
+        await asyncio.sleep(0.5)
+        assert expected_event == actual_event
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
+    async def test_create_batch_run_loop_is_still_running(
+        self, mock_hook, batch_trigger, caplog, async_get_batch
+    ):
+        mock_hook.return_value = async_get_batch(state=Batch.State.RUNNING)
+
+        caplog.set_level(logging.INFO)
+
+        task = asyncio.create_task(batch_trigger.run().__anext__())
+        await asyncio.sleep(0.5)
+
+        assert not task.done()
+        assert f"Current state is: {Batch.State.RUNNING}"
+        assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
diff --git 
a/tests/system/providers/google/cloud/dataproc/example_dataproc_batch_deferrable.py
 
b/tests/system/providers/google/cloud/dataproc/example_dataproc_batch_deferrable.py
new file mode 100644
index 0000000000..2d363328c4
--- /dev/null
+++ 
b/tests/system/providers/google/cloud/dataproc/example_dataproc_batch_deferrable.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG for DataprocSubmitJobOperator with spark job
+in deferrable mode.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow import models
+from airflow.providers.google.cloud.operators.dataproc import (
+    DataprocCreateBatchOperator,
+    DataprocDeleteBatchOperator,
+    DataprocGetBatchOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = "dataproc_batch_deferrable"
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
+REGION = "europe-west1"
+BATCH_ID = f"test-def-batch-id-{ENV_ID}"
+BATCH_CONFIG = {
+    "spark_batch": {
+        "jar_file_uris": 
["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
+        "main_class": "org.apache.spark.examples.SparkPi",
+    },
+}
+
+
+with models.DAG(
+    DAG_ID,
+    schedule="@once",
+    start_date=datetime(2021, 1, 1),
+    catchup=False,
+    tags=["example", "dataproc"],
+) as dag:
+    # [START how_to_cloud_dataproc_create_batch_operator_async]
+    create_batch = DataprocCreateBatchOperator(
+        task_id="create_batch",
+        project_id=PROJECT_ID,
+        region=REGION,
+        batch=BATCH_CONFIG,
+        batch_id=BATCH_ID,
+        deferrable=True,
+    )
+    # [END how_to_cloud_dataproc_create_batch_operator_async]
+
+    get_batch = DataprocGetBatchOperator(
+        task_id="get_batch", project_id=PROJECT_ID, region=REGION, 
batch_id=BATCH_ID
+    )
+
+    delete_batch = DataprocDeleteBatchOperator(
+        task_id="delete_batch",
+        project_id=PROJECT_ID,
+        region=REGION,
+        batch_id=BATCH_ID,
+    )
+    delete_batch.trigger_rule = TriggerRule.ALL_DONE
+
+    (create_batch >> get_batch >> delete_batch)
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "teardown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)

Reply via email to