This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new ab14c432dad Support for EnvoyRateLimiter in Beam Python SDK (#37135)
ab14c432dad is described below

commit ab14c432dada70f835156e78ef9582d8c0a1e906
Author: Tarun Annapareddy <[email protected]>
AuthorDate: Mon Jan 5 18:51:54 2026 +0530

    Support for EnvoyRateLimiter in Beam Python SDK (#37135)
    
    * Support for EnvoyRateLimiter in Apache Beam
    
    * fix format issues
    
    * fix test formatting
    
    * Fix test and syntax
    
    * fix lint
    
    * Add dependency based on python version
    
    * revert setup to separete pr
    
    * fix lint
    
    * fix formatting
    
    * resolve comments
---
 .../apache_beam/examples/rate_limiter_simple.py    |  93 +++++++++
 .../apache_beam/io/components/rate_limiter.py      | 226 +++++++++++++++++++++
 .../apache_beam/io/components/rate_limiter_test.py | 143 +++++++++++++
 3 files changed, 462 insertions(+)

diff --git a/sdks/python/apache_beam/examples/rate_limiter_simple.py 
b/sdks/python/apache_beam/examples/rate_limiter_simple.py
new file mode 100644
index 00000000000..ea469006f2b
--- /dev/null
+++ b/sdks/python/apache_beam/examples/rate_limiter_simple.py
@@ -0,0 +1,93 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A simple example demonstrating usage of the EnvoyRateLimiter in a Beam
+pipeline.
+"""
+
+import argparse
+import logging
+import time
+
+import apache_beam as beam
+from apache_beam.io.components.rate_limiter import EnvoyRateLimiter
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.utils import shared
+
+
+class SampleApiDoFn(beam.DoFn):
+  """A DoFn that simulates calling an external API with rate limiting."""
+  def __init__(self, rls_address, domain, descriptors):
+    self.rls_address = rls_address
+    self.domain = domain
+    self.descriptors = descriptors
+    self._shared = shared.Shared()
+    self.rate_limiter = None
+
+  def setup(self):
+    # Initialize the rate limiter in setup()
+    # We use shared.Shared() to ensure only one RateLimiter instance is created
+    # per worker and shared across threads.
+    def init_limiter():
+      logging.info("Connecting to Envoy RLS at %s", self.rls_address)
+      return EnvoyRateLimiter(
+          service_address=self.rls_address,
+          domain=self.domain,
+          descriptors=self.descriptors,
+          namespace='example_pipeline')
+
+    self.rate_limiter = self._shared.acquire(init_limiter)
+
+  def process(self, element):
+    self.rate_limiter.throttle()
+
+    # Process the element mock API call
+    logging.info("Processing element: %s", element)
+    time.sleep(0.1)
+    yield element
+
+
+def parse_known_args(argv):
+  """Parses args for the workflow."""
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--rls_address',
+      default='localhost:8081',
+      help='Address of the Envoy Rate Limit Service')
+  return parser.parse_known_args(argv)
+
+
+def run(argv=None):
+  known_args, pipeline_args = parse_known_args(argv)
+  pipeline_options = PipelineOptions(pipeline_args)
+
+  with beam.Pipeline(options=pipeline_options) as p:
+    _ = (
+        p
+        | 'Create' >> beam.Create(range(100))
+        | 'RateLimit' >> beam.ParDo(
+            SampleApiDoFn(
+                rls_address=known_args.rls_address,
+                domain="mongo_cps",
+                descriptors=[{
+                    "database": "users"
+                }])))
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  run()
diff --git a/sdks/python/apache_beam/io/components/rate_limiter.py 
b/sdks/python/apache_beam/io/components/rate_limiter.py
new file mode 100644
index 00000000000..3de39ddd935
--- /dev/null
+++ b/sdks/python/apache_beam/io/components/rate_limiter.py
@@ -0,0 +1,226 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Rate Limiter classes for controlling access to external resources.
+"""
+
+import abc
+import logging
+import math
+import random
+import threading
+import time
+from typing import Dict
+from typing import List
+
+import grpc
+from envoy_data_plane.envoy.extensions.common.ratelimit.v3 import 
RateLimitDescriptor
+from envoy_data_plane.envoy.extensions.common.ratelimit.v3 import 
RateLimitDescriptorEntry
+from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitRequest
+from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponse
+from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponseCode
+
+from apache_beam.io.components import adaptive_throttler
+from apache_beam.metrics import Metrics
+
+_LOGGER = logging.getLogger(__name__)
+
+_RPC_MAX_RETRIES = 5
+_RPC_RETRY_DELAY_SECONDS = 10
+
+
+class RateLimiter(abc.ABC):
+  """Abstract base class for RateLimiters."""
+  def __init__(self, namespace: str = ""):
+    # Metrics collected from the RateLimiter
+    # Metric updates are thread safe
+    self.throttling_signaler = adaptive_throttler.ThrottlingSignaler(
+        namespace=namespace)
+    self.requests_counter = Metrics.counter(namespace, 
'RatelimitRequestsTotal')
+    self.requests_allowed = Metrics.counter(
+        namespace, 'RatelimitRequestsAllowed')
+    self.requests_throttled = Metrics.counter(
+        namespace, 'RatelimitRequestsThrottled')
+    self.rpc_errors = Metrics.counter(namespace, 'RatelimitRpcErrors')
+    self.rpc_retries = Metrics.counter(namespace, 'RatelimitRpcRetries')
+    self.rpc_latency = Metrics.distribution(namespace, 'RatelimitRpcLatencyMs')
+
+  @abc.abstractmethod
+  def throttle(self, **kwargs) -> bool:
+    """Check if request should be throttled.
+
+    Args:
+      **kwargs: Keyword arguments specific to the RateLimiter implementation.
+
+    Returns:
+      bool: True if the request is allowed, False if retries exceeded.
+
+    Raises:
+      Exception: If an underlying infrastructure error occurs (e.g. RPC
+        failure).
+    """
+    pass
+
+
+class EnvoyRateLimiter(RateLimiter):
+  """
+  Rate limiter implementation that uses an external Envoy Rate Limit Service.
+  """
+  def __init__(
+      self,
+      service_address: str,
+      domain: str,
+      descriptors: List[Dict[str, str]],
+      timeout: float = 5.0,
+      block_until_allowed: bool = True,
+      retries: int = 3,
+      namespace: str = ""):
+    """
+    Args:
+      service_address: Address of the Envoy RLS (e.g., 'localhost:8081').
+      domain: The rate limit domain.
+      descriptors: List of descriptors (key-value pairs).
+      retries: Number of retries to attempt if rate limited, respected only if
+        block_until_allowed is False.
+      timeout: gRPC timeout in seconds.
+      block_until_allowed: If enabled blocks until RateLimiter gets
+        the token.
+      namespace: the namespace to use for logging and signaling
+        throttling is occurring.
+    """
+    super().__init__(namespace=namespace)
+
+    self.service_address = service_address
+    self.domain = domain
+    self.descriptors = descriptors
+    self.retries = retries
+    self.timeout = timeout
+    self.block_until_allowed = block_until_allowed
+    self._stub = None
+    self._lock = threading.Lock()
+
+  class RateLimitServiceStub(object):
+    """ 
+    Wrapper for gRPC stub to be compatible with envoy_data_plane messages.
+    
+    The envoy-data-plane package uses 'betterproto' which generates async stubs
+    for 'grpclib'. As Beam uses standard synchronous 'grpcio',
+    RateLimitServiceStub is a bridge class to use the betterproto Message types
+    (RateLimitRequest) with a standard grpcio Channel.
+    """
+    def __init__(self, channel):
+      self.ShouldRateLimit = channel.unary_unary(
+          '/envoy.service.ratelimit.v3.RateLimitService/ShouldRateLimit',
+          request_serializer=RateLimitRequest.SerializeToString,
+          response_deserializer=RateLimitResponse.FromString,
+      )
+
+  def init_connection(self):
+    if self._stub is None:
+      # Acquire lock to safegaurd againest multiple DoFn threads sharing the
+      # same RateLimiter instance, which is the case when using Shared().
+      with self._lock:
+        if self._stub is None:
+          channel = grpc.insecure_channel(self.service_address)
+          self._stub = EnvoyRateLimiter.RateLimitServiceStub(channel)
+
+  def throttle(self, hits_added: int = 1) -> bool:
+    """Calls the Envoy RLS to check for rate limits.
+
+    Args:
+      hits_added: Number of hits to add to the rate limit.
+
+    Returns:
+      bool: True if the request is allowed, False if retries exceeded.
+    """
+    self.init_connection()
+
+    # execute thread-safe gRPC call
+    # Convert descriptors to proto format
+    proto_descriptors = []
+    for d in self.descriptors:
+      entries = []
+      for k, v in d.items():
+        entries.append(RateLimitDescriptorEntry(key=k, value=v))
+      proto_descriptors.append(RateLimitDescriptor(entries=entries))
+
+    request = RateLimitRequest(
+        domain=self.domain,
+        descriptors=proto_descriptors,
+        hits_addend=hits_added)
+
+    self.requests_counter.inc()
+    attempt = 0
+    throttled = False
+    while True:
+      if not self.block_until_allowed and attempt > self.retries:
+        break
+
+      # retry loop
+      for retry_attempt in range(_RPC_MAX_RETRIES):
+        try:
+          start_time = time.time()
+          response = self._stub.ShouldRateLimit(request, timeout=self.timeout)
+          self.rpc_latency.update(int((time.time() - start_time) * 1000))
+          break
+        except grpc.RpcError as e:
+          if retry_attempt == _RPC_MAX_RETRIES - 1:
+            _LOGGER.error(
+                "[EnvoyRateLimiter] ratelimit service call failed: %s", e)
+            self.rpc_errors.inc()
+            raise e
+          self.rpc_retries.inc()
+          _LOGGER.warning(
+              "[EnvoyRateLimiter] ratelimit service call failed, retrying: %s",
+              e)
+          time.sleep(_RPC_RETRY_DELAY_SECONDS)
+
+      if response.overall_code == RateLimitResponseCode.OK:
+        self.requests_allowed.inc()
+        throttled = True
+        break
+      elif response.overall_code == RateLimitResponseCode.OVER_LIMIT:
+        self.requests_throttled.inc()
+        # Ratelimit exceeded, sleep for duration until reset and retry
+        # multiple rules can be set in the RLS config, so we need to find the
+        # max duration
+        sleep_s = 0.0
+        if response.statuses:
+          for status in response.statuses:
+            if status.code == RateLimitResponseCode.OVER_LIMIT:
+              dur = status.duration_until_reset
+              # duration_until_reset is converted to timedelta by betterproto
+              val = dur.total_seconds()
+              if val > sleep_s:
+                sleep_s = val
+
+        # Add 1% additive jitter to prevent thundering herd
+        jitter = random.uniform(0, 0.01 * sleep_s)
+        sleep_s += jitter
+
+        _LOGGER.warning("[EnvoyRateLimiter] Throttled for %s seconds", sleep_s)
+        # signal throttled time to backend
+        self.throttling_signaler.signal_throttled(math.ceil(sleep_s))
+        time.sleep(sleep_s)
+        attempt += 1
+      else:
+        _LOGGER.error(
+            "[EnvoyRateLimiter] Unknown code from RLS: %s",
+            response.overall_code)
+        break
+    return throttled
diff --git a/sdks/python/apache_beam/io/components/rate_limiter_test.py 
b/sdks/python/apache_beam/io/components/rate_limiter_test.py
new file mode 100644
index 00000000000..7c3e7b82aad
--- /dev/null
+++ b/sdks/python/apache_beam/io/components/rate_limiter_test.py
@@ -0,0 +1,143 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import timedelta
+from unittest import mock
+
+import grpc
+from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponse
+from envoy_data_plane.envoy.service.ratelimit.v3 import RateLimitResponseCode
+from envoy_data_plane.envoy.service.ratelimit.v3 import 
RateLimitResponseDescriptorStatus
+
+from apache_beam.io.components import rate_limiter
+
+
+class EnvoyRateLimiterTest(unittest.TestCase):
+  def setUp(self):
+    self.service_address = 'localhost:8081'
+    self.domain = 'test_domain'
+    self.descriptors = [{'key': 'value'}]
+    self.limiter = rate_limiter.EnvoyRateLimiter(
+        self.service_address,
+        self.domain,
+        self.descriptors,
+        timeout=0.1,  # Fast timeout for tests
+        block_until_allowed=False,
+        retries=2,
+        namespace='test_namespace')
+
+  @mock.patch('grpc.insecure_channel')
+  def test_throttle_allowed(self, mock_channel):
+    # Mock successful OK response
+    mock_stub = mock.Mock()
+    mock_response = RateLimitResponse(overall_code=RateLimitResponseCode.OK)
+    mock_stub.ShouldRateLimit.return_value = mock_response
+
+    # Inject mock stub
+    self.limiter._stub = mock_stub
+
+    throttled = self.limiter.throttle()
+
+    self.assertTrue(throttled)
+    mock_stub.ShouldRateLimit.assert_called_once()
+
+  @mock.patch('grpc.insecure_channel')
+  def test_throttle_over_limit_retries_exceeded(self, mock_channel):
+    # Mock OVER_LIMIT response
+    mock_stub = mock.Mock()
+    mock_response = RateLimitResponse(
+        overall_code=RateLimitResponseCode.OVER_LIMIT)
+    mock_stub.ShouldRateLimit.return_value = mock_response
+
+    self.limiter._stub = mock_stub
+    # block_until_allowed is False, so it should eventually return False
+
+    # We mock time.sleep to run fast
+    with mock.patch('time.sleep'):
+      throttled = self.limiter.throttle()
+
+    self.assertFalse(throttled)
+    # Should be called 1 (initial) + 2 (retries) + 1 (last check > retries
+    # logic depends on loop)
+    # Logic: attempt starts at 0.
+    # Loop 1: attempt 0. status OVER_LIMIT. sleep. attempt becomes 1.
+    # Loop 2: attempt 1. status OVER_LIMIT. sleep. attempt becomes 2.
+    # Loop 3: attempt 2. status OVER_LIMIT. sleep. attempt becomes 3.
+    # Loop 4: attempt 3 > retries(2). Break.
+    # Total calls: 3
+    self.assertEqual(mock_stub.ShouldRateLimit.call_count, 3)
+
+  @mock.patch('grpc.insecure_channel')
+  def test_throttle_rpc_error_retry(self, mock_channel):
+    # Mock RpcError then Success
+    mock_stub = mock.Mock()
+    mock_response = RateLimitResponse(overall_code=RateLimitResponseCode.OK)
+
+    # Side effect: Error, Error, Success
+    error = grpc.RpcError()
+    mock_stub.ShouldRateLimit.side_effect = [error, error, mock_response]
+
+    self.limiter._stub = mock_stub
+
+    with mock.patch('time.sleep'):
+      throttled = self.limiter.throttle()
+
+    self.assertTrue(throttled)
+    self.assertEqual(mock_stub.ShouldRateLimit.call_count, 3)
+
+  @mock.patch('grpc.insecure_channel')
+  def test_throttle_rpc_error_fail(self, mock_channel):
+    # Mock Persistent RpcError
+    mock_stub = mock.Mock()
+    error = grpc.RpcError()
+    mock_stub.ShouldRateLimit.side_effect = error
+
+    self.limiter._stub = mock_stub
+
+    with mock.patch('time.sleep'):
+      with self.assertRaises(grpc.RpcError):
+        self.limiter.throttle()
+
+    # The inner loop tries 5 times for connection errors
+    self.assertEqual(mock_stub.ShouldRateLimit.call_count, 5)
+
+  @mock.patch('grpc.insecure_channel')
+  @mock.patch('random.uniform', return_value=0.0)
+  def test_extract_duration_from_response(self, mock_random, mock_channel):
+    # Mock OVER_LIMIT with specific duration
+    mock_stub = mock.Mock()
+
+    # Valid until 5 seconds
+    status = RateLimitResponseDescriptorStatus(
+        code=RateLimitResponseCode.OVER_LIMIT,
+        duration_until_reset=timedelta(seconds=5))
+    mock_response = RateLimitResponse(
+        overall_code=RateLimitResponseCode.OVER_LIMIT, statuses=[status])
+
+    mock_stub.ShouldRateLimit.return_value = mock_response
+    self.limiter._stub = mock_stub
+    self.limiter.retries = 0  # Single attempt
+
+    with mock.patch('time.sleep') as mock_sleep:
+      self.limiter.throttle()
+      # Should sleep for 5 seconds (jitter is 0.0)
+      mock_sleep.assert_called_with(5.0)
+
+
+if __name__ == '__main__':
+  unittest.main()

Reply via email to