This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch release-2.71
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/release-2.71 by this push:
new 0c4d81bc49e Cherry Pick RateLimiter SDK changes to Beam 2.71 release
(#37306)
0c4d81bc49e is described below
commit 0c4d81bc49e5244ef324baead1f6a9dd430e099f
Author: Tarun Annapareddy <[email protected]>
AuthorDate: Wed Jan 14 22:00:47 2026 +0530
Cherry Pick RateLimiter SDK changes to Beam 2.71 release (#37306)
* Support for RateLimiter in Beam Remote Model Handler (#37218)
* Support for EnvoyRateLimiter in Apache Beam
* fix format issues
* fix test formatting
* Fix test and syntax
* fix lint
* Add dependency based on python version
* revert setup to separete pr
* fix lint
* fix formatting
* resolve comments
* Support Ratelimiter through RemoteModelHandler
* fix lint
* fix lint
* fix comments
* Add custom RateLimited Exception
* fix doc
* fix test
* fix lint
* update RateLimiter execution function name (#37287)
* Catch breaking import error (#37295)
* Catch Import Error
* import order
---
.../examples/inference/rate_limiter_vertex_ai.py | 85 ++++++++++++++++++++++
.../apache_beam/examples/rate_limiter_simple.py | 2 +-
.../apache_beam/io/components/rate_limiter.py | 44 +++++++++--
.../apache_beam/io/components/rate_limiter_test.py | 24 +++---
sdks/python/apache_beam/ml/inference/base.py | 30 +++++++-
sdks/python/apache_beam/ml/inference/base_test.py | 61 ++++++++++++++++
6 files changed, 225 insertions(+), 21 deletions(-)
diff --git
a/sdks/python/apache_beam/examples/inference/rate_limiter_vertex_ai.py
b/sdks/python/apache_beam/examples/inference/rate_limiter_vertex_ai.py
new file mode 100644
index 00000000000..11ec02fbd54
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/rate_limiter_vertex_ai.py
@@ -0,0 +1,85 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A simple example demonstrating usage of the EnvoyRateLimiter with Vertex AI.
+"""
+
+import argparse
+import logging
+
+import apache_beam as beam
+from apache_beam.io.components.rate_limiter import EnvoyRateLimiter
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.vertex_ai_inference import
VertexAIModelHandlerJSON
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+
+
+def run(argv=None):
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--project',
+ dest='project',
+ help='The Google Cloud project ID for Vertex AI.')
+ parser.add_argument(
+ '--location',
+ dest='location',
+ help='The Google Cloud location (e.g. us-central1) for Vertex AI.')
+ parser.add_argument(
+ '--endpoint_id',
+ dest='endpoint_id',
+ help='The ID of the Vertex AI endpoint.')
+ parser.add_argument(
+ '--rls_address',
+ dest='rls_address',
+ help='The address of the Envoy Rate Limit Service (e.g.
localhost:8081).')
+
+ known_args, pipeline_args = parser.parse_known_args(argv)
+ pipeline_options = PipelineOptions(pipeline_args)
+ pipeline_options.view_as(SetupOptions).save_main_session = True
+
+ # Initialize the EnvoyRateLimiter
+ rate_limiter = EnvoyRateLimiter(
+ service_address=known_args.rls_address,
+ domain="mongo_cps",
+ descriptors=[{
+ "database": "users"
+ }],
+ namespace='example_pipeline')
+
+ # Initialize the VertexAIModelHandler with the rate limiter
+ model_handler = VertexAIModelHandlerJSON(
+ endpoint_id=known_args.endpoint_id,
+ project=known_args.project,
+ location=known_args.location,
+ rate_limiter=rate_limiter)
+
+ # Input features for the model
+ features = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],
+ [10.0, 11.0, 12.0], [13.0, 14.0, 15.0]]
+
+ with beam.Pipeline(options=pipeline_options) as p:
+ _ = (
+ p
+ | 'CreateInputs' >> beam.Create(features)
+ | 'RunInference' >> RunInference(model_handler)
+ | 'PrintPredictions' >> beam.Map(logging.info))
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ run()
diff --git a/sdks/python/apache_beam/examples/rate_limiter_simple.py
b/sdks/python/apache_beam/examples/rate_limiter_simple.py
index ea469006f2b..8cdf1166aad 100644
--- a/sdks/python/apache_beam/examples/rate_limiter_simple.py
+++ b/sdks/python/apache_beam/examples/rate_limiter_simple.py
@@ -53,7 +53,7 @@ class SampleApiDoFn(beam.DoFn):
self.rate_limiter = self._shared.acquire(init_limiter)
def process(self, element):
- self.rate_limiter.throttle()
+ self.rate_limiter.allow()
# Process the element mock API call
logging.info("Processing element: %s", element)
diff --git a/sdks/python/apache_beam/io/components/rate_limiter.py
b/sdks/python/apache_beam/io/components/rate_limiter.py
index 3de39ddd935..2dc8a5340fd 100644
--- a/sdks/python/apache_beam/io/components/rate_limiter.py
+++ b/sdks/python/apache_beam/io/components/rate_limiter.py
@@ -61,8 +61,13 @@ class RateLimiter(abc.ABC):
self.rpc_latency = Metrics.distribution(namespace, 'RatelimitRpcLatencyMs')
@abc.abstractmethod
- def throttle(self, **kwargs) -> bool:
- """Check if request should be throttled.
+ def allow(self, **kwargs) -> bool:
+ """Applies rate limiting to the request.
+
+ This method checks if the request is permitted by the rate limiting policy.
+ Depending on the implementation and configuration, it may block (sleep)
+ until the request is allowed, or return false if the rate limit retry is
+ exceeded.
Args:
**kwargs: Keyword arguments specific to the RateLimiter implementation.
@@ -78,8 +83,12 @@ class RateLimiter(abc.ABC):
class EnvoyRateLimiter(RateLimiter):
- """
- Rate limiter implementation that uses an external Envoy Rate Limit Service.
+ """Rate limiter implementation that uses an external Envoy Rate Limit
Service.
+
+ This limiter connects to a gRPC Envoy Rate Limit Service (RLS) to determine
+ whether a request should be allowed. It supports defining a domain and a
+ list of descriptors that correspond to the rate limit configuration in the
+ RLS.
"""
def __init__(
self,
@@ -89,7 +98,7 @@ class EnvoyRateLimiter(RateLimiter):
timeout: float = 5.0,
block_until_allowed: bool = True,
retries: int = 3,
- namespace: str = ""):
+ namespace: str = ''):
"""
Args:
service_address: Address of the Envoy RLS (e.g., 'localhost:8081').
@@ -139,8 +148,16 @@ class EnvoyRateLimiter(RateLimiter):
channel = grpc.insecure_channel(self.service_address)
self._stub = EnvoyRateLimiter.RateLimitServiceStub(channel)
- def throttle(self, hits_added: int = 1) -> bool:
- """Calls the Envoy RLS to check for rate limits.
+ def allow(self, hits_added: int = 1) -> bool:
+ """Calls the Envoy RLS to apply rate limits.
+
+ Sends a rate limit request to the configured Envoy Rate Limit Service.
+ If 'block_until_allowed' is True, this method will sleep and retry
+ if the limit is exceeded, effectively blocking until the request is
+ permitted.
+
+ If 'block_until_allowed' is False, it will return False after the retry
+ limit is exceeded.
Args:
hits_added: Number of hits to add to the rate limit.
@@ -224,3 +241,16 @@ class EnvoyRateLimiter(RateLimiter):
response.overall_code)
break
return throttled
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ if '_lock' in state:
+ del state['_lock']
+ if '_stub' in state:
+ del state['_stub']
+ return state
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ self._lock = threading.Lock()
+ self._stub = None
diff --git a/sdks/python/apache_beam/io/components/rate_limiter_test.py
b/sdks/python/apache_beam/io/components/rate_limiter_test.py
index 7c3e7b82aad..24d30a1c5c9 100644
--- a/sdks/python/apache_beam/io/components/rate_limiter_test.py
+++ b/sdks/python/apache_beam/io/components/rate_limiter_test.py
@@ -42,7 +42,7 @@ class EnvoyRateLimiterTest(unittest.TestCase):
namespace='test_namespace')
@mock.patch('grpc.insecure_channel')
- def test_throttle_allowed(self, mock_channel):
+ def test_allow_success(self, mock_channel):
# Mock successful OK response
mock_stub = mock.Mock()
mock_response = RateLimitResponse(overall_code=RateLimitResponseCode.OK)
@@ -51,13 +51,13 @@ class EnvoyRateLimiterTest(unittest.TestCase):
# Inject mock stub
self.limiter._stub = mock_stub
- throttled = self.limiter.throttle()
+ allowed = self.limiter.allow()
- self.assertTrue(throttled)
+ self.assertTrue(allowed)
mock_stub.ShouldRateLimit.assert_called_once()
@mock.patch('grpc.insecure_channel')
- def test_throttle_over_limit_retries_exceeded(self, mock_channel):
+ def test_allow_over_limit_retries_exceeded(self, mock_channel):
# Mock OVER_LIMIT response
mock_stub = mock.Mock()
mock_response = RateLimitResponse(
@@ -69,9 +69,9 @@ class EnvoyRateLimiterTest(unittest.TestCase):
# We mock time.sleep to run fast
with mock.patch('time.sleep'):
- throttled = self.limiter.throttle()
+ allowed = self.limiter.allow()
- self.assertFalse(throttled)
+ self.assertFalse(allowed)
# Should be called 1 (initial) + 2 (retries) + 1 (last check > retries
# logic depends on loop)
# Logic: attempt starts at 0.
@@ -83,7 +83,7 @@ class EnvoyRateLimiterTest(unittest.TestCase):
self.assertEqual(mock_stub.ShouldRateLimit.call_count, 3)
@mock.patch('grpc.insecure_channel')
- def test_throttle_rpc_error_retry(self, mock_channel):
+ def test_allow_rpc_error_retry(self, mock_channel):
# Mock RpcError then Success
mock_stub = mock.Mock()
mock_response = RateLimitResponse(overall_code=RateLimitResponseCode.OK)
@@ -95,13 +95,13 @@ class EnvoyRateLimiterTest(unittest.TestCase):
self.limiter._stub = mock_stub
with mock.patch('time.sleep'):
- throttled = self.limiter.throttle()
+ allowed = self.limiter.allow()
- self.assertTrue(throttled)
+ self.assertTrue(allowed)
self.assertEqual(mock_stub.ShouldRateLimit.call_count, 3)
@mock.patch('grpc.insecure_channel')
- def test_throttle_rpc_error_fail(self, mock_channel):
+ def test_allow_rpc_error_fail(self, mock_channel):
# Mock Persistent RpcError
mock_stub = mock.Mock()
error = grpc.RpcError()
@@ -111,7 +111,7 @@ class EnvoyRateLimiterTest(unittest.TestCase):
with mock.patch('time.sleep'):
with self.assertRaises(grpc.RpcError):
- self.limiter.throttle()
+ self.limiter.allow()
# The inner loop tries 5 times for connection errors
self.assertEqual(mock_stub.ShouldRateLimit.call_count, 5)
@@ -134,7 +134,7 @@ class EnvoyRateLimiterTest(unittest.TestCase):
self.limiter.retries = 0 # Single attempt
with mock.patch('time.sleep') as mock_sleep:
- self.limiter.throttle()
+ self.limiter.allow()
# Should sleep for 5 seconds (jitter is 0.0)
mock_sleep.assert_called_with(5.0)
diff --git a/sdks/python/apache_beam/ml/inference/base.py
b/sdks/python/apache_beam/ml/inference/base.py
index d79565ee24d..e0f870669f7 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -60,6 +60,11 @@ from apache_beam.utils import multi_process_shared
from apache_beam.utils import retry
from apache_beam.utils import shared
+try:
+ from apache_beam.io.components.rate_limiter import RateLimiter
+except ImportError:
+ RateLimiter = None
+
try:
# pylint: disable=wrong-import-order, wrong-import-position
import resource
@@ -102,6 +107,11 @@ PredictionResult.inference.__doc__ = """Results for the
inference on the model
PredictionResult.model_id.__doc__ = """Model ID used to run the prediction."""
+class RateLimitExceeded(RuntimeError):
+ """RateLimit Exceeded to process a batch of requests."""
+ pass
+
+
class ModelMetadata(NamedTuple):
model_id: str
model_name: str
@@ -349,7 +359,8 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT,
PredictionT, ModelT]):
*,
window_ms: int = 1 * _MILLISECOND_TO_SECOND,
bucket_ms: int = 1 * _MILLISECOND_TO_SECOND,
- overload_ratio: float = 2):
+ overload_ratio: float = 2,
+ rate_limiter: Optional[RateLimiter] = None):
"""Initializes a ReactiveThrottler class for enabling
client-side throttling for remote calls to an inference service. Also wraps
provided calls to the service with retry logic.
@@ -372,6 +383,7 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT,
PredictionT, ModelT]):
overload_ratio: the target ratio between requests sent and successful
requests. This is "K" in the formula in
https://landing.google.com/sre/book/chapters/handling-overload.html.
+ rate_limiter: A RateLimiter object for setting a global rate limit.
"""
# Configure ReactiveThrottler for client-side throttling behavior.
self.throttler = ReactiveThrottler(
@@ -383,6 +395,9 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT,
PredictionT, ModelT]):
self.logger = logging.getLogger(namespace)
self.num_retries = num_retries
self.retry_filter = retry_filter
+ self._rate_limiter = rate_limiter
+ self._shared_rate_limiter = None
+ self._shared_handle = shared.Shared()
def __init_subclass__(cls):
if cls.load_model is not RemoteModelHandler.load_model:
@@ -431,6 +446,19 @@ class RemoteModelHandler(ABC, ModelHandler[ExampleT,
PredictionT, ModelT]):
Returns:
An Iterable of Predictions.
"""
+ if self._rate_limiter:
+ if self._shared_rate_limiter is None:
+
+ def init_limiter():
+ return self._rate_limiter
+
+ self._shared_rate_limiter = self._shared_handle.acquire(init_limiter)
+
+ if not self._shared_rate_limiter.allow(hits_added=len(batch)):
+ raise RateLimitExceeded(
+ "Rate Limit Exceeded, "
+ "Could not process this batch.")
+
self.throttler.throttle()
try:
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py
b/sdks/python/apache_beam/ml/inference/base_test.py
index 574e71de89c..381bf545660 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -2071,6 +2071,67 @@ class RunInferenceRemoteTest(unittest.TestCase):
responses.append(model.predict(example))
return responses
+ def test_run_inference_with_rate_limiter(self):
+ class FakeRateLimiter(base.RateLimiter):
+ def __init__(self):
+ super().__init__(namespace='test_namespace')
+
+ def allow(self, hits_added=1):
+ self.requests_counter.inc()
+ return True
+
+ limiter = FakeRateLimiter()
+
+ with TestPipeline() as pipeline:
+ examples = [1, 5]
+
+ class ConcreteRemoteModelHandler(base.RemoteModelHandler):
+ def create_client(self):
+ return FakeModel()
+
+ def request(self, batch, model, inference_args=None):
+ return [model.predict(example) for example in batch]
+
+ model_handler = ConcreteRemoteModelHandler(
+ rate_limiter=limiter, namespace='test_namespace')
+
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ actual = pcoll | base.RunInference(model_handler)
+
+ expected = [2, 6]
+ assert_that(actual, equal_to(expected))
+
+ result = pipeline.run()
+ result.wait_until_finish()
+
+ metrics_filter = MetricsFilter().with_name(
+ 'RatelimitRequestsTotal').with_namespace('test_namespace')
+ metrics = result.metrics().query(metrics_filter)
+ self.assertGreaterEqual(metrics['counters'][0].committed, 0)
+
+ def test_run_inference_with_rate_limiter_exceeded(self):
+ class FakeRateLimiter(base.RateLimiter):
+ def __init__(self):
+ super().__init__(namespace='test_namespace')
+
+ def allow(self, hits_added=1):
+ return False
+
+ class ConcreteRemoteModelHandler(base.RemoteModelHandler):
+ def create_client(self):
+ return FakeModel()
+
+ def request(self, batch, model, inference_args=None):
+ return [model.predict(example) for example in batch]
+
+ model_handler = ConcreteRemoteModelHandler(
+ rate_limiter=FakeRateLimiter(),
+ namespace='test_namespace',
+ num_retries=0)
+
+ with self.assertRaises(base.RateLimitExceeded):
+ model_handler.run_inference([1], FakeModel())
+
if __name__ == '__main__':
unittest.main()