This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a2581cc625b fix dataproc trigger (#53485)
a2581cc625b is described below
commit a2581cc625b06eab851962c796794b459d1cab84
Author: VladaZakharova <[email protected]>
AuthorDate: Mon Jul 21 13:09:26 2025 +0000
fix dataproc trigger (#53485)
# Conflicts:
#
providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
---
.../providers/google/cloud/operators/dataproc.py | 2 +-
.../providers/google/cloud/triggers/dataproc.py | 15 ++--
.../unit/google/cloud/triggers/test_dataproc.py | 86 ++++++++++++----------
3 files changed, 58 insertions(+), 45 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
index 349250150a5..647101bf33f 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
@@ -907,7 +907,7 @@ class
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
cluster_state = event["cluster_state"]
cluster_name = event["cluster_name"]
- if cluster_state == ClusterStatus.State.ERROR:
+ if cluster_state ==
ClusterStatus.State(ClusterStatus.State.DELETING).name:
raise AirflowException(f"Cluster is in ERROR
state:\n{cluster_name}")
self.log.info("%s completed successfully.", self.task_id)
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
index 956e80efd00..c81fbaef922 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
@@ -316,8 +316,8 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
yield TriggerEvent(
{
"cluster_name": self.cluster_name,
- "cluster_state": ClusterStatus.State.DELETING,
- "cluster": cluster,
+ "cluster_state":
ClusterStatus.State(ClusterStatus.State.DELETING).name,
+ "cluster": Cluster.to_dict(cluster),
}
)
return
@@ -325,14 +325,15 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
yield TriggerEvent(
{
"cluster_name": self.cluster_name,
- "cluster_state": state,
- "cluster": cluster,
+ "cluster_state": ClusterStatus.State(state).name,
+ "cluster": Cluster.to_dict(cluster),
}
)
return
- self.log.info("Current state is %s", state)
- self.log.info("Sleeping for %s seconds.",
self.polling_interval_seconds)
- await asyncio.sleep(self.polling_interval_seconds)
+ else:
+ self.log.info("Current state is %s", state)
+ self.log.info("Sleeping for %s seconds.",
self.polling_interval_seconds)
+ await asyncio.sleep(self.polling_interval_seconds)
except asyncio.CancelledError:
try:
if self.delete_on_error and await self.safe_to_cancel():
diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
index 9d459c0fd69..690055903bb 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
@@ -18,6 +18,7 @@ from __future__ import annotations
import asyncio
import contextlib
+import logging
from asyncio import CancelledError, Future, sleep
from unittest import mock
@@ -50,6 +51,14 @@ TEST_POLL_INTERVAL = 5
TEST_GCP_CONN_ID = "google_cloud_default"
TEST_OPERATION_NAME = "name"
TEST_JOB_ID = "test-job-id"
+TEST_RUNNING_CLUSTER = Cluster(
+ cluster_name=TEST_CLUSTER_NAME,
+ status=ClusterStatus(state=ClusterStatus.State.RUNNING),
+)
+TEST_ERROR_CLUSTER = Cluster(
+ cluster_name=TEST_CLUSTER_NAME,
+ status=ClusterStatus(state=ClusterStatus.State.ERROR),
+)
@pytest.fixture
@@ -158,28 +167,56 @@ class TestDataprocClusterTrigger:
@pytest.mark.db_test
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
- @mock.patch.object(DataprocClusterTrigger, "log")
+ async def
test_async_cluster_triggers_on_success_should_execute_successfully(
+ self, mock_get_async_hook, cluster_trigger
+ ):
+ future = asyncio.Future()
+ future.set_result(TEST_RUNNING_CLUSTER)
+ mock_get_async_hook.return_value.get_cluster.return_value = future
+
+ generator = cluster_trigger.run()
+ actual_event = await generator.asend(None)
+
+ expected_event = TriggerEvent(
+ {
+ "cluster_name": TEST_CLUSTER_NAME,
+ "cluster_state":
ClusterStatus.State(ClusterStatus.State.RUNNING).name,
+ "cluster": actual_event.payload["cluster"],
+ }
+ )
+ assert expected_event == actual_event
+
+ @pytest.mark.db_test
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.fetch_cluster")
+ @mock.patch(
+
"airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster",
+ return_value=asyncio.Future(),
+ )
+ @mock.patch("google.auth.default")
async def test_async_cluster_trigger_run_returns_error_event(
- self, mock_log, mock_get_async_hook, cluster_trigger
+ self, mock_auth, mock_delete_cluster, mock_fetch_cluster,
cluster_trigger, async_get_cluster, caplog
):
- # Mock delete_cluster to return a Future
- mock_delete_future = asyncio.Future()
- mock_delete_future.set_result(None)
- mock_get_async_hook.return_value.delete_cluster.return_value =
mock_delete_future
+ mock_credentials = mock.MagicMock()
+ mock_credentials.universe_domain = "googleapis.com"
- mock_cluster = mock.MagicMock()
- mock_cluster.status = ClusterStatus(state=ClusterStatus.State.ERROR)
+ mock_auth.return_value = (mock_credentials, "project-id")
- future = asyncio.Future()
- future.set_result(mock_cluster)
- mock_get_async_hook.return_value.get_cluster.return_value = future
+ mock_delete_cluster.return_value = asyncio.Future()
+ mock_delete_cluster.return_value.set_result(None)
+
+ mock_fetch_cluster.return_value = TEST_ERROR_CLUSTER
+
+ caplog.set_level(logging.INFO)
trigger_event = None
async for event in cluster_trigger.run():
trigger_event = event
assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME
- assert trigger_event.payload["cluster_state"] ==
ClusterStatus.State.DELETING
+ assert (
+ trigger_event.payload["cluster_state"] ==
ClusterStatus.State(ClusterStatus.State.DELETING).name
+ )
@pytest.mark.db_test
@pytest.mark.asyncio
@@ -321,31 +358,6 @@ class TestDataprocClusterTrigger:
assert mock_delete_cluster.call_count == 0
mock_delete_cluster.assert_not_called()
- @pytest.mark.db_test
- @pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
- async def
test_async_cluster_triggers_on_success_should_execute_successfully(
- self, mock_get_async_hook, cluster_trigger
- ):
- mock_cluster = mock.MagicMock()
- mock_cluster.status = ClusterStatus(state=ClusterStatus.State.RUNNING)
-
- future = asyncio.Future()
- future.set_result(mock_cluster)
- mock_get_async_hook.return_value.get_cluster.return_value = future
-
- generator = cluster_trigger.run()
- actual_event = await generator.asend(None)
-
- expected_event = TriggerEvent(
- {
- "cluster_name": TEST_CLUSTER_NAME,
- "cluster_state": ClusterStatus.State.RUNNING,
- "cluster": actual_event.payload["cluster"],
- }
- )
- assert expected_event == actual_event
-
class TestDataprocBatchTrigger:
def
test_async_create_batch_trigger_serialization_should_execute_successfully(self,
batch_trigger):