This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9d9351728c Add deferrable mode to DataprocCreateBatchOperator (#28457)
9d9351728c is described below
commit 9d9351728cac9f9ed3bea0504dcfa8da15a7461b
Author: Beata Kossakowska <[email protected]>
AuthorDate: Mon Jan 30 12:16:28 2023 +0100
Add deferrable mode to DataprocCreateBatchOperator (#28457)
---
.../providers/google/cloud/operators/dataproc.py | 62 +++++++++--
.../providers/google/cloud/triggers/dataproc.py | 72 ++++++++++++-
.../operators/cloud/dataproc.rst | 8 ++
.../google/cloud/operators/test_dataproc.py | 47 +++++++-
.../google/cloud/triggers/test_dataproc.py | 118 ++++++++++++++++++++-
.../dataproc/example_dataproc_batch_deferrable.py | 90 ++++++++++++++++
6 files changed, 385 insertions(+), 12 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/dataproc.py
b/airflow/providers/google/cloud/operators/dataproc.py
index 316a15981f..d2fc241a28 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -50,7 +50,11 @@ from airflow.providers.google.cloud.links.dataproc import (
DataprocLink,
DataprocListLink,
)
-from airflow.providers.google.cloud.triggers.dataproc import
DataprocClusterTrigger, DataprocSubmitTrigger
+from airflow.providers.google.cloud.triggers.dataproc import (
+ DataprocBatchTrigger,
+ DataprocClusterTrigger,
+ DataprocSubmitTrigger,
+)
from airflow.utils import timezone
if TYPE_CHECKING:
@@ -2134,6 +2138,8 @@ class DataprocCreateBatchOperator(BaseOperator):
:param asynchronous: Flag to return after creating batch to the Dataproc
API.
This is useful for creating long-running batch and
waiting on them asynchronously using the DataprocBatchSensor
+ :param deferrable: Run operator in the deferrable mode.
+ :param polling_interval_seconds: Time (seconds) to wait between calls to
check the run status.
"""
template_fields: Sequence[str] = (
@@ -2151,7 +2157,7 @@ class DataprocCreateBatchOperator(BaseOperator):
region: str | None = None,
project_id: str | None = None,
batch: dict | Batch,
- batch_id: str | None = None,
+ batch_id: str,
request_id: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
@@ -2160,9 +2166,13 @@ class DataprocCreateBatchOperator(BaseOperator):
impersonation_chain: str | Sequence[str] | None = None,
result_retry: Retry | _MethodDefault = DEFAULT,
asynchronous: bool = False,
+ deferrable: bool = False,
+ polling_interval_seconds: int = 5,
**kwargs,
):
super().__init__(**kwargs)
+ if deferrable and polling_interval_seconds <= 0:
+ raise ValueError("Invalid value for polling_interval_seconds.
Expected value greater than 0")
self.region = region
self.project_id = project_id
self.batch = batch
@@ -2176,6 +2186,8 @@ class DataprocCreateBatchOperator(BaseOperator):
self.impersonation_chain = impersonation_chain
self.operation: operation.Operation | None = None
self.asynchronous = asynchronous
+ self.deferrable = deferrable
+ self.polling_interval_seconds = polling_interval_seconds
def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
@@ -2195,13 +2207,30 @@ class DataprocCreateBatchOperator(BaseOperator):
)
if self.operation is None:
raise RuntimeError("The operation should be set here!")
- if not self.asynchronous:
- result = hook.wait_for_operation(
- timeout=self.timeout, result_retry=self.result_retry,
operation=self.operation
- )
- self.log.info("Batch %s created", self.batch_id)
+
+ if not self.deferrable:
+ if not self.asynchronous:
+ result = hook.wait_for_operation(
+ timeout=self.timeout, result_retry=self.result_retry,
operation=self.operation
+ )
+ self.log.info("Batch %s created", self.batch_id)
+
+ else:
+ return self.operation.operation.name
+
else:
- return self.operation.operation.name
+ self.defer(
+ trigger=DataprocBatchTrigger(
+ batch_id=self.batch_id,
+ project_id=self.project_id,
+ region=self.region,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ polling_interval_seconds=self.polling_interval_seconds,
+ ),
+ method_name="execute_complete",
+ )
+
except AlreadyExists:
self.log.info("Batch with given id already exists")
if self.batch_id is None:
@@ -2233,6 +2262,23 @@ class DataprocCreateBatchOperator(BaseOperator):
DataprocLink.persist(context=context, task_instance=self,
url=DATAPROC_BATCH_LINK, resource=batch_id)
return Batch.to_dict(result)
+ def execute_complete(self, context, event=None) -> None:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes
execution was
+ successful.
+ """
+ if event is None:
+ raise AirflowException("Batch failed.")
+ batch_state = event["batch_state"]
+ batch_id = event["batch_id"]
+
+ if batch_state == Batch.State.FAILED:
+ raise AirflowException(f"Batch failed:\n{batch_id}")
+ if batch_state == Batch.State.CANCELLED:
+ raise AirflowException(f"Batch was cancelled:\n{batch_id}")
+ self.log.info("%s completed successfully.", self.task_id)
+
def on_kill(self):
if self.operation:
self.operation.cancel()
diff --git a/airflow/providers/google/cloud/triggers/dataproc.py
b/airflow/providers/google/cloud/triggers/dataproc.py
index 48d3666f4a..486434e9fa 100644
--- a/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/airflow/providers/google/cloud/triggers/dataproc.py
@@ -22,7 +22,7 @@ import asyncio
import warnings
from typing import Any, AsyncIterator, Sequence
-from google.cloud.dataproc_v1 import ClusterStatus, JobStatus
+from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus
from airflow import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
@@ -149,3 +149,73 @@ class DataprocClusterTrigger(BaseTrigger):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
+
+
+class DataprocBatchTrigger(BaseTrigger):
+ """
+ DataprocCreateBatchTrigger run on the trigger worker to perform create
Build operation
+
+ :param batch_id: The ID of the build.
+ :param project_id: Google Cloud Project where the job is running
+ :param region: The Cloud Dataproc region in which to handle the request.
+ :param gcp_conn_id: Optional, the connection ID used to connect to Google
Cloud Platform.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ :param polling_interval_seconds: polling period in seconds to check for
the status
+ """
+
+ def __init__(
+ self,
+ batch_id: str,
+ region: str,
+ project_id: str | None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ polling_interval_seconds: float = 5.0,
+ ):
+ super().__init__()
+ self.batch_id = batch_id
+ self.project_id = project_id
+ self.region = region
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.polling_interval_seconds = polling_interval_seconds
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes DataprocBatchTrigger arguments and classpath."""
+ return (
+
"airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger",
+ {
+ "batch_id": self.batch_id,
+ "project_id": self.project_id,
+ "region": self.region,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "polling_interval_seconds": self.polling_interval_seconds,
+ },
+ )
+
+ async def run(self):
+ hook = DataprocAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ while True:
+ batch = await hook.get_batch(
+ project_id=self.project_id, region=self.region,
batch_id=self.batch_id
+ )
+ state = batch.state
+
+ if state in (Batch.State.FAILED, Batch.State.SUCCEEDED,
Batch.State.CANCELLED):
+ break
+ self.log.info("Current state is %s", state)
+ self.log.info("Sleeping for %s seconds.",
self.polling_interval_seconds)
+ await asyncio.sleep(self.polling_interval_seconds)
+ yield TriggerEvent({"batch_id": self.batch_id, "batch_state": state})
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
index d613031042..8816bb4254 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
@@ -305,6 +305,14 @@ To check if operation succeeded you can use
:start-after: [START how_to_cloud_dataproc_batch_async_sensor]
:end-before: [END how_to_cloud_dataproc_batch_async_sensor]
+Also for all this action you can use operator in the deferrable mode:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataproc/example_dataproc_batch_deferrable.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_dataproc_create_batch_operator_async]
+ :end-before: [END how_to_cloud_dataproc_create_batch_operator_async]
+
Get a Batch
-----------
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py
b/tests/providers/google/cloud/operators/test_dataproc.py
index 70ebd991a8..929b1dd232 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -54,7 +54,11 @@ from airflow.providers.google.cloud.operators.dataproc
import (
DataprocSubmitSparkSqlJobOperator,
DataprocUpdateClusterOperator,
)
-from airflow.providers.google.cloud.triggers.dataproc import
DataprocClusterTrigger, DataprocSubmitTrigger
+from airflow.providers.google.cloud.triggers.dataproc import (
+ DataprocBatchTrigger,
+ DataprocClusterTrigger,
+ DataprocSubmitTrigger,
+)
from airflow.providers.google.common.consts import
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils.timezone import datetime
@@ -2032,3 +2036,44 @@ class TestDataprocListBatchesOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
+
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
+ def test_execute_deferrable(self, mock_trigger_hook, mock_hook):
+ mock_hook.return_value.submit_job.return_value.reference.job_id =
TEST_JOB_ID
+
+ op = DataprocCreateBatchOperator(
+ task_id=TASK_ID,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ batch=BATCH,
+ batch_id="batch_id",
+ gcp_conn_id=GCP_CONN_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ request_id=REQUEST_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ deferrable=True,
+ )
+ with pytest.raises(TaskDeferred) as exc:
+ op.execute(mock.MagicMock())
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ mock_hook.return_value.create_batch.assert_called_once_with(
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ batch_id="batch_id",
+ batch=BATCH,
+ request_id=REQUEST_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ mock_hook.return_value.wait_for_job.assert_not_called()
+
+ assert isinstance(exc.value.trigger, DataprocBatchTrigger)
+ assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py
b/tests/providers/google/cloud/triggers/test_dataproc.py
index 854c02c0de..90ea36c31d 100644
--- a/tests/providers/google/cloud/triggers/test_dataproc.py
+++ b/tests/providers/google/cloud/triggers/test_dataproc.py
@@ -18,16 +18,24 @@ from __future__ import annotations
import asyncio
import logging
+from asyncio import Future
import pytest
-from google.cloud.dataproc_v1 import ClusterStatus
+from google.cloud.dataproc_v1 import Batch, ClusterStatus
-from airflow.providers.google.cloud.triggers.dataproc import
DataprocClusterTrigger
+from airflow.providers.google.cloud.triggers.dataproc import
DataprocBatchTrigger, DataprocClusterTrigger
from airflow.triggers.base import TriggerEvent
from tests.providers.google.cloud.utils.compat import async_mock
TEST_PROJECT_ID = "project-id"
TEST_REGION = "region"
+TEST_BATCH_ID = "batch-id"
+BATCH_CONFIG = {
+ "spark_batch": {
+ "jar_file_uris":
["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
+ "main_class": "org.apache.spark.examples.SparkPi",
+ },
+}
TEST_CLUSTER_NAME = "cluster_name"
TEST_POLL_INTERVAL = 5
TEST_GCP_CONN_ID = "google_cloud_default"
@@ -45,6 +53,19 @@ def trigger():
)
[email protected]
+def batch_trigger():
+ trigger = DataprocBatchTrigger(
+ batch_id=TEST_BATCH_ID,
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ impersonation_chain=None,
+ polling_interval_seconds=TEST_POLL_INTERVAL,
+ )
+ return trigger
+
+
@pytest.fixture()
def async_get_cluster():
def func(**kwargs):
@@ -57,6 +78,18 @@ def async_get_cluster():
return func
[email protected]()
+def async_get_batch():
+ def func(**kwargs):
+ m = async_mock.MagicMock()
+ m.configure_mock(**kwargs)
+ f = Future()
+ f.set_result(m)
+ return f
+
+ return func
+
+
class TestDataprocClusterTrigger:
def
test_async_cluster_trigger_serialization_should_execute_successfully(self,
trigger):
classpath, kwargs = trigger.serialize()
@@ -134,3 +167,84 @@ class TestDataprocClusterTrigger:
assert not task.done()
assert f"Current state is: {ClusterStatus.State.CREATING}"
assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
+
+
+class TestDataprocBatchTrigger:
+ def
test_async_create_batch_trigger_serialization_should_execute_successfully(self,
batch_trigger):
+ """
+ Asserts that the DataprocBatchTrigger correctly serializes its
arguments
+ and classpath.
+ """
+
+ classpath, kwargs = batch_trigger.serialize()
+ assert classpath ==
"airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger"
+ assert kwargs == {
+ "batch_id": TEST_BATCH_ID,
+ "project_id": TEST_PROJECT_ID,
+ "region": TEST_REGION,
+ "gcp_conn_id": TEST_GCP_CONN_ID,
+ "impersonation_chain": None,
+ "polling_interval_seconds": TEST_POLL_INTERVAL,
+ }
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
+ async def
test_async_create_batch_trigger_triggers_on_success_should_execute_successfully(
+ self, mock_hook, batch_trigger, async_get_batch
+ ):
+ """
+ Tests the DataprocBatchTrigger only fires once the batch execution
reaches a successful state.
+ """
+
+ mock_hook.return_value = async_get_batch(state=Batch.State.SUCCEEDED,
batch_id=TEST_BATCH_ID)
+
+ expected_event = TriggerEvent(
+ {
+ "batch_id": TEST_BATCH_ID,
+ "batch_state": Batch.State.SUCCEEDED,
+ }
+ )
+
+ actual_event = await (batch_trigger.run()).asend(None)
+ await asyncio.sleep(0.5)
+ assert expected_event == actual_event
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
+ async def test_async_create_batch_trigger_run_returns_failed_event(
+ self, mock_hook, batch_trigger, async_get_batch
+ ):
+ mock_hook.return_value = async_get_batch(state=Batch.State.FAILED,
batch_id=TEST_BATCH_ID)
+
+ expected_event = TriggerEvent({"batch_id": TEST_BATCH_ID,
"batch_state": Batch.State.FAILED})
+
+ actual_event = await (batch_trigger.run()).asend(None)
+ await asyncio.sleep(0.5)
+ assert expected_event == actual_event
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
+ async def test_create_batch_run_returns_cancelled_event(self, mock_hook,
batch_trigger, async_get_batch):
+ mock_hook.return_value = async_get_batch(state=Batch.State.CANCELLED,
batch_id=TEST_BATCH_ID)
+
+ expected_event = TriggerEvent({"batch_id": TEST_BATCH_ID,
"batch_state": Batch.State.CANCELLED})
+
+ actual_event = await (batch_trigger.run()).asend(None)
+ await asyncio.sleep(0.5)
+ assert expected_event == actual_event
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
+ async def test_create_batch_run_loop_is_still_running(
+ self, mock_hook, batch_trigger, caplog, async_get_batch
+ ):
+ mock_hook.return_value = async_get_batch(state=Batch.State.RUNNING)
+
+ caplog.set_level(logging.INFO)
+
+ task = asyncio.create_task(batch_trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ assert not task.done()
+ assert f"Current state is: {Batch.State.RUNNING}"
+ assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
diff --git
a/tests/system/providers/google/cloud/dataproc/example_dataproc_batch_deferrable.py
b/tests/system/providers/google/cloud/dataproc/example_dataproc_batch_deferrable.py
new file mode 100644
index 0000000000..2d363328c4
--- /dev/null
+++
b/tests/system/providers/google/cloud/dataproc/example_dataproc_batch_deferrable.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG for DataprocSubmitJobOperator with spark job
+in deferrable mode.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow import models
+from airflow.providers.google.cloud.operators.dataproc import (
+ DataprocCreateBatchOperator,
+ DataprocDeleteBatchOperator,
+ DataprocGetBatchOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = "dataproc_batch_deferrable"
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
+REGION = "europe-west1"
+BATCH_ID = f"test-def-batch-id-{ENV_ID}"
+BATCH_CONFIG = {
+ "spark_batch": {
+ "jar_file_uris":
["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
+ "main_class": "org.apache.spark.examples.SparkPi",
+ },
+}
+
+
+with models.DAG(
+ DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "dataproc"],
+) as dag:
+ # [START how_to_cloud_dataproc_create_batch_operator_async]
+ create_batch = DataprocCreateBatchOperator(
+ task_id="create_batch",
+ project_id=PROJECT_ID,
+ region=REGION,
+ batch=BATCH_CONFIG,
+ batch_id=BATCH_ID,
+ deferrable=True,
+ )
+ # [END how_to_cloud_dataproc_create_batch_operator_async]
+
+ get_batch = DataprocGetBatchOperator(
+ task_id="get_batch", project_id=PROJECT_ID, region=REGION,
batch_id=BATCH_ID
+ )
+
+ delete_batch = DataprocDeleteBatchOperator(
+ task_id="delete_batch",
+ project_id=PROJECT_ID,
+ region=REGION,
+ batch_id=BATCH_ID,
+ )
+ delete_batch.trigger_rule = TriggerRule.ALL_DONE
+
+ (create_batch >> get_batch >> delete_batch)
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "teardown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)