gemini-code-assist[bot] commented on code in PR #37573:
URL: https://github.com/apache/beam/pull/37573#discussion_r2801184550


##########
sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/EnvoyRateLimiterFactory.java:
##########
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.components.ratelimiter;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
+
+import io.envoyproxy.envoy.extensions.common.ratelimit.v3.RateLimitDescriptor;
+import io.envoyproxy.envoy.service.ratelimit.v3.RateLimitRequest;
+import io.envoyproxy.envoy.service.ratelimit.v3.RateLimitResponse;
+import io.envoyproxy.envoy.service.ratelimit.v3.RateLimitServiceGrpc;
+import io.grpc.StatusRuntimeException;
+import java.io.IOException;
+import java.util.Map;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.io.components.throttling.ThrottlingSignaler;
+import org.apache.beam.sdk.metrics.Counter;
+import org.apache.beam.sdk.metrics.Distribution;
+import org.apache.beam.sdk.metrics.Metrics;
+import org.apache.beam.sdk.util.Sleeper;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** A {@link RateLimiterFactory} for Envoy Rate Limit Service. */
+public class EnvoyRateLimiterFactory implements RateLimiterFactory {
+  private static final Logger LOG = 
LoggerFactory.getLogger(EnvoyRateLimiterFactory.class);
+  private static final int RPC_RETRY_COUNT = 3;
+  private static final long RPC_RETRY_DELAY_MILLIS = 5000;
+
+  private final RateLimiterOptions options;
+
+  private transient volatile @Nullable 
RateLimitServiceGrpc.RateLimitServiceBlockingStub stub;
+  private transient @Nullable RateLimiterClientCache clientCache;
+  private final ThrottlingSignaler throttlingSignaler;
+  private final Sleeper sleeper;
+
+  private final Counter requestsTotal;
+  private final Counter requestsAllowed;
+  private final Counter requestsThrottled;
+  private final Counter rpcErrors;
+  private final Counter rpcRetries;
+  private final Distribution rpcLatency;
+
+  public EnvoyRateLimiterFactory(RateLimiterOptions options) {
+    this(options, Sleeper.DEFAULT);
+  }
+
+  @VisibleForTesting
+  EnvoyRateLimiterFactory(RateLimiterOptions options, Sleeper sleeper) {
+    this.options = options;
+    this.sleeper = sleeper;
+    String namespace = EnvoyRateLimiterFactory.class.getName();
+    this.throttlingSignaler = new ThrottlingSignaler(namespace);
+    this.requestsTotal = Metrics.counter(namespace, 
"ratelimit-requests-total");
+    this.requestsAllowed = Metrics.counter(namespace, 
"ratelimit-requests-allowed");
+    this.requestsThrottled = Metrics.counter(namespace, 
"ratelimit-requests-throttled");
+    this.rpcErrors = Metrics.counter(namespace, "ratelimit-rpc-errors");
+    this.rpcRetries = Metrics.counter(namespace, "ratelimit-rpc-retries");
+    this.rpcLatency = Metrics.distribution(namespace, 
"ratelimit-rpc-latency-ms");
+  }
+
+  @Override
+  public synchronized void close() {
+    if (clientCache != null) {
+      clientCache.release();
+      clientCache = null;
+    }
+    stub = null;
+  }
+
+  private void init() {
+    if (stub != null) {
+      return;
+    }
+    synchronized (this) {
+      if (stub == null) {
+        RateLimiterClientCache cache = 
RateLimiterClientCache.getOrCreate(options.getAddress());
+        this.clientCache = cache;
+        stub = RateLimitServiceGrpc.newBlockingStub(cache.getChannel());
+      }
+    }
+  }
+
+  @VisibleForTesting
+  void setStub(RateLimitServiceGrpc.RateLimitServiceBlockingStub stub) {
+    this.stub = stub;
+  }
+
+  @Override
+  public RateLimiter getLimiter(RateLimiterContext context) {
+    if (!(context instanceof EnvoyRateLimiterContext)) {
+      throw new IllegalArgumentException(
+          "EnvoyRateLimiterFactory requires EnvoyRateLimiterContext");
+    }
+    return new EnvoyRateLimiter(this, (EnvoyRateLimiterContext) context);
+  }
+
+  @Override
+  public boolean allow(RateLimiterContext context, int permits)
+      throws IOException, InterruptedException {
+    if (permits == 0) {
+      return true;
+    }
+    if (!(context instanceof EnvoyRateLimiterContext)) {
+      throw new IllegalArgumentException(
+          "EnvoyRateLimiterFactory requires EnvoyRateLimiterContext, got: "
+              + context.getClass().getName());
+    }
+    checkArgument(permits >= 0, "Permits must be non-negative");
+    EnvoyRateLimiterContext envoyContext = (EnvoyRateLimiterContext) context;
+    return fetchTokens(envoyContext, permits);
+  }
+
+  private boolean fetchTokens(EnvoyRateLimiterContext context, int tokens)
+      throws IOException, InterruptedException {
+
+    init();
+    RateLimitServiceGrpc.RateLimitServiceBlockingStub currentStub = stub;
+    if (currentStub == null) {
+      throw new IllegalStateException("RateLimitServiceStub is null");
+    }
+
+    Map<String, String> descriptors = context.getDescriptors();
+    RateLimitDescriptor.Builder descriptorBuilder = 
RateLimitDescriptor.newBuilder();
+
+    for (Map.Entry<String, String> entry : descriptors.entrySet()) {
+      descriptorBuilder.addEntries(
+          RateLimitDescriptor.Entry.newBuilder()
+              .setKey(entry.getKey())
+              .setValue(entry.getValue())
+              .build());
+    }
+
+    RateLimitRequest request =
+        RateLimitRequest.newBuilder()
+            .setDomain(context.getDomain())
+            .setHitsAddend(tokens)
+            .addDescriptors(descriptorBuilder.build())
+            .build();
+
+    Integer maxRetries = options.getMaxRetries();
+    long timeoutMillis = options.getTimeout().toMillis();
+
+    requestsTotal.inc();
+    int attempt = 0;
+    while (true) {
+      if (maxRetries != null && attempt > maxRetries) {
+        return false;
+      }
+
+      // RPC Retry Loop
+      RateLimitResponse response = null;
+      long startTime = System.currentTimeMillis();
+      for (int i = 0; i < RPC_RETRY_COUNT; i++) {
+        try {
+          response =
+              currentStub
+                  .withDeadlineAfter(timeoutMillis, 
java.util.concurrent.TimeUnit.MILLISECONDS)
+                  .shouldRateLimit(request);
+          long endTime = System.currentTimeMillis();
+          rpcLatency.update(endTime - startTime);
+          break;
+        } catch (StatusRuntimeException e) {
+          rpcErrors.inc();
+          if (i == RPC_RETRY_COUNT - 1) {
+            LOG.error("RateLimitService call failed after {} attempts", 
RPC_RETRY_COUNT, e);
+            throw new IOException("Failed to call Rate Limit Service", e);
+          }
+          rpcRetries.inc();
+          LOG.warn("RateLimitService call failed, retrying", e);
+          if (sleeper != null) {
+            sleeper.sleep(RPC_RETRY_DELAY_MILLIS);
+          }
+        }
+      }
+
+      if (response == null) {
+        throw new IOException("Failed to get response from Rate Limit 
Service");
+      }
+
+      if (response.getOverallCode() == RateLimitResponse.Code.OK) {
+        requestsAllowed.inc();
+        return true;
+      } else if (response.getOverallCode() == 
RateLimitResponse.Code.OVER_LIMIT) {
+        long sleepMillis = 0;
+        for (RateLimitResponse.DescriptorStatus status : 
response.getStatusesList()) {
+          if (status.getCode() == RateLimitResponse.Code.OVER_LIMIT
+              && status.hasDurationUntilReset()) {
+            long durationMillis =
+                status.getDurationUntilReset().getSeconds() * 1000
+                    + status.getDurationUntilReset().getNanos() / 1_000_000;
+            if (durationMillis > sleepMillis) {
+              sleepMillis = durationMillis;
+            }
+          }
+        }
+
+        if (sleepMillis == 0) {
+          sleepMillis = 1000;
+        }
+
+        long jitter =
+            (long)
+                (java.util.concurrent.ThreadLocalRandom.current().nextDouble()
+                    * (0.01 * sleepMillis));

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The current jitter calculation adds a random value between 0 and 1% of the 
sleep duration. This is a very small amount of jitter and may not be sufficient 
to prevent a "thundering herd" problem where many clients retry simultaneously 
after being throttled. It's common practice to use a larger jitter, for 
example, by adding a random value up to 10% or 20% of the backoff time.
   
   ```suggestion
           long jitter =
               (long)
                   
(java.util.concurrent.ThreadLocalRandom.current().nextDouble()
                       * (0.1 * sleepMillis));
   ```



##########
examples/java/src/main/java/org/apache/beam/examples/RateLimiterSimple.java:
##########
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.examples;
+
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.io.components.ratelimiter.EnvoyRateLimiterContext;
+import org.apache.beam.sdk.io.components.ratelimiter.EnvoyRateLimiterFactory;
+import org.apache.beam.sdk.io.components.ratelimiter.RateLimiter;
+import org.apache.beam.sdk.io.components.ratelimiter.RateLimiterContext;
+import org.apache.beam.sdk.io.components.ratelimiter.RateLimiterFactory;
+import org.apache.beam.sdk.io.components.ratelimiter.RateLimiterOptions;
+import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A simple example demonstrating how to use the {@link RateLimiter} in a 
custom {@link DoFn}.
+ *
+ * <p>This pipeline creates a small set of elements and processes them using a 
DoFn that calls an
+ * external service (simulated). The processing is rate-limited using an Envoy 
Rate Limit Service.
+ *
+ * <p>To run this example, you need a running Envoy Rate Limit Service.
+ */
+public class RateLimiterSimple {
+
+  public interface Options extends PipelineOptions {
+    @Description("Address of the Envoy Rate Limit Service(eg:localhost:8081)")
+    String getRateLimiterAddress();
+
+    void setRateLimiterAddress(String value);
+
+    @Description("Domain for the Rate Limit Service(eg:mydomain)")
+    String getRateLimiterDomain();
+
+    void setRateLimiterDomain(String value);
+  }
+
+  static class CallExternalServiceFn extends DoFn<String, String> {
+    private final String rlsAddress;
+    private final String rlsDomain;
+    private transient @Nullable RateLimiter rateLimiter;
+    private static final Logger LOG = 
LoggerFactory.getLogger(CallExternalServiceFn.class);
+
+    public CallExternalServiceFn(String rlsAddress, String rlsDomain) {
+      this.rlsAddress = rlsAddress;
+      this.rlsDomain = rlsDomain;
+    }
+
+    @Setup
+    public void setup() {
+      // Create the RateLimiterOptions.
+      RateLimiterOptions options = 
RateLimiterOptions.builder().setAddress(rlsAddress).build();
+
+      // Static RateLimtier with pre-configured domain and descriptors
+      RateLimiterFactory factory = new EnvoyRateLimiterFactory(options);
+      RateLimiterContext context =
+          EnvoyRateLimiterContext.builder()
+              .setDomain(rlsDomain)
+              .addDescriptor("database", "users")
+              .build();
+      this.rateLimiter = factory.getLimiter(context);
+    }
+
+    @Teardown
+    public void teardown() {
+      if (rateLimiter != null) {
+        try {
+          rateLimiter.close();
+        } catch (Exception e) {
+          throw new RuntimeException("Failed to close RateLimiter", e);
+        }

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   In a `@Teardown` method, it's generally better to avoid throwing a 
`RuntimeException` for cleanup failures. A failure here could cause the entire 
worker to be torn down and restarted, which might be an overly aggressive 
response to a failed `close()` on a rate limiter. Consider logging the error 
instead to ensure that teardown is as robust as possible.
   
   ```suggestion
           try {
             rateLimiter.close();
           } catch (Exception e) {
             LOG.warn("Failed to close RateLimiter", e);
           }
   ```



##########
sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterClientCache.java:
##########
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.components.ratelimiter;
+
+import io.grpc.ManagedChannel;
+import io.grpc.ManagedChannelBuilder;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A static cache for {@link ManagedChannel}s to Rate Limit Service.
+ *
+ * <p>This class ensures that multiple DoFn instances (threads) in the same 
Worker sharing the same
+ * RLS address will share a single {@link ManagedChannel}.
+ *
+ * <p>It uses reference counting to close the channel when it is no longer in 
use by any RateLimiter
+ * instance.
+ */
+public class RateLimiterClientCache {
+  private static final Logger LOG = 
LoggerFactory.getLogger(RateLimiterClientCache.class);
+  private static final Map<String, RateLimiterClientCache> CACHE = new 
ConcurrentHashMap<>();
+
+  private final ManagedChannel channel;
+  private final String address;
+  private int refCount = 0;
+
+  private RateLimiterClientCache(String address) {
+    this.address = address;
+    LOG.info("Creating new ManagedChannel for RLS at {}", address);
+    this.channel = 
ManagedChannelBuilder.forTarget(address).usePlaintext().build();
+  }
+
+  /**
+   * Gets or creates a cached client for the given address. Increments the 
reference count.
+   * Synchronized on the class to prevent race conditions when multiple 
instances call getOrCreate()
+   * simultaneously
+   */
+  public static synchronized RateLimiterClientCache getOrCreate(String 
address) {
+    RateLimiterClientCache client = CACHE.get(address);
+    if (client == null) {
+      client = new RateLimiterClientCache(address);
+      CACHE.put(address, client);
+    }
+    client.refCount++;
+    LOG.debug("Referenced RLS Channel for {}. New RefCount: {}", address, 
client.refCount);
+    return client;
+  }
+
+  public ManagedChannel getChannel() {
+    return channel;
+  }
+
+  /**
+   * Releases the client. Decrements the reference count. If reference count 
reaches 0, the channel
+   * is shut down and removed from the cache. Synchronized on the class to 
prevent race conditions
+   * when multiple threads call release() simultaneously and to prevent race 
conditions between
+   * getOrCreate() and release() calls.
+   */
+  public void release() {
+    synchronized (RateLimiterClientCache.class) {
+      refCount--;
+      LOG.debug("Released RLS Channel for {}. New RefCount: {}", address, 
refCount);
+      if (refCount <= 0) {
+        LOG.info("Closing ManagedChannel for RLS at {}", address);
+        CACHE.remove(address);
+        channel.shutdown();
+        try {
+          channel.awaitTermination(10, TimeUnit.SECONDS);
+        } catch (InterruptedException e) {
+          LOG.error("Couldn't gracefully close gRPC channel={}", channel, e);
+        }

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   When catching an `InterruptedException`, it's a good practice to restore the 
interrupted status of the thread by calling 
`Thread.currentThread().interrupt()`. This allows upstream code to be aware 
that an interruption occurred and handle it appropriately.
   
   ```suggestion
           try {
             channel.awaitTermination(10, TimeUnit.SECONDS);
           } catch (InterruptedException e) {
             LOG.error("Couldn't gracefully close gRPC channel={}", channel, e);
             Thread.currentThread().interrupt();
           }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to