This is an automated email from the ASF dual-hosted git repository.
shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 511dc0014de Fix the google cloud provider RayHook protobuf
compatibility (#54014)
511dc0014de is described below
commit 511dc0014dea9153fb8480ade91ac234003d761c
Author: olegkachur-e <[email protected]>
AuthorDate: Fri Aug 1 11:21:29 2025 +0000
Fix the google cloud provider RayHook protobuf compatibility (#54014)
- The previously used `google._upb._message.ScalarMapContainer` is not
available in protobuf >= 5.*, which breaks RayHook and operators.
Co-authored-by: Oleg Kachur <[email protected]>
---
.../providers/google/cloud/hooks/vertex_ai/ray.py | 15 ++----
.../google/cloud/operators/vertex_ai/ray.py | 2 +-
.../unit/google/cloud/hooks/vertex_ai/test_ray.py | 53 ++++++++++++++++++++--
3 files changed, 54 insertions(+), 16 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py
b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py
index 5aede4a0465..76d94299d49 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py
@@ -20,19 +20,10 @@
from __future__ import annotations
import dataclasses
+from collections.abc import MutableMapping
from typing import Any
-from airflow.exceptions import AirflowOptionalProviderFeatureException
-
-try:
- import vertex_ray
- from google._upb._message import ScalarMapContainer # type:
ignore[attr-defined]
-except ImportError:
- # Fallback for environments where the upb module is not available.
- raise AirflowOptionalProviderFeatureException(
- "google._upb._message.ScalarMapContainer is not available. "
- "Please install the ray package to use this feature."
- )
+import vertex_ray
from google.cloud import aiplatform
from google.cloud.aiplatform.vertex_ray.util import resources
from google.cloud.aiplatform_v1 import (
@@ -59,7 +50,7 @@ class RayHook(GoogleBaseHook):
def __encode_value(value: Any) -> Any:
if isinstance(value, (list, Repeated)):
return [__encode_value(nested_value) for nested_value in value]
- if isinstance(value, ScalarMapContainer):
+ if not isinstance(value, dict) and isinstance(value,
MutableMapping):
return {key: __encode_value(nested_value) for key,
nested_value in dict(value).items()}
if dataclasses.is_dataclass(value):
return dataclasses.asdict(value)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
index e06c18ea1c4..95a284c1def 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
@@ -282,7 +282,7 @@ class GetRayClusterOperator(RayBaseOperator):
location=self.location,
cluster_id=self.cluster_id,
)
- self.log.info("Cluster was gotten.")
+ self.log.info("Cluster data has been retrieved.")
ray_cluster_dict = self.hook.serialize_cluster_obj(ray_cluster)
return ray_cluster_dict
except NotFound as not_found_err:
diff --git
a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py
b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py
index 34feb72e821..20e5e8829d6 100644
--- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py
+++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py
@@ -19,9 +19,7 @@ from __future__ import annotations
from unittest import mock
-import pytest
-
-ScalarMapContainer =
pytest.importorskip("google._upb._message.ScalarMapContainer")
+from google.cloud.aiplatform.vertex_ray.util.resources import Cluster,
Resources
from airflow.providers.google.cloud.hooks.vertex_ai.ray import RayHook
@@ -168,6 +166,55 @@ class TestRayWithDefaultProjectIdHook:
mock_aiplatform_init.assert_called_once()
mock_list_ray_clusters.assert_called_once()
+ @mock.patch(RAY_STRING.format("aiplatform.init"))
+ def test_serialize_cluster_obj(self, mock_aiplatform_init) -> None:
+ RESOURCE_SAMPLE = {
+ "accelerator_count": 0,
+ "accelerator_type": None,
+ "autoscaling_spec": None,
+ "boot_disk_size_gb": 100,
+ "boot_disk_type": "pd-ssd",
+ "custom_image": None,
+ "machine_type": "n1-standard-16",
+ "node_count": 1,
+ }
+ SAMPLE_CLUSTER_SERIALIZED = {
+ "cluster_resource_name": TEST_CLUSTER_NAME,
+ "dashboard_address": "dashboard_addr",
+ "head_node_type": RESOURCE_SAMPLE,
+ "labels": {"label1": "val1"},
+ "network": "custom_network",
+ "psc_interface_config": None,
+ "python_version": TEST_PYTHON_VERSION,
+ "ray_logs_enabled": True,
+ "ray_metric_enabled": True,
+ "ray_version": TEST_RAY_VERSION,
+ "reserved_ip_ranges": [
+ "172.16.0.0/16",
+ "10.10.10.0/28",
+ ],
+ "service_account": None,
+ "state": "RUNNING",
+ "worker_node_types": [RESOURCE_SAMPLE, RESOURCE_SAMPLE],
+ }
+ cluster_obj = Cluster(
+ cluster_resource_name=TEST_CLUSTER_NAME,
+ state="RUNNING", # type: ignore[arg-type]
+ network="custom_network",
+ reserved_ip_ranges=["172.16.0.0/16", "10.10.10.0/28"],
+ python_version=TEST_PYTHON_VERSION,
+ ray_version=TEST_RAY_VERSION,
+ head_node_type=Resources(**RESOURCE_SAMPLE), # type:
ignore[arg-type]
+ worker_node_types=[
+ Resources(**RESOURCE_SAMPLE), # type: ignore[arg-type]
+ Resources(**RESOURCE_SAMPLE), # type: ignore[arg-type]
+ ],
+ dashboard_address="dashboard_addr",
+ labels={"label1": "val1"},
+ )
+
+ assert self.hook.serialize_cluster_obj(cluster_obj) ==
SAMPLE_CLUSTER_SERIALIZED
+
class TestRayWithoutDefaultProjectIdHook:
def setup_method(self):