This is an automated email from the ASF dual-hosted git repository.

shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 511dc0014de Fix the google cloud provider RayHook protobuf 
compatibility (#54014)
511dc0014de is described below

commit 511dc0014dea9153fb8480ade91ac234003d761c
Author: olegkachur-e <[email protected]>
AuthorDate: Fri Aug 1 11:21:29 2025 +0000

    Fix the google cloud provider RayHook protobuf compatibility (#54014)
    
    - The previously used `google._upb._message.ScalarMapContainer` is not
      available in protobuf >= 5.*, which breaks RayHook and operators.
    
    Co-authored-by: Oleg Kachur <[email protected]>
---
 .../providers/google/cloud/hooks/vertex_ai/ray.py  | 15 ++----
 .../google/cloud/operators/vertex_ai/ray.py        |  2 +-
 .../unit/google/cloud/hooks/vertex_ai/test_ray.py  | 53 ++++++++++++++++++++--
 3 files changed, 54 insertions(+), 16 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py 
b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py
index 5aede4a0465..76d94299d49 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py
@@ -20,19 +20,10 @@
 from __future__ import annotations
 
 import dataclasses
+from collections.abc import MutableMapping
 from typing import Any
 
-from airflow.exceptions import AirflowOptionalProviderFeatureException
-
-try:
-    import vertex_ray
-    from google._upb._message import ScalarMapContainer  # type: 
ignore[attr-defined]
-except ImportError:
-    # Fallback for environments where the upb module is not available.
-    raise AirflowOptionalProviderFeatureException(
-        "google._upb._message.ScalarMapContainer is not available. "
-        "Please install the ray package to use this feature."
-    )
+import vertex_ray
 from google.cloud import aiplatform
 from google.cloud.aiplatform.vertex_ray.util import resources
 from google.cloud.aiplatform_v1 import (
@@ -59,7 +50,7 @@ class RayHook(GoogleBaseHook):
         def __encode_value(value: Any) -> Any:
             if isinstance(value, (list, Repeated)):
                 return [__encode_value(nested_value) for nested_value in value]
-            if isinstance(value, ScalarMapContainer):
+            if not isinstance(value, dict) and isinstance(value, 
MutableMapping):
                 return {key: __encode_value(nested_value) for key, 
nested_value in dict(value).items()}
             if dataclasses.is_dataclass(value):
                 return dataclasses.asdict(value)
diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
 
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
index e06c18ea1c4..95a284c1def 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
@@ -282,7 +282,7 @@ class GetRayClusterOperator(RayBaseOperator):
                 location=self.location,
                 cluster_id=self.cluster_id,
             )
-            self.log.info("Cluster was gotten.")
+            self.log.info("Cluster data has been retrieved.")
             ray_cluster_dict = self.hook.serialize_cluster_obj(ray_cluster)
             return ray_cluster_dict
         except NotFound as not_found_err:
diff --git 
a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py 
b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py
index 34feb72e821..20e5e8829d6 100644
--- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py
+++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py
@@ -19,9 +19,7 @@ from __future__ import annotations
 
 from unittest import mock
 
-import pytest
-
-ScalarMapContainer = 
pytest.importorskip("google._upb._message.ScalarMapContainer")
+from google.cloud.aiplatform.vertex_ray.util.resources import Cluster, 
Resources
 
 from airflow.providers.google.cloud.hooks.vertex_ai.ray import RayHook
 
@@ -168,6 +166,55 @@ class TestRayWithDefaultProjectIdHook:
         mock_aiplatform_init.assert_called_once()
         mock_list_ray_clusters.assert_called_once()
 
+    @mock.patch(RAY_STRING.format("aiplatform.init"))
+    def test_serialize_cluster_obj(self, mock_aiplatform_init) -> None:
+        RESOURCE_SAMPLE = {
+            "accelerator_count": 0,
+            "accelerator_type": None,
+            "autoscaling_spec": None,
+            "boot_disk_size_gb": 100,
+            "boot_disk_type": "pd-ssd",
+            "custom_image": None,
+            "machine_type": "n1-standard-16",
+            "node_count": 1,
+        }
+        SAMPLE_CLUSTER_SERIALIZED = {
+            "cluster_resource_name": TEST_CLUSTER_NAME,
+            "dashboard_address": "dashboard_addr",
+            "head_node_type": RESOURCE_SAMPLE,
+            "labels": {"label1": "val1"},
+            "network": "custom_network",
+            "psc_interface_config": None,
+            "python_version": TEST_PYTHON_VERSION,
+            "ray_logs_enabled": True,
+            "ray_metric_enabled": True,
+            "ray_version": TEST_RAY_VERSION,
+            "reserved_ip_ranges": [
+                "172.16.0.0/16",
+                "10.10.10.0/28",
+            ],
+            "service_account": None,
+            "state": "RUNNING",
+            "worker_node_types": [RESOURCE_SAMPLE, RESOURCE_SAMPLE],
+        }
+        cluster_obj = Cluster(
+            cluster_resource_name=TEST_CLUSTER_NAME,
+            state="RUNNING",  # type: ignore[arg-type]
+            network="custom_network",
+            reserved_ip_ranges=["172.16.0.0/16", "10.10.10.0/28"],
+            python_version=TEST_PYTHON_VERSION,
+            ray_version=TEST_RAY_VERSION,
+            head_node_type=Resources(**RESOURCE_SAMPLE),  # type: 
ignore[arg-type]
+            worker_node_types=[
+                Resources(**RESOURCE_SAMPLE),  # type: ignore[arg-type]
+                Resources(**RESOURCE_SAMPLE),  # type: ignore[arg-type]
+            ],
+            dashboard_address="dashboard_addr",
+            labels={"label1": "val1"},
+        )
+
+        assert self.hook.serialize_cluster_obj(cluster_obj) == 
SAMPLE_CLUSTER_SERIALIZED
+
 
 class TestRayWithoutDefaultProjectIdHook:
     def setup_method(self):

Reply via email to