This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a2581cc625b fix dataproc trigger (#53485)
a2581cc625b is described below

commit a2581cc625b06eab851962c796794b459d1cab84
Author: VladaZakharova <[email protected]>
AuthorDate: Mon Jul 21 13:09:26 2025 +0000

    fix dataproc trigger (#53485)
    
    # Conflicts:
    #
    providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
---
 .../providers/google/cloud/operators/dataproc.py   |  2 +-
 .../providers/google/cloud/triggers/dataproc.py    | 15 ++--
 .../unit/google/cloud/triggers/test_dataproc.py    | 86 ++++++++++++----------
 3 files changed, 58 insertions(+), 45 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py 
b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
index 349250150a5..647101bf33f 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
@@ -907,7 +907,7 @@ class 
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
         cluster_state = event["cluster_state"]
         cluster_name = event["cluster_name"]
 
-        if cluster_state == ClusterStatus.State.ERROR:
+        if cluster_state == 
ClusterStatus.State(ClusterStatus.State.DELETING).name:
             raise AirflowException(f"Cluster is in ERROR 
state:\n{cluster_name}")
 
         self.log.info("%s completed successfully.", self.task_id)
diff --git 
a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py 
b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
index 956e80efd00..c81fbaef922 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
@@ -316,8 +316,8 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
                     yield TriggerEvent(
                         {
                             "cluster_name": self.cluster_name,
-                            "cluster_state": ClusterStatus.State.DELETING,
-                            "cluster": cluster,
+                            "cluster_state": 
ClusterStatus.State(ClusterStatus.State.DELETING).name,
+                            "cluster": Cluster.to_dict(cluster),
                         }
                     )
                     return
@@ -325,14 +325,15 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
                     yield TriggerEvent(
                         {
                             "cluster_name": self.cluster_name,
-                            "cluster_state": state,
-                            "cluster": cluster,
+                            "cluster_state": ClusterStatus.State(state).name,
+                            "cluster": Cluster.to_dict(cluster),
                         }
                     )
                     return
-                self.log.info("Current state is %s", state)
-                self.log.info("Sleeping for %s seconds.", 
self.polling_interval_seconds)
-                await asyncio.sleep(self.polling_interval_seconds)
+                else:
+                    self.log.info("Current state is %s", state)
+                    self.log.info("Sleeping for %s seconds.", 
self.polling_interval_seconds)
+                    await asyncio.sleep(self.polling_interval_seconds)
         except asyncio.CancelledError:
             try:
                 if self.delete_on_error and await self.safe_to_cancel():
diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py 
b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
index 9d459c0fd69..690055903bb 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
@@ -18,6 +18,7 @@ from __future__ import annotations
 
 import asyncio
 import contextlib
+import logging
 from asyncio import CancelledError, Future, sleep
 from unittest import mock
 
@@ -50,6 +51,14 @@ TEST_POLL_INTERVAL = 5
 TEST_GCP_CONN_ID = "google_cloud_default"
 TEST_OPERATION_NAME = "name"
 TEST_JOB_ID = "test-job-id"
+TEST_RUNNING_CLUSTER = Cluster(
+    cluster_name=TEST_CLUSTER_NAME,
+    status=ClusterStatus(state=ClusterStatus.State.RUNNING),
+)
+TEST_ERROR_CLUSTER = Cluster(
+    cluster_name=TEST_CLUSTER_NAME,
+    status=ClusterStatus(state=ClusterStatus.State.ERROR),
+)
 
 
 @pytest.fixture
@@ -158,28 +167,56 @@ class TestDataprocClusterTrigger:
     @pytest.mark.db_test
     @pytest.mark.asyncio
     
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
-    @mock.patch.object(DataprocClusterTrigger, "log")
+    async def 
test_async_cluster_triggers_on_success_should_execute_successfully(
+        self, mock_get_async_hook, cluster_trigger
+    ):
+        future = asyncio.Future()
+        future.set_result(TEST_RUNNING_CLUSTER)
+        mock_get_async_hook.return_value.get_cluster.return_value = future
+
+        generator = cluster_trigger.run()
+        actual_event = await generator.asend(None)
+
+        expected_event = TriggerEvent(
+            {
+                "cluster_name": TEST_CLUSTER_NAME,
+                "cluster_state": 
ClusterStatus.State(ClusterStatus.State.RUNNING).name,
+                "cluster": actual_event.payload["cluster"],
+            }
+        )
+        assert expected_event == actual_event
+
+    @pytest.mark.db_test
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.fetch_cluster")
+    @mock.patch(
+        
"airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster",
+        return_value=asyncio.Future(),
+    )
+    @mock.patch("google.auth.default")
     async def test_async_cluster_trigger_run_returns_error_event(
-        self, mock_log, mock_get_async_hook, cluster_trigger
+        self, mock_auth, mock_delete_cluster, mock_fetch_cluster, 
cluster_trigger, async_get_cluster, caplog
     ):
-        # Mock delete_cluster to return a Future
-        mock_delete_future = asyncio.Future()
-        mock_delete_future.set_result(None)
-        mock_get_async_hook.return_value.delete_cluster.return_value = 
mock_delete_future
+        mock_credentials = mock.MagicMock()
+        mock_credentials.universe_domain = "googleapis.com"
 
-        mock_cluster = mock.MagicMock()
-        mock_cluster.status = ClusterStatus(state=ClusterStatus.State.ERROR)
+        mock_auth.return_value = (mock_credentials, "project-id")
 
-        future = asyncio.Future()
-        future.set_result(mock_cluster)
-        mock_get_async_hook.return_value.get_cluster.return_value = future
+        mock_delete_cluster.return_value = asyncio.Future()
+        mock_delete_cluster.return_value.set_result(None)
+
+        mock_fetch_cluster.return_value = TEST_ERROR_CLUSTER
+
+        caplog.set_level(logging.INFO)
 
         trigger_event = None
         async for event in cluster_trigger.run():
             trigger_event = event
 
         assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME
-        assert trigger_event.payload["cluster_state"] == 
ClusterStatus.State.DELETING
+        assert (
+            trigger_event.payload["cluster_state"] == 
ClusterStatus.State(ClusterStatus.State.DELETING).name
+        )
 
     @pytest.mark.db_test
     @pytest.mark.asyncio
@@ -321,31 +358,6 @@ class TestDataprocClusterTrigger:
         assert mock_delete_cluster.call_count == 0
         mock_delete_cluster.assert_not_called()
 
-    @pytest.mark.db_test
-    @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
-    async def 
test_async_cluster_triggers_on_success_should_execute_successfully(
-        self, mock_get_async_hook, cluster_trigger
-    ):
-        mock_cluster = mock.MagicMock()
-        mock_cluster.status = ClusterStatus(state=ClusterStatus.State.RUNNING)
-
-        future = asyncio.Future()
-        future.set_result(mock_cluster)
-        mock_get_async_hook.return_value.get_cluster.return_value = future
-
-        generator = cluster_trigger.run()
-        actual_event = await generator.asend(None)
-
-        expected_event = TriggerEvent(
-            {
-                "cluster_name": TEST_CLUSTER_NAME,
-                "cluster_state": ClusterStatus.State.RUNNING,
-                "cluster": actual_event.payload["cluster"],
-            }
-        )
-        assert expected_event == actual_event
-
 
 class TestDataprocBatchTrigger:
     def 
test_async_create_batch_trigger_serialization_should_execute_successfully(self, 
batch_trigger):

Reply via email to