gemini-code-assist[bot] commented on code in PR #37573: URL: https://github.com/apache/beam/pull/37573#discussion_r2801184550
########## sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/EnvoyRateLimiterFactory.java: ########## @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.components.ratelimiter; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; + +import io.envoyproxy.envoy.extensions.common.ratelimit.v3.RateLimitDescriptor; +import io.envoyproxy.envoy.service.ratelimit.v3.RateLimitRequest; +import io.envoyproxy.envoy.service.ratelimit.v3.RateLimitResponse; +import io.envoyproxy.envoy.service.ratelimit.v3.RateLimitServiceGrpc; +import io.grpc.StatusRuntimeException; +import java.io.IOException; +import java.util.Map; +import javax.annotation.Nullable; +import org.apache.beam.sdk.io.components.throttling.ThrottlingSignaler; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Distribution; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.util.Sleeper; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A {@link RateLimiterFactory} for Envoy Rate Limit Service. */ +public class EnvoyRateLimiterFactory implements RateLimiterFactory { + private static final Logger LOG = LoggerFactory.getLogger(EnvoyRateLimiterFactory.class); + private static final int RPC_RETRY_COUNT = 3; + private static final long RPC_RETRY_DELAY_MILLIS = 5000; + + private final RateLimiterOptions options; + + private transient volatile @Nullable RateLimitServiceGrpc.RateLimitServiceBlockingStub stub; + private transient @Nullable RateLimiterClientCache clientCache; + private final ThrottlingSignaler throttlingSignaler; + private final Sleeper sleeper; + + private final Counter requestsTotal; + private final Counter requestsAllowed; + private final Counter requestsThrottled; + private final Counter rpcErrors; + private final Counter rpcRetries; + private final Distribution rpcLatency; + + public EnvoyRateLimiterFactory(RateLimiterOptions options) { + this(options, Sleeper.DEFAULT); + } + + @VisibleForTesting + EnvoyRateLimiterFactory(RateLimiterOptions options, Sleeper sleeper) { + this.options = options; + this.sleeper = sleeper; + String namespace = EnvoyRateLimiterFactory.class.getName(); + this.throttlingSignaler = new ThrottlingSignaler(namespace); + this.requestsTotal = Metrics.counter(namespace, "ratelimit-requests-total"); + this.requestsAllowed = Metrics.counter(namespace, "ratelimit-requests-allowed"); + this.requestsThrottled = Metrics.counter(namespace, "ratelimit-requests-throttled"); + this.rpcErrors = Metrics.counter(namespace, "ratelimit-rpc-errors"); + this.rpcRetries = Metrics.counter(namespace, "ratelimit-rpc-retries"); + this.rpcLatency = Metrics.distribution(namespace, "ratelimit-rpc-latency-ms"); + } + + @Override + public synchronized void close() { + if (clientCache != null) { + clientCache.release(); + clientCache = null; + } + stub = null; + } + + private void init() { + if (stub != null) { + return; + } + synchronized (this) { + if (stub == null) { + RateLimiterClientCache cache = RateLimiterClientCache.getOrCreate(options.getAddress()); + this.clientCache = cache; + stub = RateLimitServiceGrpc.newBlockingStub(cache.getChannel()); + } + } + } + + @VisibleForTesting + void setStub(RateLimitServiceGrpc.RateLimitServiceBlockingStub stub) { + this.stub = stub; + } + + @Override + public RateLimiter getLimiter(RateLimiterContext context) { + if (!(context instanceof EnvoyRateLimiterContext)) { + throw new IllegalArgumentException( + "EnvoyRateLimiterFactory requires EnvoyRateLimiterContext"); + } + return new EnvoyRateLimiter(this, (EnvoyRateLimiterContext) context); + } + + @Override + public boolean allow(RateLimiterContext context, int permits) + throws IOException, InterruptedException { + if (permits == 0) { + return true; + } + if (!(context instanceof EnvoyRateLimiterContext)) { + throw new IllegalArgumentException( + "EnvoyRateLimiterFactory requires EnvoyRateLimiterContext, got: " + + context.getClass().getName()); + } + checkArgument(permits >= 0, "Permits must be non-negative"); + EnvoyRateLimiterContext envoyContext = (EnvoyRateLimiterContext) context; + return fetchTokens(envoyContext, permits); + } + + private boolean fetchTokens(EnvoyRateLimiterContext context, int tokens) + throws IOException, InterruptedException { + + init(); + RateLimitServiceGrpc.RateLimitServiceBlockingStub currentStub = stub; + if (currentStub == null) { + throw new IllegalStateException("RateLimitServiceStub is null"); + } + + Map<String, String> descriptors = context.getDescriptors(); + RateLimitDescriptor.Builder descriptorBuilder = RateLimitDescriptor.newBuilder(); + + for (Map.Entry<String, String> entry : descriptors.entrySet()) { + descriptorBuilder.addEntries( + RateLimitDescriptor.Entry.newBuilder() + .setKey(entry.getKey()) + .setValue(entry.getValue()) + .build()); + } + + RateLimitRequest request = + RateLimitRequest.newBuilder() + .setDomain(context.getDomain()) + .setHitsAddend(tokens) + .addDescriptors(descriptorBuilder.build()) + .build(); + + Integer maxRetries = options.getMaxRetries(); + long timeoutMillis = options.getTimeout().toMillis(); + + requestsTotal.inc(); + int attempt = 0; + while (true) { + if (maxRetries != null && attempt > maxRetries) { + return false; + } + + // RPC Retry Loop + RateLimitResponse response = null; + long startTime = System.currentTimeMillis(); + for (int i = 0; i < RPC_RETRY_COUNT; i++) { + try { + response = + currentStub + .withDeadlineAfter(timeoutMillis, java.util.concurrent.TimeUnit.MILLISECONDS) + .shouldRateLimit(request); + long endTime = System.currentTimeMillis(); + rpcLatency.update(endTime - startTime); + break; + } catch (StatusRuntimeException e) { + rpcErrors.inc(); + if (i == RPC_RETRY_COUNT - 1) { + LOG.error("RateLimitService call failed after {} attempts", RPC_RETRY_COUNT, e); + throw new IOException("Failed to call Rate Limit Service", e); + } + rpcRetries.inc(); + LOG.warn("RateLimitService call failed, retrying", e); + if (sleeper != null) { + sleeper.sleep(RPC_RETRY_DELAY_MILLIS); + } + } + } + + if (response == null) { + throw new IOException("Failed to get response from Rate Limit Service"); + } + + if (response.getOverallCode() == RateLimitResponse.Code.OK) { + requestsAllowed.inc(); + return true; + } else if (response.getOverallCode() == RateLimitResponse.Code.OVER_LIMIT) { + long sleepMillis = 0; + for (RateLimitResponse.DescriptorStatus status : response.getStatusesList()) { + if (status.getCode() == RateLimitResponse.Code.OVER_LIMIT + && status.hasDurationUntilReset()) { + long durationMillis = + status.getDurationUntilReset().getSeconds() * 1000 + + status.getDurationUntilReset().getNanos() / 1_000_000; + if (durationMillis > sleepMillis) { + sleepMillis = durationMillis; + } + } + } + + if (sleepMillis == 0) { + sleepMillis = 1000; + } + + long jitter = + (long) + (java.util.concurrent.ThreadLocalRandom.current().nextDouble() + * (0.01 * sleepMillis)); Review Comment:  The current jitter calculation adds a random value between 0 and 1% of the sleep duration. This is a very small amount of jitter and may not be sufficient to prevent a "thundering herd" problem where many clients retry simultaneously after being throttled. It's common practice to use a larger jitter, for example, by adding a random value up to 10% or 20% of the backoff time. ```suggestion long jitter = (long) (java.util.concurrent.ThreadLocalRandom.current().nextDouble() * (0.1 * sleepMillis)); ``` ########## examples/java/src/main/java/org/apache/beam/examples/RateLimiterSimple.java: ########## @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.examples; + +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.components.ratelimiter.EnvoyRateLimiterContext; +import org.apache.beam.sdk.io.components.ratelimiter.EnvoyRateLimiterFactory; +import org.apache.beam.sdk.io.components.ratelimiter.RateLimiter; +import org.apache.beam.sdk.io.components.ratelimiter.RateLimiterContext; +import org.apache.beam.sdk.io.components.ratelimiter.RateLimiterFactory; +import org.apache.beam.sdk.io.components.ratelimiter.RateLimiterOptions; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A simple example demonstrating how to use the {@link RateLimiter} in a custom {@link DoFn}. + * + * <p>This pipeline creates a small set of elements and processes them using a DoFn that calls an + * external service (simulated). The processing is rate-limited using an Envoy Rate Limit Service. + * + * <p>To run this example, you need a running Envoy Rate Limit Service. + */ +public class RateLimiterSimple { + + public interface Options extends PipelineOptions { + @Description("Address of the Envoy Rate Limit Service(eg:localhost:8081)") + String getRateLimiterAddress(); + + void setRateLimiterAddress(String value); + + @Description("Domain for the Rate Limit Service(eg:mydomain)") + String getRateLimiterDomain(); + + void setRateLimiterDomain(String value); + } + + static class CallExternalServiceFn extends DoFn<String, String> { + private final String rlsAddress; + private final String rlsDomain; + private transient @Nullable RateLimiter rateLimiter; + private static final Logger LOG = LoggerFactory.getLogger(CallExternalServiceFn.class); + + public CallExternalServiceFn(String rlsAddress, String rlsDomain) { + this.rlsAddress = rlsAddress; + this.rlsDomain = rlsDomain; + } + + @Setup + public void setup() { + // Create the RateLimiterOptions. + RateLimiterOptions options = RateLimiterOptions.builder().setAddress(rlsAddress).build(); + + // Static RateLimtier with pre-configured domain and descriptors + RateLimiterFactory factory = new EnvoyRateLimiterFactory(options); + RateLimiterContext context = + EnvoyRateLimiterContext.builder() + .setDomain(rlsDomain) + .addDescriptor("database", "users") + .build(); + this.rateLimiter = factory.getLimiter(context); + } + + @Teardown + public void teardown() { + if (rateLimiter != null) { + try { + rateLimiter.close(); + } catch (Exception e) { + throw new RuntimeException("Failed to close RateLimiter", e); + } Review Comment:  In a `@Teardown` method, it's generally better to avoid throwing a `RuntimeException` for cleanup failures. A failure here could cause the entire worker to be torn down and restarted, which might be an overly aggressive response to a failed `close()` on a rate limiter. Consider logging the error instead to ensure that teardown is as robust as possible. ```suggestion try { rateLimiter.close(); } catch (Exception e) { LOG.warn("Failed to close RateLimiter", e); } ``` ########## sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterClientCache.java: ########## @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.components.ratelimiter; + +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A static cache for {@link ManagedChannel}s to Rate Limit Service. + * + * <p>This class ensures that multiple DoFn instances (threads) in the same Worker sharing the same + * RLS address will share a single {@link ManagedChannel}. + * + * <p>It uses reference counting to close the channel when it is no longer in use by any RateLimiter + * instance. + */ +public class RateLimiterClientCache { + private static final Logger LOG = LoggerFactory.getLogger(RateLimiterClientCache.class); + private static final Map<String, RateLimiterClientCache> CACHE = new ConcurrentHashMap<>(); + + private final ManagedChannel channel; + private final String address; + private int refCount = 0; + + private RateLimiterClientCache(String address) { + this.address = address; + LOG.info("Creating new ManagedChannel for RLS at {}", address); + this.channel = ManagedChannelBuilder.forTarget(address).usePlaintext().build(); + } + + /** + * Gets or creates a cached client for the given address. Increments the reference count. + * Synchronized on the class to prevent race conditions when multiple instances call getOrCreate() + * simultaneously + */ + public static synchronized RateLimiterClientCache getOrCreate(String address) { + RateLimiterClientCache client = CACHE.get(address); + if (client == null) { + client = new RateLimiterClientCache(address); + CACHE.put(address, client); + } + client.refCount++; + LOG.debug("Referenced RLS Channel for {}. New RefCount: {}", address, client.refCount); + return client; + } + + public ManagedChannel getChannel() { + return channel; + } + + /** + * Releases the client. Decrements the reference count. If reference count reaches 0, the channel + * is shut down and removed from the cache. Synchronized on the class to prevent race conditions + * when multiple threads call release() simultaneously and to prevent race conditions between + * getOrCreate() and release() calls. + */ + public void release() { + synchronized (RateLimiterClientCache.class) { + refCount--; + LOG.debug("Released RLS Channel for {}. New RefCount: {}", address, refCount); + if (refCount <= 0) { + LOG.info("Closing ManagedChannel for RLS at {}", address); + CACHE.remove(address); + channel.shutdown(); + try { + channel.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + LOG.error("Couldn't gracefully close gRPC channel={}", channel, e); + } Review Comment:  When catching an `InterruptedException`, it's a good practice to restore the interrupted status of the thread by calling `Thread.currentThread().interrupt()`. This allows upstream code to be aware that an interruption occurred and handle it appropriately. ```suggestion try { channel.awaitTermination(10, TimeUnit.SECONDS); } catch (InterruptedException e) { LOG.error("Couldn't gracefully close gRPC channel={}", channel, e); Thread.currentThread().interrupt(); } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
