This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 68f1543b6bf Simplify budget distribution logic and new worker metadata
consumption (#32775)
68f1543b6bf is described below
commit 68f1543b6bfe4ceaa752c7f16fc2cae7393211fd
Author: martin trieu <[email protected]>
AuthorDate: Mon Oct 21 03:05:43 2024 -0600
Simplify budget distribution logic and new worker metadata consumption
(#32775)
---
.../FanOutStreamingEngineWorkerHarness.java | 379 +++++++++----------
.../streaming/harness/GlobalDataStreamSender.java | 63 ++++
...tionState.java => StreamingEngineBackends.java} | 30 +-
.../streaming/harness/WindmillStreamSender.java | 25 +-
.../worker/windmill/WindmillEndpoints.java | 28 +-
.../worker/windmill/WindmillServiceAddress.java | 22 +-
.../worker/windmill/client/WindmillStream.java | 7 +-
.../client/grpc/GrpcDirectGetWorkStream.java | 286 ++++++++++-----
.../windmill/client/grpc/GrpcGetDataStream.java | 2 +-
.../windmill/client/grpc/GrpcGetWorkStream.java | 10 +-
.../client/grpc/GrpcWindmillStreamFactory.java | 6 +-
.../client/grpc/stubs/WindmillChannelFactory.java | 17 +-
.../work/budget/EvenGetWorkBudgetDistributor.java | 59 +--
.../work/budget/GetWorkBudgetDistributors.java | 6 +-
.../windmill/work/budget/GetWorkBudgetSpender.java | 8 +-
.../dataflow/worker/FakeWindmillServer.java | 10 +-
.../FanOutStreamingEngineWorkerHarnessTest.java | 111 ++----
.../harness/WindmillStreamSenderTest.java | 4 +-
.../client/grpc/GrpcDirectGetWorkStreamTest.java | 405 +++++++++++++++++++++
.../budget/EvenGetWorkBudgetDistributorTest.java | 186 ++--------
20 files changed, 998 insertions(+), 666 deletions(-)
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java
index 3556b7ce291..458cf57ca8e 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java
@@ -20,20 +20,25 @@ package
org.apache.beam.runners.dataflow.worker.streaming.harness;
import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet;
-import java.util.Collection;
-import java.util.List;
+import java.io.Closeable;
+import java.util.HashSet;
import java.util.Map.Entry;
+import java.util.NoSuchElementException;
import java.util.Optional;
-import java.util.Queue;
-import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
-import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
-import java.util.function.Supplier;
+import java.util.stream.Collectors;
import javax.annotation.CheckReturnValue;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
@@ -54,18 +59,14 @@ import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.Thrott
import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
-import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher;
import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.util.MoreFutures;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
-import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -80,32 +81,39 @@ import org.slf4j.LoggerFactory;
public final class FanOutStreamingEngineWorkerHarness implements
StreamingWorkerHarness {
private static final Logger LOG =
LoggerFactory.getLogger(FanOutStreamingEngineWorkerHarness.class);
- private static final String PUBLISH_NEW_WORKER_METADATA_THREAD =
"PublishNewWorkerMetadataThread";
- private static final String CONSUME_NEW_WORKER_METADATA_THREAD =
"ConsumeNewWorkerMetadataThread";
+ private static final String WORKER_METADATA_CONSUMER_THREAD_NAME =
+ "WindmillWorkerMetadataConsumerThread";
+ private static final String STREAM_MANAGER_THREAD_NAME =
"WindmillStreamManager-%d";
private final JobHeader jobHeader;
private final GrpcWindmillStreamFactory streamFactory;
private final WorkItemScheduler workItemScheduler;
private final ChannelCachingStubFactory channelCachingStubFactory;
private final GrpcDispatcherClient dispatcherClient;
- private final AtomicBoolean isBudgetRefreshPaused;
- private final GetWorkBudgetRefresher getWorkBudgetRefresher;
- private final AtomicReference<Instant> lastBudgetRefresh;
+ private final GetWorkBudgetDistributor getWorkBudgetDistributor;
+ private final GetWorkBudget totalGetWorkBudget;
private final ThrottleTimer getWorkerMetadataThrottleTimer;
- private final ExecutorService newWorkerMetadataPublisher;
- private final ExecutorService newWorkerMetadataConsumer;
- private final long clientId;
- private final Supplier<GetWorkerMetadataStream> getWorkerMetadataStream;
- private final Queue<WindmillEndpoints> newWindmillEndpoints;
private final Function<WindmillStream.CommitWorkStream, WorkCommitter>
workCommitterFactory;
private final ThrottlingGetDataMetricTracker getDataMetricTracker;
+ private final ExecutorService windmillStreamManager;
+ private final ExecutorService workerMetadataConsumer;
+ private final Object metadataLock = new Object();
/** Writes are guarded by synchronization, reads are lock free. */
- private final AtomicReference<StreamingEngineConnectionState> connections;
+ private final AtomicReference<StreamingEngineBackends> backends;
- private volatile boolean started;
+ @GuardedBy("this")
+ private long activeMetadataVersion;
+
+ @GuardedBy("metadataLock")
+ private long pendingMetadataVersion;
+
+ @GuardedBy("this")
+ private boolean started;
+
+ @GuardedBy("this")
+ private @Nullable GetWorkerMetadataStream getWorkerMetadataStream;
- @SuppressWarnings("FutureReturnValueIgnored")
private FanOutStreamingEngineWorkerHarness(
JobHeader jobHeader,
GetWorkBudget totalGetWorkBudget,
@@ -114,7 +122,6 @@ public final class FanOutStreamingEngineWorkerHarness
implements StreamingWorker
ChannelCachingStubFactory channelCachingStubFactory,
GetWorkBudgetDistributor getWorkBudgetDistributor,
GrpcDispatcherClient dispatcherClient,
- long clientId,
Function<WindmillStream.CommitWorkStream, WorkCommitter>
workCommitterFactory,
ThrottlingGetDataMetricTracker getDataMetricTracker) {
this.jobHeader = jobHeader;
@@ -122,42 +129,21 @@ public final class FanOutStreamingEngineWorkerHarness
implements StreamingWorker
this.started = false;
this.streamFactory = streamFactory;
this.workItemScheduler = workItemScheduler;
- this.connections = new
AtomicReference<>(StreamingEngineConnectionState.EMPTY);
+ this.backends = new AtomicReference<>(StreamingEngineBackends.EMPTY);
this.channelCachingStubFactory = channelCachingStubFactory;
this.dispatcherClient = dispatcherClient;
- this.isBudgetRefreshPaused = new AtomicBoolean(false);
this.getWorkerMetadataThrottleTimer = new ThrottleTimer();
- this.newWorkerMetadataPublisher =
- singleThreadedExecutorServiceOf(PUBLISH_NEW_WORKER_METADATA_THREAD);
- this.newWorkerMetadataConsumer =
- singleThreadedExecutorServiceOf(CONSUME_NEW_WORKER_METADATA_THREAD);
- this.clientId = clientId;
- this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH);
- this.newWindmillEndpoints =
Queues.synchronizedQueue(EvictingQueue.create(1));
- this.getWorkBudgetRefresher =
- new GetWorkBudgetRefresher(
- isBudgetRefreshPaused::get,
- () -> {
- getWorkBudgetDistributor.distributeBudget(
- connections.get().windmillStreams().values(),
totalGetWorkBudget);
- lastBudgetRefresh.set(Instant.now());
- });
- this.getWorkerMetadataStream =
- Suppliers.memoize(
- () ->
- streamFactory.createGetWorkerMetadataStream(
- dispatcherClient.getWindmillMetadataServiceStubBlocking(),
- getWorkerMetadataThrottleTimer,
- endpoints ->
- // Run this on a separate thread than the grpc stream
thread.
- newWorkerMetadataPublisher.submit(
- () -> newWindmillEndpoints.add(endpoints))));
+ this.windmillStreamManager =
+ Executors.newCachedThreadPool(
+ new
ThreadFactoryBuilder().setNameFormat(STREAM_MANAGER_THREAD_NAME).build());
+ this.workerMetadataConsumer =
+ Executors.newSingleThreadScheduledExecutor(
+ new
ThreadFactoryBuilder().setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME).build());
+ this.getWorkBudgetDistributor = getWorkBudgetDistributor;
+ this.totalGetWorkBudget = totalGetWorkBudget;
+ this.activeMetadataVersion = Long.MIN_VALUE;
this.workCommitterFactory = workCommitterFactory;
- }
-
- private static ExecutorService singleThreadedExecutorServiceOf(String
threadName) {
- return Executors.newSingleThreadScheduledExecutor(
- new ThreadFactoryBuilder().setNameFormat(threadName).build());
+ this.getWorkerMetadataStream = null;
}
/**
@@ -183,7 +169,6 @@ public final class FanOutStreamingEngineWorkerHarness
implements StreamingWorker
channelCachingStubFactory,
getWorkBudgetDistributor,
dispatcherClient,
- /* clientId= */ new Random().nextLong(),
workCommitterFactory,
getDataMetricTracker);
}
@@ -197,7 +182,6 @@ public final class FanOutStreamingEngineWorkerHarness
implements StreamingWorker
ChannelCachingStubFactory stubFactory,
GetWorkBudgetDistributor getWorkBudgetDistributor,
GrpcDispatcherClient dispatcherClient,
- long clientId,
Function<WindmillStream.CommitWorkStream, WorkCommitter>
workCommitterFactory,
ThrottlingGetDataMetricTracker getDataMetricTracker) {
FanOutStreamingEngineWorkerHarness fanOutStreamingEngineWorkProvider =
@@ -209,201 +193,218 @@ public final class FanOutStreamingEngineWorkerHarness
implements StreamingWorker
stubFactory,
getWorkBudgetDistributor,
dispatcherClient,
- clientId,
workCommitterFactory,
getDataMetricTracker);
fanOutStreamingEngineWorkProvider.start();
return fanOutStreamingEngineWorkProvider;
}
- @SuppressWarnings("ReturnValueIgnored")
@Override
public synchronized void start() {
- Preconditions.checkState(!started, "StreamingEngineClient cannot start
twice.");
- // Starts the stream, this value is memoized.
- getWorkerMetadataStream.get();
- startWorkerMetadataConsumer();
- getWorkBudgetRefresher.start();
+ Preconditions.checkState(!started, "FanOutStreamingEngineWorkerHarness
cannot start twice.");
+ getWorkerMetadataStream =
+ streamFactory.createGetWorkerMetadataStream(
+ dispatcherClient.getWindmillMetadataServiceStubBlocking(),
+ getWorkerMetadataThrottleTimer,
+ this::consumeWorkerMetadata);
started = true;
}
public ImmutableSet<HostAndPort> currentWindmillEndpoints() {
- return connections.get().windmillConnections().keySet().stream()
+ return backends.get().windmillStreams().keySet().stream()
.map(Endpoint::directEndpoint)
.filter(Optional::isPresent)
.map(Optional::get)
- .filter(
- windmillServiceAddress ->
- windmillServiceAddress.getKind() !=
WindmillServiceAddress.Kind.IPV6)
- .map(
- windmillServiceAddress ->
- windmillServiceAddress.getKind() ==
WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS
- ? windmillServiceAddress.gcpServiceAddress()
- :
windmillServiceAddress.authenticatedGcpServiceAddress().gcpServiceAddress())
+ .map(WindmillServiceAddress::getServiceAddress)
.collect(toImmutableSet());
}
/**
- * Fetches {@link GetDataStream} mapped to globalDataKey if one exists, or
defaults to {@link
- * GetDataStream} pointing to dispatcher.
+ * Fetches {@link GetDataStream} mapped to globalDataKey if or throws {@link
+ * NoSuchElementException} if one is not found.
*/
private GetDataStream getGlobalDataStream(String globalDataKey) {
- return
Optional.ofNullable(connections.get().globalDataStreams().get(globalDataKey))
- .map(Supplier::get)
- .orElseGet(
- () ->
- streamFactory.createGetDataStream(
- dispatcherClient.getWindmillServiceStub(), new
ThrottleTimer()));
- }
-
- @SuppressWarnings("FutureReturnValueIgnored")
- private void startWorkerMetadataConsumer() {
- newWorkerMetadataConsumer.submit(
- () -> {
- while (true) {
- Optional.ofNullable(newWindmillEndpoints.poll())
- .ifPresent(this::consumeWindmillWorkerEndpoints);
- }
- });
+ return
Optional.ofNullable(backends.get().globalDataStreams().get(globalDataKey))
+ .map(GlobalDataStreamSender::get)
+ .orElseThrow(
+ () -> new NoSuchElementException("No endpoint for global data tag:
" + globalDataKey));
}
@VisibleForTesting
@Override
public synchronized void shutdown() {
- Preconditions.checkState(started, "StreamingEngineClient never started.");
- getWorkerMetadataStream.get().halfClose();
- getWorkBudgetRefresher.stop();
- newWorkerMetadataPublisher.shutdownNow();
- newWorkerMetadataConsumer.shutdownNow();
+ Preconditions.checkState(started, "FanOutStreamingEngineWorkerHarness
never started.");
+ Preconditions.checkNotNull(getWorkerMetadataStream).shutdown();
+ workerMetadataConsumer.shutdownNow();
+ closeStreamsNotIn(WindmillEndpoints.none());
channelCachingStubFactory.shutdown();
+
+ try {
+ Preconditions.checkNotNull(getWorkerMetadataStream).awaitTermination(10,
TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ LOG.warn("Interrupted waiting for GetWorkerMetadataStream to shutdown.",
e);
+ }
+
+ windmillStreamManager.shutdown();
+ boolean isStreamManagerShutdown = false;
+ try {
+ isStreamManagerShutdown = windmillStreamManager.awaitTermination(30,
TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ LOG.warn("Interrupted waiting for windmillStreamManager to shutdown.",
e);
+ }
+ if (!isStreamManagerShutdown) {
+ windmillStreamManager.shutdownNow();
+ }
+ }
+
+ private void consumeWorkerMetadata(WindmillEndpoints windmillEndpoints) {
+ synchronized (metadataLock) {
+ // Only process versions greater than what we currently have to prevent
double processing of
+ // metadata. workerMetadataConsumer is single-threaded so we maintain
ordering.
+ if (windmillEndpoints.version() > pendingMetadataVersion) {
+ pendingMetadataVersion = windmillEndpoints.version();
+ workerMetadataConsumer.execute(() ->
consumeWindmillWorkerEndpoints(windmillEndpoints));
+ }
+ }
}
- /**
- * {@link java.util.function.Consumer<WindmillEndpoints>} used to update
{@link #connections} on
- * new backend worker metadata.
- */
private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints
newWindmillEndpoints) {
- isBudgetRefreshPaused.set(true);
- LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints);
- ImmutableMap<Endpoint, WindmillConnection> newWindmillConnections =
- createNewWindmillConnections(newWindmillEndpoints.windmillEndpoints());
-
- StreamingEngineConnectionState newConnectionsState =
- StreamingEngineConnectionState.builder()
- .setWindmillConnections(newWindmillConnections)
- .setWindmillStreams(
-
closeStaleStreamsAndCreateNewStreams(newWindmillConnections.values()))
+ // Since this is run on a single threaded executor, multiple versions of
the metadata maybe
+ // queued up while a previous version of the windmillEndpoints were being
consumed. Only consume
+ // the endpoints if they are the most current version.
+ synchronized (metadataLock) {
+ if (newWindmillEndpoints.version() < pendingMetadataVersion) {
+ return;
+ }
+ }
+
+ LOG.debug(
+ "Consuming new endpoints: {}. previous metadata version: {}, current
metadata version: {}",
+ newWindmillEndpoints,
+ activeMetadataVersion,
+ newWindmillEndpoints.version());
+ closeStreamsNotIn(newWindmillEndpoints);
+ ImmutableMap<Endpoint, WindmillStreamSender> newStreams =
+
createAndStartNewStreams(newWindmillEndpoints.windmillEndpoints()).join();
+ StreamingEngineBackends newBackends =
+ StreamingEngineBackends.builder()
+ .setWindmillStreams(newStreams)
.setGlobalDataStreams(
createNewGlobalDataStreams(newWindmillEndpoints.globalDataEndpoints()))
.build();
+ backends.set(newBackends);
+ getWorkBudgetDistributor.distributeBudget(newStreams.values(),
totalGetWorkBudget);
+ activeMetadataVersion = newWindmillEndpoints.version();
+ }
+
+ /** Close the streams that are no longer valid asynchronously. */
+ private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) {
+ StreamingEngineBackends currentBackends = backends.get();
+ currentBackends.windmillStreams().entrySet().stream()
+ .filter(
+ connectionAndStream ->
+
!newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey()))
+ .forEach(
+ entry ->
+ windmillStreamManager.execute(
+ () -> closeStreamSender(entry.getKey(),
entry.getValue())));
- LOG.info(
- "Setting new connections: {}. Previous connections: {}.",
- newConnectionsState,
- connections.get());
- connections.set(newConnectionsState);
- isBudgetRefreshPaused.set(false);
- getWorkBudgetRefresher.requestBudgetRefresh();
+ Set<Endpoint> newGlobalDataEndpoints =
+ new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values());
+ currentBackends.globalDataStreams().values().stream()
+ .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint()))
+ .forEach(
+ sender ->
+ windmillStreamManager.execute(() ->
closeStreamSender(sender.endpoint(), sender)));
+ }
+
+ private void closeStreamSender(Endpoint endpoint, Closeable sender) {
+ LOG.debug("Closing streams to endpoint={}, sender={}", endpoint, sender);
+ try {
+ sender.close();
+ endpoint.directEndpoint().ifPresent(channelCachingStubFactory::remove);
+ LOG.debug("Successfully closed streams to {}", endpoint);
+ } catch (Exception e) {
+ LOG.error("Error closing streams to endpoint={}, sender={}", endpoint,
sender);
+ }
+ }
+
+ private synchronized CompletableFuture<ImmutableMap<Endpoint,
WindmillStreamSender>>
+ createAndStartNewStreams(ImmutableSet<Endpoint> newWindmillEndpoints) {
+ ImmutableMap<Endpoint, WindmillStreamSender> currentStreams =
backends.get().windmillStreams();
+ return MoreFutures.allAsList(
+ newWindmillEndpoints.stream()
+ .map(endpoint ->
getOrCreateWindmillStreamSenderFuture(endpoint, currentStreams))
+ .collect(Collectors.toList()))
+ .thenApply(
+ backends ->
backends.stream().collect(toImmutableMap(Pair::getLeft, Pair::getRight)))
+ .toCompletableFuture();
+ }
+
+ private CompletionStage<Pair<Endpoint, WindmillStreamSender>>
+ getOrCreateWindmillStreamSenderFuture(
+ Endpoint endpoint, ImmutableMap<Endpoint, WindmillStreamSender>
currentStreams) {
+ return MoreFutures.supplyAsync(
+ () ->
+ Pair.of(
+ endpoint,
+ Optional.ofNullable(currentStreams.get(endpoint))
+ .orElseGet(() ->
createAndStartWindmillStreamSender(endpoint))),
+ windmillStreamManager);
}
/** Add up all the throttle times of all streams including
GetWorkerMetadataStream. */
- public long getAndResetThrottleTimes() {
- return connections.get().windmillStreams().values().stream()
+ public long getAndResetThrottleTime() {
+ return backends.get().windmillStreams().values().stream()
.map(WindmillStreamSender::getAndResetThrottleTime)
.reduce(0L, Long::sum)
+ getWorkerMetadataThrottleTimer.getAndResetThrottleTime();
}
public long currentActiveCommitBytes() {
- return connections.get().windmillStreams().values().stream()
+ return backends.get().windmillStreams().values().stream()
.map(WindmillStreamSender::getCurrentActiveCommitBytes)
.reduce(0L, Long::sum);
}
@VisibleForTesting
- StreamingEngineConnectionState getCurrentConnections() {
- return connections.get();
- }
-
- private synchronized ImmutableMap<Endpoint, WindmillConnection>
createNewWindmillConnections(
- List<Endpoint> newWindmillEndpoints) {
- ImmutableMap<Endpoint, WindmillConnection> currentConnections =
- connections.get().windmillConnections();
- return newWindmillEndpoints.stream()
- .collect(
- toImmutableMap(
- Function.identity(),
- endpoint ->
- // Reuse existing stubs if they exist. Optional.orElseGet
only calls the
- // supplier if the value is not present, preventing
constructing expensive
- // objects.
- Optional.ofNullable(currentConnections.get(endpoint))
- .orElseGet(
- () -> WindmillConnection.from(endpoint,
this::createWindmillStub))));
+ StreamingEngineBackends currentBackends() {
+ return backends.get();
}
- private synchronized ImmutableMap<WindmillConnection, WindmillStreamSender>
- closeStaleStreamsAndCreateNewStreams(Collection<WindmillConnection>
newWindmillConnections) {
- ImmutableMap<WindmillConnection, WindmillStreamSender> currentStreams =
- connections.get().windmillStreams();
-
- // Close the streams that are no longer valid.
- currentStreams.entrySet().stream()
- .filter(
- connectionAndStream ->
!newWindmillConnections.contains(connectionAndStream.getKey()))
- .forEach(
- entry -> {
- entry.getValue().closeAllStreams();
-
entry.getKey().directEndpoint().ifPresent(channelCachingStubFactory::remove);
- });
-
- return newWindmillConnections.stream()
- .collect(
- toImmutableMap(
- Function.identity(),
- newConnection ->
- Optional.ofNullable(currentStreams.get(newConnection))
- .orElseGet(() ->
createAndStartWindmillStreamSenderFor(newConnection))));
- }
-
- private ImmutableMap<String, Supplier<GetDataStream>>
createNewGlobalDataStreams(
+ private ImmutableMap<String, GlobalDataStreamSender>
createNewGlobalDataStreams(
ImmutableMap<String, Endpoint> newGlobalDataEndpoints) {
- ImmutableMap<String, Supplier<GetDataStream>> currentGlobalDataStreams =
- connections.get().globalDataStreams();
+ ImmutableMap<String, GlobalDataStreamSender> currentGlobalDataStreams =
+ backends.get().globalDataStreams();
return newGlobalDataEndpoints.entrySet().stream()
.collect(
toImmutableMap(
Entry::getKey,
keyedEndpoint ->
- existingOrNewGetDataStreamFor(keyedEndpoint,
currentGlobalDataStreams)));
+ getOrCreateGlobalDataSteam(keyedEndpoint,
currentGlobalDataStreams)));
}
- private Supplier<GetDataStream> existingOrNewGetDataStreamFor(
+ private GlobalDataStreamSender getOrCreateGlobalDataSteam(
Entry<String, Endpoint> keyedEndpoint,
- ImmutableMap<String, Supplier<GetDataStream>> currentGlobalDataStreams) {
- return Preconditions.checkNotNull(
- currentGlobalDataStreams.getOrDefault(
- keyedEndpoint.getKey(),
+ ImmutableMap<String, GlobalDataStreamSender> currentGlobalDataStreams) {
+ return
Optional.ofNullable(currentGlobalDataStreams.get(keyedEndpoint.getKey()))
+ .orElseGet(
() ->
- streamFactory.createGetDataStream(
- newOrExistingStubFor(keyedEndpoint.getValue()), new
ThrottleTimer())));
- }
-
- private CloudWindmillServiceV1Alpha1Stub newOrExistingStubFor(Endpoint
endpoint) {
- return
Optional.ofNullable(connections.get().windmillConnections().get(endpoint))
- .map(WindmillConnection::stub)
- .orElseGet(() -> createWindmillStub(endpoint));
+ new GlobalDataStreamSender(
+ () ->
+ streamFactory.createGetDataStream(
+ createWindmillStub(keyedEndpoint.getValue()), new
ThrottleTimer()),
+ keyedEndpoint.getValue()));
}
- private WindmillStreamSender createAndStartWindmillStreamSenderFor(
- WindmillConnection connection) {
- // Initially create each stream with no budget. The budget will be
eventually assigned by the
- // GetWorkBudgetDistributor.
+ private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint
endpoint) {
WindmillStreamSender windmillStreamSender =
WindmillStreamSender.create(
- connection,
+ WindmillConnection.from(endpoint, this::createWindmillStub),
GetWorkRequest.newBuilder()
- .setClientId(clientId)
+ .setClientId(jobHeader.getClientId())
.setJobId(jobHeader.getJobId())
.setProjectId(jobHeader.getProjectId())
.setWorkerId(jobHeader.getWorkerId())
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java
new file mode 100644
index 00000000000..ce5f3a7b6bf
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming.harness;
+
+import java.io.Closeable;
+import java.util.function.Supplier;
+import javax.annotation.concurrent.ThreadSafe;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import org.apache.beam.sdk.annotations.Internal;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
+
+@Internal
+@ThreadSafe
+// TODO (m-trieu): replace Supplier<Stream> with Stream after
github.com/apache/beam/pull/32774/ is
+// merged
+final class GlobalDataStreamSender implements Closeable,
Supplier<GetDataStream> {
+ private final Endpoint endpoint;
+ private final Supplier<GetDataStream> delegate;
+ private volatile boolean started;
+
+ GlobalDataStreamSender(Supplier<GetDataStream> delegate, Endpoint endpoint) {
+ // Ensures that the Supplier is thread-safe
+ this.delegate = Suppliers.memoize(delegate::get);
+ this.started = false;
+ this.endpoint = endpoint;
+ }
+
+ @Override
+ public GetDataStream get() {
+ if (!started) {
+ started = true;
+ }
+
+ return delegate.get();
+ }
+
+ @Override
+ public void close() {
+ if (started) {
+ delegate.get().shutdown();
+ }
+ }
+
+ Endpoint endpoint() {
+ return endpoint;
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java
similarity index 55%
rename from
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java
rename to
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java
index 3c85ee6abe1..14290b48683 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java
@@ -18,47 +18,37 @@
package org.apache.beam.runners.dataflow.worker.streaming.harness;
import com.google.auto.value.AutoValue;
-import java.util.function.Supplier;
-import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
import
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
-import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
/**
- * Represents the current state of connections to Streaming Engine.
Connections are updated when
- * backend workers assigned to the key ranges being processed by this user
worker change during
+ * Represents the current state of connections to the Streaming Engine
backend. Backends are updated
+ * when backend workers assigned to the key ranges being processed by this
user worker change during
* pipeline execution. For example, changes can happen via autoscaling,
load-balancing, or other
* backend updates.
*/
@AutoValue
-abstract class StreamingEngineConnectionState {
- static final StreamingEngineConnectionState EMPTY = builder().build();
+abstract class StreamingEngineBackends {
+ static final StreamingEngineBackends EMPTY = builder().build();
static Builder builder() {
- return new AutoValue_StreamingEngineConnectionState.Builder()
- .setWindmillConnections(ImmutableMap.of())
+ return new AutoValue_StreamingEngineBackends.Builder()
.setWindmillStreams(ImmutableMap.of())
.setGlobalDataStreams(ImmutableMap.of());
}
- abstract ImmutableMap<Endpoint, WindmillConnection> windmillConnections();
-
- abstract ImmutableMap<WindmillConnection, WindmillStreamSender>
windmillStreams();
+ abstract ImmutableMap<Endpoint, WindmillStreamSender> windmillStreams();
/** Mapping of GlobalDataIds and the direct GetDataStreams used fetch them.
*/
- abstract ImmutableMap<String, Supplier<GetDataStream>> globalDataStreams();
+ abstract ImmutableMap<String, GlobalDataStreamSender> globalDataStreams();
@AutoValue.Builder
abstract static class Builder {
- public abstract Builder setWindmillConnections(
- ImmutableMap<Endpoint, WindmillConnection> value);
-
- public abstract Builder setWindmillStreams(
- ImmutableMap<WindmillConnection, WindmillStreamSender> value);
+ public abstract Builder setWindmillStreams(ImmutableMap<Endpoint,
WindmillStreamSender> value);
public abstract Builder setGlobalDataStreams(
- ImmutableMap<String, Supplier<GetDataStream>> value);
+ ImmutableMap<String, GlobalDataStreamSender> value);
- public abstract StreamingEngineConnectionState build();
+ public abstract StreamingEngineBackends build();
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java
index 45aa403ee71..744c3d74445 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.runners.dataflow.worker.streaming.harness;
+import java.io.Closeable;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
@@ -49,7 +50,7 @@ import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers
* {@link GetWorkBudget} is set.
*
* <p>Once started, the underlying streams are "alive" until they are manually
closed via {@link
- * #closeAllStreams()}.
+ * #close()} ()}.
*
* <p>If closed, it means that the backend endpoint is no longer in the worker
set. Once closed,
* these instances are not reused.
@@ -59,7 +60,7 @@ import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers
*/
@Internal
@ThreadSafe
-final class WindmillStreamSender implements GetWorkBudgetSpender {
+final class WindmillStreamSender implements GetWorkBudgetSpender, Closeable {
private final AtomicBoolean started;
private final AtomicReference<GetWorkBudget> getWorkBudget;
private final Supplier<GetWorkStream> getWorkStream;
@@ -103,9 +104,9 @@ final class WindmillStreamSender implements
GetWorkBudgetSpender {
connection,
withRequestBudget(getWorkRequest, getWorkBudget.get()),
streamingEngineThrottleTimers.getWorkThrottleTimer(),
- () ->
FixedStreamHeartbeatSender.create(getDataStream.get()),
- () -> getDataClientFactory.apply(getDataStream.get()),
- workCommitter,
+ FixedStreamHeartbeatSender.create(getDataStream.get()),
+ getDataClientFactory.apply(getDataStream.get()),
+ workCommitter.get(),
workItemScheduler));
}
@@ -141,7 +142,8 @@ final class WindmillStreamSender implements
GetWorkBudgetSpender {
started.set(true);
}
- void closeAllStreams() {
+ @Override
+ public void close() {
// Supplier<Stream>.get() starts the stream which is an expensive
operation as it initiates the
// streaming RPCs by possibly making calls over the network. Do not close
the streams unless
// they have already been started.
@@ -154,18 +156,13 @@ final class WindmillStreamSender implements
GetWorkBudgetSpender {
}
@Override
- public void adjustBudget(long itemsDelta, long bytesDelta) {
- getWorkBudget.set(getWorkBudget.get().apply(itemsDelta, bytesDelta));
+ public void setBudget(long items, long bytes) {
+ getWorkBudget.set(getWorkBudget.get().apply(items, bytes));
if (started.get()) {
- getWorkStream.get().adjustBudget(itemsDelta, bytesDelta);
+ getWorkStream.get().setBudget(items, bytes);
}
}
- @Override
- public GetWorkBudget remainingBudget() {
- return started.get() ? getWorkStream.get().remainingBudget() :
getWorkBudget.get();
- }
-
long getAndResetThrottleTime() {
return streamingEngineThrottleTimers.getAndResetThrottleTime();
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java
index d7ed83def43..eb269eef848 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java
@@ -17,8 +17,8 @@
*/
package org.apache.beam.runners.dataflow.worker.windmill;
-import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet;
import com.google.auto.value.AutoValue;
import java.net.Inet6Address;
@@ -27,8 +27,8 @@ import java.net.UnknownHostException;
import java.util.Map;
import java.util.Optional;
import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress.AuthenticatedGcpServiceAddress;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -41,6 +41,14 @@ import org.slf4j.LoggerFactory;
public abstract class WindmillEndpoints {
private static final Logger LOG =
LoggerFactory.getLogger(WindmillEndpoints.class);
+ public static WindmillEndpoints none() {
+ return WindmillEndpoints.builder()
+ .setVersion(Long.MAX_VALUE)
+ .setWindmillEndpoints(ImmutableSet.of())
+ .setGlobalDataEndpoints(ImmutableMap.of())
+ .build();
+ }
+
public static WindmillEndpoints from(
Windmill.WorkerMetadataResponse workerMetadataResponseProto) {
ImmutableMap<String, WindmillEndpoints.Endpoint> globalDataServers =
@@ -53,14 +61,15 @@ public abstract class WindmillEndpoints {
endpoint.getValue(),
workerMetadataResponseProto.getExternalEndpoint())));
- ImmutableList<WindmillEndpoints.Endpoint> windmillServers =
+ ImmutableSet<WindmillEndpoints.Endpoint> windmillServers =
workerMetadataResponseProto.getWorkEndpointsList().stream()
.map(
endpointProto ->
Endpoint.from(endpointProto,
workerMetadataResponseProto.getExternalEndpoint()))
- .collect(toImmutableList());
+ .collect(toImmutableSet());
return WindmillEndpoints.builder()
+ .setVersion(workerMetadataResponseProto.getMetadataVersion())
.setGlobalDataEndpoints(globalDataServers)
.setWindmillEndpoints(windmillServers)
.build();
@@ -123,6 +132,9 @@ public abstract class WindmillEndpoints {
directEndpointAddress.getHostAddress(), (int)
endpointProto.getPort()));
}
+ /** Version of the endpoints which increases with every modification. */
+ public abstract long version();
+
/**
* Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns
a map where the key
* is a global data tag and the value is the endpoint where the data
associated with the global
@@ -138,7 +150,7 @@ public abstract class WindmillEndpoints {
* Windmill servers. Returns a list of endpoints used to communicate with
the corresponding
* Windmill servers.
*/
- public abstract ImmutableList<Endpoint> windmillEndpoints();
+ public abstract ImmutableSet<Endpoint> windmillEndpoints();
/**
* Representation of an endpoint in {@link
Windmill.WorkerMetadataResponse.Endpoint} proto with
@@ -204,13 +216,15 @@ public abstract class WindmillEndpoints {
@AutoValue.Builder
public abstract static class Builder {
+ public abstract Builder setVersion(long version);
+
public abstract Builder setGlobalDataEndpoints(
ImmutableMap<String, WindmillEndpoints.Endpoint> globalDataServers);
public abstract Builder setWindmillEndpoints(
- ImmutableList<WindmillEndpoints.Endpoint> windmillServers);
+ ImmutableSet<WindmillEndpoints.Endpoint> windmillServers);
- abstract ImmutableList.Builder<WindmillEndpoints.Endpoint>
windmillEndpointsBuilder();
+ abstract ImmutableSet.Builder<WindmillEndpoints.Endpoint>
windmillEndpointsBuilder();
public final Builder addWindmillEndpoint(WindmillEndpoints.Endpoint
endpoint) {
windmillEndpointsBuilder().add(endpoint);
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java
index 90f93b07267..0b895652efe 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java
@@ -19,38 +19,36 @@ package org.apache.beam.runners.dataflow.worker.windmill;
import com.google.auto.value.AutoOneOf;
import com.google.auto.value.AutoValue;
-import java.net.Inet6Address;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
/** Used to create channels to communicate with Streaming Engine via gRpc. */
@AutoOneOf(WindmillServiceAddress.Kind.class)
public abstract class WindmillServiceAddress {
- public static WindmillServiceAddress create(Inet6Address ipv6Address) {
- return AutoOneOf_WindmillServiceAddress.ipv6(ipv6Address);
- }
public static WindmillServiceAddress create(HostAndPort gcpServiceAddress) {
return
AutoOneOf_WindmillServiceAddress.gcpServiceAddress(gcpServiceAddress);
}
- public abstract Kind getKind();
+ public static WindmillServiceAddress create(
+ AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) {
+ return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress(
+ authenticatedGcpServiceAddress);
+ }
- public abstract Inet6Address ipv6();
+ public abstract Kind getKind();
public abstract HostAndPort gcpServiceAddress();
public abstract AuthenticatedGcpServiceAddress
authenticatedGcpServiceAddress();
- public static WindmillServiceAddress create(
- AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) {
- return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress(
- authenticatedGcpServiceAddress);
+ public final HostAndPort getServiceAddress() {
+ return getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS
+ ? gcpServiceAddress()
+ : authenticatedGcpServiceAddress().gcpServiceAddress();
}
public enum Kind {
- IPV6,
GCP_SERVICE_ADDRESS,
- // TODO(m-trieu): Use for direct connections when ALTS is enabled.
AUTHENTICATED_GCP_SERVICE_ADDRESS
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
index 31bd4e146a7..f26c56b14ec 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
@@ -56,10 +56,11 @@ public interface WindmillStream {
@ThreadSafe
interface GetWorkStream extends WindmillStream {
/** Adjusts the {@link GetWorkBudget} for the stream. */
- void adjustBudget(long itemsDelta, long bytesDelta);
+ void setBudget(GetWorkBudget newBudget);
- /** Returns the remaining in-flight {@link GetWorkBudget}. */
- GetWorkBudget remainingBudget();
+ default void setBudget(long newItems, long newBytes) {
+
setBudget(GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build());
+ }
}
/** Interface for streaming GetDataRequests to Windmill. */
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
index 19de998b1da..b27ebc8e9ee 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
@@ -21,9 +21,11 @@ import java.io.PrintWriter;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
-import java.util.function.Supplier;
+import javax.annotation.concurrent.GuardedBy;
+import net.jcip.annotations.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.streaming.Watermarks;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
@@ -44,8 +46,8 @@ import
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSe
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
* Implementation of {@link GetWorkStream} that passes along a specific {@link
@@ -55,9 +57,10 @@ import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers
* these direct streams are used to facilitate these RPC calls to specific
backend workers.
*/
@Internal
-public final class GrpcDirectGetWorkStream
+final class GrpcDirectGetWorkStream
extends AbstractWindmillStream<StreamingGetWorkRequest,
StreamingGetWorkResponseChunk>
implements GetWorkStream {
+ private static final Logger LOG =
LoggerFactory.getLogger(GrpcDirectGetWorkStream.class);
private static final StreamingGetWorkRequest HEALTH_CHECK_REQUEST =
StreamingGetWorkRequest.newBuilder()
.setRequestExtension(
@@ -67,15 +70,14 @@ public final class GrpcDirectGetWorkStream
.build())
.build();
- private final AtomicReference<GetWorkBudget> inFlightBudget;
- private final AtomicReference<GetWorkBudget> nextBudgetAdjustment;
- private final AtomicReference<GetWorkBudget> pendingResponseBudget;
- private final GetWorkRequest request;
+ private final GetWorkBudgetTracker budgetTracker;
+ private final GetWorkRequest requestHeader;
private final WorkItemScheduler workItemScheduler;
private final ThrottleTimer getWorkThrottleTimer;
- private final Supplier<HeartbeatSender> heartbeatSender;
- private final Supplier<WorkCommitter> workCommitter;
- private final Supplier<GetDataClient> getDataClient;
+ private final HeartbeatSender heartbeatSender;
+ private final WorkCommitter workCommitter;
+ private final GetDataClient getDataClient;
+ private final AtomicReference<StreamingGetWorkRequest> lastRequest;
/**
* Map of stream IDs to their buffers. Used to aggregate streaming gRPC
response chunks as they
@@ -92,15 +94,15 @@ public final class GrpcDirectGetWorkStream
StreamObserver<StreamingGetWorkResponseChunk>,
StreamObserver<StreamingGetWorkRequest>>
startGetWorkRpcFn,
- GetWorkRequest request,
+ GetWorkRequest requestHeader,
BackOff backoff,
StreamObserverFactory streamObserverFactory,
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
ThrottleTimer getWorkThrottleTimer,
- Supplier<HeartbeatSender> heartbeatSender,
- Supplier<GetDataClient> getDataClient,
- Supplier<WorkCommitter> workCommitter,
+ HeartbeatSender heartbeatSender,
+ GetDataClient getDataClient,
+ WorkCommitter workCommitter,
WorkItemScheduler workItemScheduler) {
super(
"GetWorkStream",
@@ -110,19 +112,23 @@ public final class GrpcDirectGetWorkStream
streamRegistry,
logEveryNStreamFailures,
backendWorkerToken);
- this.request = request;
+ this.requestHeader = requestHeader;
this.getWorkThrottleTimer = getWorkThrottleTimer;
this.workItemScheduler = workItemScheduler;
this.workItemAssemblers = new ConcurrentHashMap<>();
- this.heartbeatSender = Suppliers.memoize(heartbeatSender::get);
- this.workCommitter = Suppliers.memoize(workCommitter::get);
- this.getDataClient = Suppliers.memoize(getDataClient::get);
- this.inFlightBudget = new AtomicReference<>(GetWorkBudget.noBudget());
- this.nextBudgetAdjustment = new
AtomicReference<>(GetWorkBudget.noBudget());
- this.pendingResponseBudget = new
AtomicReference<>(GetWorkBudget.noBudget());
+ this.heartbeatSender = heartbeatSender;
+ this.workCommitter = workCommitter;
+ this.getDataClient = getDataClient;
+ this.lastRequest = new AtomicReference<>();
+ this.budgetTracker =
+ new GetWorkBudgetTracker(
+ GetWorkBudget.builder()
+ .setItems(requestHeader.getMaxItems())
+ .setBytes(requestHeader.getMaxBytes())
+ .build());
}
- public static GrpcDirectGetWorkStream create(
+ static GrpcDirectGetWorkStream create(
String backendWorkerToken,
Function<
StreamObserver<StreamingGetWorkResponseChunk>,
@@ -134,9 +140,9 @@ public final class GrpcDirectGetWorkStream
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
ThrottleTimer getWorkThrottleTimer,
- Supplier<HeartbeatSender> heartbeatSender,
- Supplier<GetDataClient> getDataClient,
- Supplier<WorkCommitter> workCommitter,
+ HeartbeatSender heartbeatSender,
+ GetDataClient getDataClient,
+ WorkCommitter workCommitter,
WorkItemScheduler workItemScheduler) {
GrpcDirectGetWorkStream getWorkStream =
new GrpcDirectGetWorkStream(
@@ -165,46 +171,52 @@ public final class GrpcDirectGetWorkStream
.build();
}
- private void sendRequestExtension(GetWorkBudget adjustment) {
- inFlightBudget.getAndUpdate(budget -> budget.apply(adjustment));
- StreamingGetWorkRequest extension =
- StreamingGetWorkRequest.newBuilder()
- .setRequestExtension(
- Windmill.StreamingGetWorkRequestExtension.newBuilder()
- .setMaxItems(adjustment.items())
- .setMaxBytes(adjustment.bytes()))
- .build();
-
- executor()
- .execute(
- () -> {
- try {
- send(extension);
- } catch (IllegalStateException e) {
- // Stream was closed.
- }
- });
+ /**
+ * @implNote Do not lock/synchronize here due to this running on grpc serial
executor for message
+ * which can deadlock since we send on the stream beneath the
synchronization. {@link
+ * AbstractWindmillStream#send(Object)} is synchronized so the sends are
already guarded.
+ */
+ private void maybeSendRequestExtension(GetWorkBudget extension) {
+ if (extension.items() > 0 || extension.bytes() > 0) {
+ executeSafely(
+ () -> {
+ StreamingGetWorkRequest request =
+ StreamingGetWorkRequest.newBuilder()
+ .setRequestExtension(
+ Windmill.StreamingGetWorkRequestExtension.newBuilder()
+ .setMaxItems(extension.items())
+ .setMaxBytes(extension.bytes()))
+ .build();
+ lastRequest.set(request);
+ budgetTracker.recordBudgetRequested(extension);
+ try {
+ send(request);
+ } catch (IllegalStateException e) {
+ // Stream was closed.
+ }
+ });
+ }
}
@Override
protected synchronized void onNewStream() {
workItemAssemblers.clear();
- // Add the current in-flight budget to the next adjustment. Only positive
values are allowed
- // here
- // with negatives defaulting to 0, since GetWorkBudgets cannot be created
with negative values.
- GetWorkBudget budgetAdjustment =
nextBudgetAdjustment.get().apply(inFlightBudget.get());
- inFlightBudget.set(budgetAdjustment);
- send(
- StreamingGetWorkRequest.newBuilder()
- .setRequest(
- request
- .toBuilder()
- .setMaxBytes(budgetAdjustment.bytes())
- .setMaxItems(budgetAdjustment.items()))
- .build());
-
- // We just sent the budget, reset it.
- nextBudgetAdjustment.set(GetWorkBudget.noBudget());
+ if (!isShutdown()) {
+ budgetTracker.reset();
+ GetWorkBudget initialGetWorkBudget =
budgetTracker.computeBudgetExtension();
+ StreamingGetWorkRequest request =
+ StreamingGetWorkRequest.newBuilder()
+ .setRequest(
+ requestHeader
+ .toBuilder()
+ .setMaxItems(initialGetWorkBudget.items())
+ .setMaxBytes(initialGetWorkBudget.bytes())
+ .build())
+ .build();
+ lastRequest.set(request);
+ budgetTracker.recordBudgetRequested(initialGetWorkBudget);
+ send(request);
+ }
}
@Override
@@ -216,8 +228,9 @@ public final class GrpcDirectGetWorkStream
public void appendSpecificHtml(PrintWriter writer) {
// Number of buffers is same as distinct workers that sent work on this
stream.
writer.format(
- "GetWorkStream: %d buffers, %s inflight budget allowed.",
- workItemAssemblers.size(), inFlightBudget.get());
+ "GetWorkStream: %d buffers, " + "last sent request: %s; ",
+ workItemAssemblers.size(), lastRequest.get());
+ writer.print(budgetTracker.debugString());
}
@Override
@@ -235,30 +248,22 @@ public final class GrpcDirectGetWorkStream
}
private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) {
- // Record the fact that there are now fewer outstanding messages and bytes
on the stream.
- inFlightBudget.updateAndGet(budget -> budget.subtract(1,
assembledWorkItem.bufferedSize()));
WorkItem workItem = assembledWorkItem.workItem();
GetWorkResponseChunkAssembler.ComputationMetadata metadata =
assembledWorkItem.computationMetadata();
- pendingResponseBudget.getAndUpdate(budget -> budget.apply(1,
workItem.getSerializedSize()));
- try {
- workItemScheduler.scheduleWork(
- workItem,
- createWatermarks(workItem, Preconditions.checkNotNull(metadata)),
-
createProcessingContext(Preconditions.checkNotNull(metadata.computationId())),
- assembledWorkItem.latencyAttributions());
- } finally {
- pendingResponseBudget.getAndUpdate(budget -> budget.apply(-1,
-workItem.getSerializedSize()));
- }
+ workItemScheduler.scheduleWork(
+ workItem,
+ createWatermarks(workItem, metadata),
+ createProcessingContext(metadata.computationId()),
+ assembledWorkItem.latencyAttributions());
+ budgetTracker.recordBudgetReceived(assembledWorkItem.bufferedSize());
+ GetWorkBudget extension = budgetTracker.computeBudgetExtension();
+ maybeSendRequestExtension(extension);
}
private Work.ProcessingContext createProcessingContext(String computationId)
{
return Work.createProcessingContext(
- computationId,
- getDataClient.get(),
- workCommitter.get()::commit,
- heartbeatSender.get(),
- backendWorkerToken());
+ computationId, getDataClient, workCommitter::commit, heartbeatSender,
backendWorkerToken());
}
@Override
@@ -267,25 +272,110 @@ public final class GrpcDirectGetWorkStream
}
@Override
- public void adjustBudget(long itemsDelta, long bytesDelta) {
- GetWorkBudget adjustment =
- nextBudgetAdjustment
- // Get the current value, and reset the nextBudgetAdjustment. This
will be set again
- // when adjustBudget is called.
- .getAndUpdate(unused -> GetWorkBudget.noBudget())
- .apply(itemsDelta, bytesDelta);
- sendRequestExtension(adjustment);
+ public void setBudget(GetWorkBudget newBudget) {
+ GetWorkBudget extension =
budgetTracker.consumeAndComputeBudgetUpdate(newBudget);
+ maybeSendRequestExtension(extension);
}
- @Override
- public GetWorkBudget remainingBudget() {
- // Snapshot the current budgets.
- GetWorkBudget currentPendingResponseBudget = pendingResponseBudget.get();
- GetWorkBudget currentNextBudgetAdjustment = nextBudgetAdjustment.get();
- GetWorkBudget currentInflightBudget = inFlightBudget.get();
-
- return currentPendingResponseBudget
- .apply(currentNextBudgetAdjustment)
- .apply(currentInflightBudget);
+ private void executeSafely(Runnable runnable) {
+ try {
+ executor().execute(runnable);
+ } catch (RejectedExecutionException e) {
+ LOG.debug("{} has been shutdown.", getClass());
+ }
+ }
+
+ /**
+ * Tracks sent, received, max {@link GetWorkBudget} and uses this
information to generate request
+ * extensions.
+ */
+ @ThreadSafe
+ private static final class GetWorkBudgetTracker {
+
+ @GuardedBy("GetWorkBudgetTracker.this")
+ private GetWorkBudget maxGetWorkBudget;
+
+ @GuardedBy("GetWorkBudgetTracker.this")
+ private long itemsRequested = 0;
+
+ @GuardedBy("GetWorkBudgetTracker.this")
+ private long bytesRequested = 0;
+
+ @GuardedBy("GetWorkBudgetTracker.this")
+ private long itemsReceived = 0;
+
+ @GuardedBy("GetWorkBudgetTracker.this")
+ private long bytesReceived = 0;
+
+ private GetWorkBudgetTracker(GetWorkBudget maxGetWorkBudget) {
+ this.maxGetWorkBudget = maxGetWorkBudget;
+ }
+
+ private synchronized void reset() {
+ itemsRequested = 0;
+ bytesRequested = 0;
+ itemsReceived = 0;
+ bytesReceived = 0;
+ }
+
+ private synchronized String debugString() {
+ return String.format(
+ "max budget: %s; "
+ + "in-flight budget: %s; "
+ + "total budget requested: %s; "
+ + "total budget received: %s.",
+ maxGetWorkBudget, inFlightBudget(), totalRequestedBudget(),
totalReceivedBudget());
+ }
+
+ /** Consumes the new budget and computes an extension based on the new
budget. */
+ private synchronized GetWorkBudget
consumeAndComputeBudgetUpdate(GetWorkBudget newBudget) {
+ maxGetWorkBudget = newBudget;
+ return computeBudgetExtension();
+ }
+
+ private synchronized void recordBudgetRequested(GetWorkBudget
budgetRequested) {
+ itemsRequested += budgetRequested.items();
+ bytesRequested += budgetRequested.bytes();
+ }
+
+ private synchronized void recordBudgetReceived(long returnedBudget) {
+ itemsReceived++;
+ bytesReceived += returnedBudget;
+ }
+
+ /**
+ * If the outstanding items or bytes limit has gotten too low, top both
off with a
+ * GetWorkExtension. The goal is to keep the limits relatively close to
their maximum values
+ * without sending too many extension requests.
+ */
+ private synchronized GetWorkBudget computeBudgetExtension() {
+ // Expected items and bytes can go negative here, since WorkItems
returned might be larger
+ // than the initially requested budget.
+ long inFlightItems = itemsRequested - itemsReceived;
+ long inFlightBytes = bytesRequested - bytesReceived;
+
+ // Don't send negative budget extensions.
+ long requestBytes = Math.max(0, maxGetWorkBudget.bytes() -
inFlightBytes);
+ long requestItems = Math.max(0, maxGetWorkBudget.items() -
inFlightItems);
+
+ return (inFlightItems > requestItems / 2 && inFlightBytes > requestBytes
/ 2)
+ ? GetWorkBudget.noBudget()
+ :
GetWorkBudget.builder().setItems(requestItems).setBytes(requestBytes).build();
+ }
+
+ private synchronized GetWorkBudget inFlightBudget() {
+ return GetWorkBudget.builder()
+ .setItems(itemsRequested - itemsReceived)
+ .setBytes(bytesRequested - bytesReceived)
+ .build();
+ }
+
+ private synchronized GetWorkBudget totalRequestedBudget() {
+ return
GetWorkBudget.builder().setItems(itemsRequested).setBytes(bytesRequested).build();
+ }
+
+ private synchronized GetWorkBudget totalReceivedBudget() {
+ return
GetWorkBudget.builder().setItems(itemsReceived).setBytes(bytesReceived).build();
+ }
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
index 0e9a0c6316e..c99e05a7707 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
@@ -59,7 +59,7 @@ import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-public final class GrpcGetDataStream
+final class GrpcGetDataStream
extends AbstractWindmillStream<StreamingGetDataRequest,
StreamingGetDataResponse>
implements GetDataStream {
private static final Logger LOG =
LoggerFactory.getLogger(GrpcGetDataStream.class);
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java
index 09ecbf3f305..a368f3fec23 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java
@@ -194,15 +194,7 @@ final class GrpcGetWorkStream
}
@Override
- public void adjustBudget(long itemsDelta, long bytesDelta) {
+ public void setBudget(GetWorkBudget newBudget) {
// no-op
}
-
- @Override
- public GetWorkBudget remainingBudget() {
- return GetWorkBudget.builder()
- .setBytes(request.getMaxBytes() - inflightBytes.get())
- .setItems(request.getMaxItems() - inflightMessages.get())
- .build();
- }
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
index 92f031db997..9e6a02d135e 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
@@ -198,9 +198,9 @@ public class GrpcWindmillStreamFactory implements
StatusDataProvider {
WindmillConnection connection,
GetWorkRequest request,
ThrottleTimer getWorkThrottleTimer,
- Supplier<HeartbeatSender> heartbeatSender,
- Supplier<GetDataClient> getDataClient,
- Supplier<WorkCommitter> workCommitter,
+ HeartbeatSender heartbeatSender,
+ GetDataClient getDataClient,
+ WorkCommitter workCommitter,
WorkItemScheduler workItemScheduler) {
return GrpcDirectGetWorkStream.create(
connection.backendWorkerToken(),
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java
index 9aec29a3ba4..f0ea2f550a7 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java
@@ -36,7 +36,6 @@ import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPor
/** Utility class used to create different RPC Channels. */
public final class WindmillChannelFactory {
public static final String LOCALHOST = "localhost";
- private static final int DEFAULT_GRPC_PORT = 443;
private static final int MAX_REMOTE_TRACE_EVENTS = 100;
private WindmillChannelFactory() {}
@@ -55,8 +54,6 @@ public final class WindmillChannelFactory {
public static ManagedChannel remoteChannel(
WindmillServiceAddress windmillServiceAddress, int
windmillServiceRpcChannelTimeoutSec) {
switch (windmillServiceAddress.getKind()) {
- case IPV6:
- return remoteChannel(windmillServiceAddress.ipv6(),
windmillServiceRpcChannelTimeoutSec);
case GCP_SERVICE_ADDRESS:
return remoteChannel(
windmillServiceAddress.gcpServiceAddress(),
windmillServiceRpcChannelTimeoutSec);
@@ -67,7 +64,8 @@ public final class WindmillChannelFactory {
windmillServiceRpcChannelTimeoutSec);
default:
throw new UnsupportedOperationException(
- "Only IPV6, GCP_SERVICE_ADDRESS, AUTHENTICATED_GCP_SERVICE_ADDRESS
are supported WindmillServiceAddresses.");
+ "Only GCP_SERVICE_ADDRESS and AUTHENTICATED_GCP_SERVICE_ADDRESS
are supported"
+ + " WindmillServiceAddresses.");
}
}
@@ -105,17 +103,6 @@ public final class WindmillChannelFactory {
}
}
- public static ManagedChannel remoteChannel(
- Inet6Address directEndpoint, int windmillServiceRpcChannelTimeoutSec) {
- try {
- return createRemoteChannel(
- NettyChannelBuilder.forAddress(new InetSocketAddress(directEndpoint,
DEFAULT_GRPC_PORT)),
- windmillServiceRpcChannelTimeoutSec);
- } catch (SSLException sslException) {
- throw new WindmillChannelCreationException(directEndpoint.toString(),
sslException);
- }
- }
-
@SuppressWarnings("nullness")
private static ManagedChannel createRemoteChannel(
NettyChannelBuilder channelBuilder, int
windmillServiceRpcChannelTimeoutSec)
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java
index 403bb99efb4..8a1ba2556cf 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java
@@ -17,18 +17,11 @@
*/
package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
-import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
-import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.DoubleMath.roundToLong;
import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath.divide;
import java.math.RoundingMode;
-import java.util.Map;
-import java.util.Map.Entry;
-import java.util.function.Function;
-import java.util.function.Supplier;
import org.apache.beam.sdk.annotations.Internal;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -36,22 +29,11 @@ import org.slf4j.LoggerFactory;
@Internal
final class EvenGetWorkBudgetDistributor implements GetWorkBudgetDistributor {
private static final Logger LOG =
LoggerFactory.getLogger(EvenGetWorkBudgetDistributor.class);
- private final Supplier<GetWorkBudget> activeWorkBudgetSupplier;
-
- EvenGetWorkBudgetDistributor(Supplier<GetWorkBudget>
activeWorkBudgetSupplier) {
- this.activeWorkBudgetSupplier = activeWorkBudgetSupplier;
- }
-
- private static boolean isBelowFiftyPercentOfTarget(
- GetWorkBudget remaining, GetWorkBudget target) {
- return remaining.items() < roundToLong(target.items() * 0.5,
RoundingMode.CEILING)
- || remaining.bytes() < roundToLong(target.bytes() * 0.5,
RoundingMode.CEILING);
- }
@Override
public <T extends GetWorkBudgetSpender> void distributeBudget(
- ImmutableCollection<T> budgetOwners, GetWorkBudget getWorkBudget) {
- if (budgetOwners.isEmpty()) {
+ ImmutableCollection<T> budgetSpenders, GetWorkBudget getWorkBudget) {
+ if (budgetSpenders.isEmpty()) {
LOG.debug("Cannot distribute budget to no owners.");
return;
}
@@ -61,38 +43,15 @@ final class EvenGetWorkBudgetDistributor implements
GetWorkBudgetDistributor {
return;
}
- Map<T, GetWorkBudget> desiredBudgets = computeDesiredBudgets(budgetOwners,
getWorkBudget);
-
- for (Entry<T, GetWorkBudget> streamAndDesiredBudget :
desiredBudgets.entrySet()) {
- GetWorkBudgetSpender getWorkBudgetSpender =
streamAndDesiredBudget.getKey();
- GetWorkBudget desired = streamAndDesiredBudget.getValue();
- GetWorkBudget remaining = getWorkBudgetSpender.remainingBudget();
- if (isBelowFiftyPercentOfTarget(remaining, desired)) {
- GetWorkBudget adjustment = desired.subtract(remaining);
- getWorkBudgetSpender.adjustBudget(adjustment);
- }
- }
+ GetWorkBudget budgetPerStream =
computeDesiredPerStreamBudget(budgetSpenders, getWorkBudget);
+ budgetSpenders.forEach(getWorkBudgetSpender ->
getWorkBudgetSpender.setBudget(budgetPerStream));
}
- private <T extends GetWorkBudgetSpender> ImmutableMap<T, GetWorkBudget>
computeDesiredBudgets(
+ private <T extends GetWorkBudgetSpender> GetWorkBudget
computeDesiredPerStreamBudget(
ImmutableCollection<T> streams, GetWorkBudget totalGetWorkBudget) {
- GetWorkBudget activeWorkBudget = activeWorkBudgetSupplier.get();
- LOG.info("Current active work budget: {}", activeWorkBudget);
- // TODO: Fix possibly non-deterministic handing out of budgets.
- // Rounding up here will drift upwards over the lifetime of the streams.
- GetWorkBudget budgetPerStream =
- GetWorkBudget.builder()
- .setItems(
- divide(
- totalGetWorkBudget.items() - activeWorkBudget.items(),
- streams.size(),
- RoundingMode.CEILING))
- .setBytes(
- divide(
- totalGetWorkBudget.bytes() - activeWorkBudget.bytes(),
- streams.size(),
- RoundingMode.CEILING))
- .build();
- return streams.stream().collect(toImmutableMap(Function.identity(), unused
-> budgetPerStream));
+ return GetWorkBudget.builder()
+ .setItems(divide(totalGetWorkBudget.items(), streams.size(),
RoundingMode.CEILING))
+ .setBytes(divide(totalGetWorkBudget.bytes(), streams.size(),
RoundingMode.CEILING))
+ .build();
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java
index 43c0d46139d..2013c9ff1cb 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java
@@ -17,13 +17,11 @@
*/
package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
-import java.util.function.Supplier;
import org.apache.beam.sdk.annotations.Internal;
@Internal
public final class GetWorkBudgetDistributors {
- public static GetWorkBudgetDistributor distributeEvenly(
- Supplier<GetWorkBudget> activeWorkBudgetSupplier) {
- return new EvenGetWorkBudgetDistributor(activeWorkBudgetSupplier);
+ public static GetWorkBudgetDistributor distributeEvenly() {
+ return new EvenGetWorkBudgetDistributor();
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java
index 254b2589062..decf101a641 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java
@@ -22,11 +22,9 @@ package
org.apache.beam.runners.dataflow.worker.windmill.work.budget;
* org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget}
*/
public interface GetWorkBudgetSpender {
- void adjustBudget(long itemsDelta, long bytesDelta);
+ void setBudget(long items, long bytes);
- default void adjustBudget(GetWorkBudget adjustment) {
- adjustBudget(adjustment.items(), adjustment.bytes());
+ default void setBudget(GetWorkBudget budget) {
+ setBudget(budget.items(), budget.bytes());
}
-
- GetWorkBudget remainingBudget();
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
index b3f7467cdbd..90ffb3d3fbc 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
@@ -245,18 +245,10 @@ public final class FakeWindmillServer extends
WindmillServerStub {
}
@Override
- public void adjustBudget(long itemsDelta, long bytesDelta) {
+ public void setBudget(GetWorkBudget newBudget) {
// no-op.
}
- @Override
- public GetWorkBudget remainingBudget() {
- return GetWorkBudget.builder()
- .setItems(request.getMaxItems())
- .setBytes(request.getMaxBytes())
- .build();
- }
-
@Override
public boolean awaitTermination(int time, TimeUnit unit) throws
InterruptedException {
while (done.getCount() > 0) {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java
index ed8815c48e7..0092fcc7bcd 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java
@@ -30,9 +30,7 @@ import static org.mockito.Mockito.verify;
import java.io.IOException;
import java.util.ArrayList;
-import java.util.Comparator;
import java.util.HashSet;
-import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
@@ -46,7 +44,6 @@ import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse;
-import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress;
import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
import
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker;
@@ -71,7 +68,6 @@ import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Precondit
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
import org.junit.After;
import org.junit.Before;
@@ -92,7 +88,6 @@ public class FanOutStreamingEngineWorkerHarnessTest {
.setDirectEndpoint(DEFAULT_WINDMILL_SERVICE_ADDRESS.gcpServiceAddress().toString())
.build());
- private static final long CLIENT_ID = 1L;
private static final String JOB_ID = "jobId";
private static final String PROJECT_ID = "projectId";
private static final String WORKER_ID = "workerId";
@@ -101,6 +96,7 @@ public class FanOutStreamingEngineWorkerHarnessTest {
.setJobId(JOB_ID)
.setProjectId(PROJECT_ID)
.setWorkerId(WORKER_ID)
+ .setClientId(1L)
.build();
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
@@ -134,7 +130,7 @@ public class FanOutStreamingEngineWorkerHarnessTest {
.setJobId(JOB_ID)
.setProjectId(PROJECT_ID)
.setWorkerId(WORKER_ID)
- .setClientId(CLIENT_ID)
+ .setClientId(JOB_HEADER.getClientId())
.setMaxItems(items)
.setMaxBytes(bytes)
.build();
@@ -174,7 +170,7 @@ public class FanOutStreamingEngineWorkerHarnessTest {
stubFactory.shutdown();
}
- private FanOutStreamingEngineWorkerHarness newStreamingEngineClient(
+ private FanOutStreamingEngineWorkerHarness
newFanOutStreamingEngineWorkerHarness(
GetWorkBudget getWorkBudget,
GetWorkBudgetDistributor getWorkBudgetDistributor,
WorkItemScheduler workItemScheduler) {
@@ -186,7 +182,6 @@ public class FanOutStreamingEngineWorkerHarnessTest {
stubFactory,
getWorkBudgetDistributor,
dispatcherClient,
- CLIENT_ID,
ignored -> mock(WorkCommitter.class),
new ThrottlingGetDataMetricTracker(mock(MemoryMonitor.class)));
}
@@ -201,7 +196,7 @@ public class FanOutStreamingEngineWorkerHarnessTest {
spy(new TestGetWorkBudgetDistributor(numBudgetDistributionsExpected));
fanOutStreamingEngineWorkProvider =
- newStreamingEngineClient(
+ newFanOutStreamingEngineWorkerHarness(
GetWorkBudget.builder().setItems(items).setBytes(bytes).build(),
getWorkBudgetDistributor,
noOpProcessWorkItemFn());
@@ -219,16 +214,14 @@ public class FanOutStreamingEngineWorkerHarnessTest {
getWorkerMetadataReady.await();
fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
- waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
+ assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());
- StreamingEngineConnectionState currentConnections =
- fanOutStreamingEngineWorkProvider.getCurrentConnections();
+ StreamingEngineBackends currentBackends =
fanOutStreamingEngineWorkProvider.currentBackends();
- assertEquals(2, currentConnections.windmillConnections().size());
- assertEquals(2, currentConnections.windmillStreams().size());
+ assertEquals(2, currentBackends.windmillStreams().size());
Set<String> workerTokens =
- currentConnections.windmillConnections().values().stream()
- .map(WindmillConnection::backendWorkerToken)
+ currentBackends.windmillStreams().keySet().stream()
+ .map(endpoint ->
endpoint.workerToken().orElseThrow(IllegalStateException::new))
.collect(Collectors.toSet());
assertTrue(workerTokens.contains(workerToken));
@@ -252,27 +245,6 @@ public class FanOutStreamingEngineWorkerHarnessTest {
verify(streamFactory, times(2)).createCommitWorkStream(any(), any());
}
- @Test
- public void testScheduledBudgetRefresh() throws InterruptedException {
- TestGetWorkBudgetDistributor getWorkBudgetDistributor =
- spy(new TestGetWorkBudgetDistributor(2));
- fanOutStreamingEngineWorkProvider =
- newStreamingEngineClient(
- GetWorkBudget.builder().setItems(1L).setBytes(1L).build(),
- getWorkBudgetDistributor,
- noOpProcessWorkItemFn());
-
- getWorkerMetadataReady.await();
- fakeGetWorkerMetadataStub.injectWorkerMetadata(
- WorkerMetadataResponse.newBuilder()
- .setMetadataVersion(1)
- .addWorkEndpoints(metadataResponseEndpoint("workerToken"))
- .putAllGlobalDataEndpoints(DEFAULT)
- .build());
- waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
- verify(getWorkBudgetDistributor, atLeast(2)).distributeBudget(any(),
any());
- }
-
@Test
public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers()
throws InterruptedException {
@@ -280,7 +252,7 @@ public class FanOutStreamingEngineWorkerHarnessTest {
TestGetWorkBudgetDistributor getWorkBudgetDistributor =
spy(new TestGetWorkBudgetDistributor(metadataCount));
fanOutStreamingEngineWorkProvider =
- newStreamingEngineClient(
+ newFanOutStreamingEngineWorkerHarness(
GetWorkBudget.builder().setItems(1).setBytes(1).build(),
getWorkBudgetDistributor,
noOpProcessWorkItemFn());
@@ -309,32 +281,28 @@ public class FanOutStreamingEngineWorkerHarnessTest {
WorkerMetadataResponse.Endpoint.newBuilder()
.setBackendWorkerToken(workerToken3)
.build())
- .putAllGlobalDataEndpoints(DEFAULT)
.build();
getWorkerMetadataReady.await();
fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata);
- waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
- StreamingEngineConnectionState currentConnections =
- fanOutStreamingEngineWorkProvider.getCurrentConnections();
- assertEquals(1, currentConnections.windmillConnections().size());
- assertEquals(1, currentConnections.windmillStreams().size());
+ assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());
+ StreamingEngineBackends currentBackends =
fanOutStreamingEngineWorkProvider.currentBackends();
+ assertEquals(1, currentBackends.windmillStreams().size());
Set<String> workerTokens =
-
fanOutStreamingEngineWorkProvider.getCurrentConnections().windmillConnections().values()
- .stream()
- .map(WindmillConnection::backendWorkerToken)
+
fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().keySet().stream()
+ .map(endpoint ->
endpoint.workerToken().orElseThrow(IllegalStateException::new))
.collect(Collectors.toSet());
assertFalse(workerTokens.contains(workerToken));
assertFalse(workerTokens.contains(workerToken2));
+ assertTrue(currentBackends.globalDataStreams().isEmpty());
}
@Test
public void testOnNewWorkerMetadata_redistributesBudget() throws
InterruptedException {
String workerToken = "workerToken1";
String workerToken2 = "workerToken2";
- String workerToken3 = "workerToken3";
WorkerMetadataResponse firstWorkerMetadata =
WorkerMetadataResponse.newBuilder()
@@ -354,42 +322,24 @@ public class FanOutStreamingEngineWorkerHarnessTest {
.build())
.putAllGlobalDataEndpoints(DEFAULT)
.build();
- WorkerMetadataResponse thirdWorkerMetadata =
- WorkerMetadataResponse.newBuilder()
- .setMetadataVersion(3)
- .addWorkEndpoints(
- WorkerMetadataResponse.Endpoint.newBuilder()
- .setBackendWorkerToken(workerToken3)
- .build())
- .putAllGlobalDataEndpoints(DEFAULT)
- .build();
-
- List<WorkerMetadataResponse> workerMetadataResponses =
- Lists.newArrayList(firstWorkerMetadata, secondWorkerMetadata,
thirdWorkerMetadata);
TestGetWorkBudgetDistributor getWorkBudgetDistributor =
- spy(new TestGetWorkBudgetDistributor(workerMetadataResponses.size()));
+ spy(new TestGetWorkBudgetDistributor(1));
fanOutStreamingEngineWorkProvider =
- newStreamingEngineClient(
+ newFanOutStreamingEngineWorkerHarness(
GetWorkBudget.builder().setItems(1).setBytes(1).build(),
getWorkBudgetDistributor,
noOpProcessWorkItemFn());
getWorkerMetadataReady.await();
- // Make sure we are injecting the metadata from smallest to largest.
- workerMetadataResponses.stream()
-
.sorted(Comparator.comparingLong(WorkerMetadataResponse::getMetadataVersion))
- .forEach(fakeGetWorkerMetadataStub::injectWorkerMetadata);
-
- waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
- verify(getWorkBudgetDistributor, atLeast(workerMetadataResponses.size()))
- .distributeBudget(any(), any());
- }
+ fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
+ assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());
+ getWorkBudgetDistributor.expectNumDistributions(1);
+ fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata);
+ assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());
- private void waitForWorkerMetadataToBeConsumed(
- TestGetWorkBudgetDistributor getWorkBudgetDistributor) throws
InterruptedException {
- getWorkBudgetDistributor.waitForBudgetDistribution();
+ verify(getWorkBudgetDistributor, times(2)).distributeBudget(any(), any());
}
private static class GetWorkerMetadataTestStub
@@ -434,21 +384,24 @@ public class FanOutStreamingEngineWorkerHarnessTest {
}
private static class TestGetWorkBudgetDistributor implements
GetWorkBudgetDistributor {
- private final CountDownLatch getWorkBudgetDistributorTriggered;
+ private CountDownLatch getWorkBudgetDistributorTriggered;
private TestGetWorkBudgetDistributor(int numBudgetDistributionsExpected) {
this.getWorkBudgetDistributorTriggered = new
CountDownLatch(numBudgetDistributionsExpected);
}
- @SuppressWarnings("ReturnValueIgnored")
- private void waitForBudgetDistribution() throws InterruptedException {
- getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS);
+ private boolean waitForBudgetDistribution() throws InterruptedException {
+ return getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS);
+ }
+
+ private void expectNumDistributions(int numBudgetDistributionsExpected) {
+ this.getWorkBudgetDistributorTriggered = new
CountDownLatch(numBudgetDistributionsExpected);
}
@Override
public <T extends GetWorkBudgetSpender> void distributeBudget(
ImmutableCollection<T> streams, GetWorkBudget getWorkBudget) {
- streams.forEach(stream -> stream.adjustBudget(getWorkBudget.items(),
getWorkBudget.bytes()));
+ streams.forEach(stream -> stream.setBudget(getWorkBudget.items(),
getWorkBudget.bytes()));
getWorkBudgetDistributorTriggered.countDown();
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java
index dc6cc564105..32d1f573808 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java
@@ -193,7 +193,7 @@ public class WindmillStreamSenderTest {
WindmillStreamSender windmillStreamSender =
newWindmillStreamSender(GetWorkBudget.builder().setBytes(1L).setItems(1L).build());
- windmillStreamSender.closeAllStreams();
+ windmillStreamSender.close();
verifyNoInteractions(streamFactory);
}
@@ -230,7 +230,7 @@ public class WindmillStreamSenderTest {
mockStreamFactory);
windmillStreamSender.startStreams();
- windmillStreamSender.closeAllStreams();
+ windmillStreamSender.close();
verify(mockGetWorkStream).shutdown();
verify(mockGetDataStream).shutdown();
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java
new file mode 100644
index 00000000000..fd2b3023883
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java
@@ -0,0 +1,405 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+ private static final WorkItemScheduler NO_OP_WORK_ITEM_SCHEDULER =
+ (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {};
+ private static final Windmill.JobHeader TEST_JOB_HEADER =
+ Windmill.JobHeader.newBuilder()
+ .setClientId(1L)
+ .setJobId("test_job")
+ .setWorkerId("test_worker")
+ .setProjectId("test_project")
+ .build();
+ private static final String FAKE_SERVER_NAME = "Fake server for
GrpcDirectGetWorkStreamTest";
+ @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+ private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+ private ManagedChannel inProcessChannel;
+ private GrpcDirectGetWorkStream stream;
+
+ private static Windmill.StreamingGetWorkRequestExtension
extension(GetWorkBudget budget) {
+ return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+ .setMaxItems(budget.items())
+ .setMaxBytes(budget.bytes())
+ .build();
+ }
+
+ private static void assertHeader(
+ Windmill.StreamingGetWorkRequest getWorkRequest, GetWorkBudget
expectedInitialBudget) {
+ assertTrue(getWorkRequest.hasRequest());
+ assertFalse(getWorkRequest.hasRequestExtension());
+ assertThat(getWorkRequest.getRequest())
+ .isEqualTo(
+ Windmill.GetWorkRequest.newBuilder()
+ .setClientId(TEST_JOB_HEADER.getClientId())
+ .setJobId(TEST_JOB_HEADER.getJobId())
+ .setProjectId(TEST_JOB_HEADER.getProjectId())
+ .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+ .setMaxItems(expectedInitialBudget.items())
+ .setMaxBytes(expectedInitialBudget.bytes())
+ .build());
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ Server server =
+ InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+ .fallbackHandlerRegistry(serviceRegistry)
+ .directExecutor()
+ .build()
+ .start();
+
+ inProcessChannel =
+ grpcCleanup.register(
+
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+ grpcCleanup.register(server);
+ grpcCleanup.register(inProcessChannel);
+ }
+
+ @After
+ public void cleanUp() {
+ inProcessChannel.shutdownNow();
+ checkNotNull(stream).shutdown();
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ ThrottleTimer throttleTimer,
+ WorkItemScheduler workItemScheduler) {
+ serviceRegistry.addService(testStub);
+ return (GrpcDirectGetWorkStream)
+ GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+ .build()
+ .createDirectGetWorkStream(
+ WindmillConnection.builder()
+
.setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel))
+ .build(),
+ Windmill.GetWorkRequest.newBuilder()
+ .setClientId(TEST_JOB_HEADER.getClientId())
+ .setJobId(TEST_JOB_HEADER.getJobId())
+ .setProjectId(TEST_JOB_HEADER.getProjectId())
+ .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+ .setMaxItems(initialGetWorkBudget.items())
+ .setMaxBytes(initialGetWorkBudget.bytes())
+ .build(),
+ throttleTimer,
+ mock(HeartbeatSender.class),
+ mock(GetDataClient.class),
+ mock(WorkCommitter.class),
+ workItemScheduler);
+ }
+
+ private Windmill.StreamingGetWorkResponseChunk
createResponse(Windmill.WorkItem workItem) {
+ return Windmill.StreamingGetWorkResponseChunk.newBuilder()
+ .setStreamId(1L)
+ .setComputationMetadata(
+ Windmill.ComputationWorkItemMetadata.newBuilder()
+ .setComputationId("compId")
+ .setInputDataWatermark(1L)
+ .setDependentRealtimeInputWatermark(1L)
+ .build())
+ .setSerializedWorkItem(workItem.toByteString())
+ .setRemainingBytesForWorkItem(0)
+ .build();
+ }
+
+ @Test
+ public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget()
+ throws InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream =
+ createGetWorkStream(
+ testStub, GetWorkBudget.noBudget(), new ThrottleTimer(),
NO_OP_WORK_ITEM_SCHEDULER);
+ GetWorkBudget newBudget =
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+ stream.setBudget(newBudget);
+
+ assertTrue(waitForRequests.await(5, TimeUnit.SECONDS));
+
+ // Header and extension.
+ assertThat(requestObserver.sent()).hasSize(expectedRequests);
+ assertHeader(requestObserver.sent().get(0), GetWorkBudget.noBudget());
+ assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension())
+ .isEqualTo(extension(newBudget));
+ }
+
+ @Test
+ public void testSetBudget_computesAndSendsCorrectExtension_existingBudget()
+ throws InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ GetWorkBudget initialBudget =
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+ stream =
+ createGetWorkStream(
+ testStub, initialBudget, new ThrottleTimer(),
NO_OP_WORK_ITEM_SCHEDULER);
+ GetWorkBudget newBudget =
GetWorkBudget.builder().setItems(100).setBytes(100).build();
+ stream.setBudget(newBudget);
+ GetWorkBudget diff = newBudget.subtract(initialBudget);
+
+ assertTrue(waitForRequests.await(5, TimeUnit.SECONDS));
+
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ // Header and extension.
+ assertThat(requests).hasSize(expectedRequests);
+ assertHeader(requests.get(0), initialBudget);
+
assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff));
+ }
+
+ @Test
+ public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh()
+ throws InterruptedException {
+ int expectedRequests = 1;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ GetWorkBudget initialBudget =
+
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build();
+ stream =
+ createGetWorkStream(
+ testStub, initialBudget, new ThrottleTimer(),
NO_OP_WORK_ITEM_SCHEDULER);
+
stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build());
+
+ assertTrue(waitForRequests.await(5, TimeUnit.SECONDS));
+
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ // Assert that the extension was never sent, only the header.
+ assertThat(requests).hasSize(expectedRequests);
+ assertHeader(Iterables.getOnlyElement(requests), initialBudget);
+ }
+
+ @Test
+ public void testSetBudget_doesNothingIfStreamShutdown() throws
InterruptedException {
+ int expectedRequests = 1;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream =
+ createGetWorkStream(
+ testStub, GetWorkBudget.noBudget(), new ThrottleTimer(),
NO_OP_WORK_ITEM_SCHEDULER);
+ stream.shutdown();
+ stream.setBudget(
+
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+
+ assertTrue(waitForRequests.await(5, TimeUnit.SECONDS));
+
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ // Assert that the extension was never sent, only the header.
+ assertThat(requests).hasSize(1);
+ assertHeader(Iterables.getOnlyElement(requests), GetWorkBudget.noBudget());
+ }
+
+ @Test
+ public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws
InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ GetWorkBudget initialBudget =
GetWorkBudget.builder().setItems(1).setBytes(100).build();
+ Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+ stream =
+ createGetWorkStream(
+ testStub,
+ initialBudget,
+ new ThrottleTimer(),
+ (work, watermarks, processingContext, getWorkStreamLatencies) -> {
+ scheduledWorkItems.add(work);
+ });
+ Windmill.WorkItem workItem =
+ Windmill.WorkItem.newBuilder()
+ .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+ .setWorkToken(1L)
+ .setShardingKey(1L)
+ .setCacheToken(1L)
+ .build();
+
+ testStub.injectResponse(createResponse(workItem));
+
+ assertTrue(waitForRequests.await(5, TimeUnit.SECONDS));
+
+ assertThat(scheduledWorkItems).containsExactly(workItem);
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize();
+
+ assertThat(requests).hasSize(expectedRequests);
+ assertHeader(requests.get(0), initialBudget);
+ assertThat(Iterables.getLast(requests).getRequestExtension())
+ .isEqualTo(
+ extension(
+ GetWorkBudget.builder()
+ .setItems(1)
+ .setBytes(initialBudget.bytes() - inFlightBytes)
+ .build()));
+ }
+
+ @Test
+ public void
testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh()
+ throws InterruptedException {
+ int expectedRequests = 1;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+ GetWorkBudget initialBudget =
+
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build();
+ stream =
+ createGetWorkStream(
+ testStub,
+ initialBudget,
+ new ThrottleTimer(),
+ (work, watermarks, processingContext, getWorkStreamLatencies) ->
+ scheduledWorkItems.add(work));
+ Windmill.WorkItem workItem =
+ Windmill.WorkItem.newBuilder()
+ .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+ .setWorkToken(1L)
+ .setShardingKey(1L)
+ .setCacheToken(1L)
+ .build();
+
+ testStub.injectResponse(createResponse(workItem));
+
+ assertTrue(waitForRequests.await(5, TimeUnit.SECONDS));
+
+ assertThat(scheduledWorkItems).containsExactly(workItem);
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+
+ // Assert that the extension was never sent, only the header.
+ assertThat(requests).hasSize(expectedRequests);
+ assertHeader(Iterables.getOnlyElement(requests), initialBudget);
+ }
+
+ @Test
+ public void testOnResponse_stopsThrottling() {
+ ThrottleTimer throttleTimer = new ThrottleTimer();
+ TestGetWorkRequestObserver requestObserver =
+ new TestGetWorkRequestObserver(new CountDownLatch(1));
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream =
+ createGetWorkStream(
+ testStub, GetWorkBudget.noBudget(), throttleTimer,
NO_OP_WORK_ITEM_SCHEDULER);
+ stream.startThrottleTimer();
+ assertTrue(throttleTimer.throttled());
+
testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance());
+ assertFalse(throttleTimer.throttled());
+ }
+
+ private static class GetWorkStreamTestStub
+ extends
CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase {
+ private final TestGetWorkRequestObserver requestObserver;
+ private @Nullable StreamObserver<Windmill.StreamingGetWorkResponseChunk>
responseObserver;
+
+ private GetWorkStreamTestStub(TestGetWorkRequestObserver requestObserver) {
+ this.requestObserver = requestObserver;
+ }
+
+ @Override
+ public StreamObserver<Windmill.StreamingGetWorkRequest> getWorkStream(
+ StreamObserver<Windmill.StreamingGetWorkResponseChunk>
responseObserver) {
+ if (this.responseObserver == null) {
+ this.responseObserver = responseObserver;
+ requestObserver.responseObserver = this.responseObserver;
+ }
+
+ return requestObserver;
+ }
+
+ private void injectResponse(Windmill.StreamingGetWorkResponseChunk
responseChunk) {
+ checkNotNull(responseObserver).onNext(responseChunk);
+ }
+ }
+
+ private static class TestGetWorkRequestObserver
+ implements StreamObserver<Windmill.StreamingGetWorkRequest> {
+ private final List<Windmill.StreamingGetWorkRequest> requests =
+ Collections.synchronizedList(new ArrayList<>());
+ private final CountDownLatch waitForRequests;
+ private @Nullable volatile
StreamObserver<Windmill.StreamingGetWorkResponseChunk>
+ responseObserver;
+
+ public TestGetWorkRequestObserver(CountDownLatch waitForRequests) {
+ this.waitForRequests = waitForRequests;
+ }
+
+ @Override
+ public void onNext(Windmill.StreamingGetWorkRequest request) {
+ requests.add(request);
+ waitForRequests.countDown();
+ }
+
+ @Override
+ public void onError(Throwable throwable) {}
+
+ @Override
+ public void onCompleted() {
+ responseObserver.onCompleted();
+ }
+
+ List<Windmill.StreamingGetWorkRequest> sent() {
+ return requests;
+ }
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java
index 3cda4559c10..c76d5a58418 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java
@@ -17,9 +17,7 @@
*/
package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
-import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@@ -40,169 +38,79 @@ public class EvenGetWorkBudgetDistributorTest {
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);
- private static GetWorkBudgetDistributor
createBudgetDistributor(GetWorkBudget activeWorkBudget) {
- return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget);
- }
+ private static GetWorkBudgetSpender createGetWorkBudgetOwner() {
+ // Lambdas are final and cannot be spied.
+ return spy(
+ new GetWorkBudgetSpender() {
- private static GetWorkBudgetDistributor createBudgetDistributor(long
activeWorkItemsAndBytes) {
- return createBudgetDistributor(
- GetWorkBudget.builder()
- .setItems(activeWorkItemsAndBytes)
- .setBytes(activeWorkItemsAndBytes)
- .build());
+ @Override
+ public void setBudget(long items, long bytes) {}
+ });
}
@Test
public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() {
- createBudgetDistributor(1L)
+ GetWorkBudgetDistributors.distributeEvenly()
.distributeBudget(
ImmutableList.of(),
GetWorkBudget.builder().setItems(10L).setBytes(10L).build());
}
@Test
public void testDistributeBudget_doesNothingWithNoBudget() {
- GetWorkBudgetSpender getWorkBudgetSpender =
-
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()));
- createBudgetDistributor(1L)
+ GetWorkBudgetSpender getWorkBudgetSpender = createGetWorkBudgetOwner();
+ GetWorkBudgetDistributors.distributeEvenly()
.distributeBudget(ImmutableList.of(getWorkBudgetSpender),
GetWorkBudget.noBudget());
verifyNoInteractions(getWorkBudgetSpender);
}
@Test
- public void
testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighNoActiveWork()
{
- GetWorkBudgetSpender getWorkBudgetSpender =
- spy(
- createGetWorkBudgetOwnerWithRemainingBudgetOf(
- GetWorkBudget.builder().setItems(10L).setBytes(10L).build()));
- createBudgetDistributor(0L)
- .distributeBudget(
- ImmutableList.of(getWorkBudgetSpender),
- GetWorkBudget.builder().setItems(10L).setBytes(10L).build());
-
- verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong());
- }
-
- @Test
- public void
-
testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighWithActiveWork()
{
- GetWorkBudgetSpender getWorkBudgetSpender =
- spy(
- createGetWorkBudgetOwnerWithRemainingBudgetOf(
- GetWorkBudget.builder().setItems(5L).setBytes(5L).build()));
- createBudgetDistributor(10L)
+ public void testDistributeBudget_distributesBudgetEvenlyIfPossible() {
+ int totalStreams = 10;
+ long totalItems = 10L;
+ long totalBytes = 100L;
+ List<GetWorkBudgetSpender> streams = new ArrayList<>();
+ for (int i = 0; i < totalStreams; i++) {
+ streams.add(createGetWorkBudgetOwner());
+ }
+ GetWorkBudgetDistributors.distributeEvenly()
.distributeBudget(
- ImmutableList.of(getWorkBudgetSpender),
- GetWorkBudget.builder().setItems(20L).setBytes(20L).build());
-
- verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong());
- }
-
- @Test
- public void
-
testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithNoActiveWork()
{
- GetWorkBudget streamRemainingBudget =
- GetWorkBudget.builder().setItems(1L).setBytes(10L).build();
- GetWorkBudget totalGetWorkBudget =
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
- GetWorkBudgetSpender getWorkBudgetSpender =
-
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
- createBudgetDistributor(0L)
- .distributeBudget(ImmutableList.of(getWorkBudgetSpender),
totalGetWorkBudget);
-
- verify(getWorkBudgetSpender, times(1))
- .adjustBudget(
- eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
- eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
- }
-
- @Test
- public void
-
testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithActiveWork()
{
- GetWorkBudget streamRemainingBudget =
- GetWorkBudget.builder().setItems(1L).setBytes(10L).build();
- GetWorkBudget totalGetWorkBudget =
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
- long activeWorkItemsAndBytes = 2L;
- GetWorkBudgetSpender getWorkBudgetSpender =
-
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
- createBudgetDistributor(activeWorkItemsAndBytes)
- .distributeBudget(ImmutableList.of(getWorkBudgetSpender),
totalGetWorkBudget);
-
- verify(getWorkBudgetSpender, times(1))
- .adjustBudget(
- eq(
- totalGetWorkBudget.items()
- - streamRemainingBudget.items()
- - activeWorkItemsAndBytes),
- eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
- }
-
- @Test
- public void
testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowNoActiveWork()
{
- GetWorkBudget streamRemainingBudget =
- GetWorkBudget.builder().setItems(10L).setBytes(1L).build();
- GetWorkBudget totalGetWorkBudget =
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
- GetWorkBudgetSpender getWorkBudgetSpender =
-
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
- createBudgetDistributor(0L)
- .distributeBudget(ImmutableList.of(getWorkBudgetSpender),
totalGetWorkBudget);
-
- verify(getWorkBudgetSpender, times(1))
- .adjustBudget(
- eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
- eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
- }
-
- @Test
- public void
-
testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowWithActiveWork()
{
- GetWorkBudget streamRemainingBudget =
- GetWorkBudget.builder().setItems(10L).setBytes(1L).build();
- GetWorkBudget totalGetWorkBudget =
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
- long activeWorkItemsAndBytes = 2L;
-
- GetWorkBudgetSpender getWorkBudgetSpender =
-
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
- createBudgetDistributor(activeWorkItemsAndBytes)
- .distributeBudget(ImmutableList.of(getWorkBudgetSpender),
totalGetWorkBudget);
+ ImmutableList.copyOf(streams),
+
GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build());
- verify(getWorkBudgetSpender, times(1))
- .adjustBudget(
- eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
- eq(
- totalGetWorkBudget.bytes()
- - streamRemainingBudget.bytes()
- - activeWorkItemsAndBytes));
+ streams.forEach(
+ stream ->
+ verify(stream, times(1))
+
.setBudget(eq(GetWorkBudget.builder().setItems(1L).setBytes(10L).build())));
}
@Test
- public void testDistributeBudget_distributesBudgetEvenlyIfPossible() {
- long totalItemsAndBytes = 10L;
+ public void testDistributeBudget_distributesFairlyWhenNotEven() {
+ long totalItems = 10L;
+ long totalBytes = 19L;
List<GetWorkBudgetSpender> streams = new ArrayList<>();
- for (int i = 0; i < totalItemsAndBytes; i++) {
-
streams.add(spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget())));
+ for (int i = 0; i < 3; i++) {
+ streams.add(createGetWorkBudgetOwner());
}
- createBudgetDistributor(0L)
+ GetWorkBudgetDistributors.distributeEvenly()
.distributeBudget(
ImmutableList.copyOf(streams),
- GetWorkBudget.builder()
- .setItems(totalItemsAndBytes)
- .setBytes(totalItemsAndBytes)
- .build());
+
GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build());
- long itemsAndBytesPerStream = totalItemsAndBytes / streams.size();
streams.forEach(
stream ->
verify(stream, times(1))
- .adjustBudget(eq(itemsAndBytesPerStream),
eq(itemsAndBytesPerStream)));
+
.setBudget(eq(GetWorkBudget.builder().setItems(4L).setBytes(7L).build())));
}
@Test
- public void testDistributeBudget_distributesFairlyWhenNotEven() {
+ public void testDistributeBudget_distributesBudgetEvenly() {
long totalItemsAndBytes = 10L;
List<GetWorkBudgetSpender> streams = new ArrayList<>();
- for (int i = 0; i < 3; i++) {
-
streams.add(spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget())));
+ for (int i = 0; i < totalItemsAndBytes; i++) {
+ streams.add(createGetWorkBudgetOwner());
}
- createBudgetDistributor(0L)
+
+ GetWorkBudgetDistributors.distributeEvenly()
.distributeBudget(
ImmutableList.copyOf(streams),
GetWorkBudget.builder()
@@ -210,24 +118,10 @@ public class EvenGetWorkBudgetDistributorTest {
.setBytes(totalItemsAndBytes)
.build());
- long itemsAndBytesPerStream = (long) Math.ceil(totalItemsAndBytes /
(streams.size() * 1.0));
+ long itemsAndBytesPerStream = totalItemsAndBytes / streams.size();
streams.forEach(
stream ->
verify(stream, times(1))
- .adjustBudget(eq(itemsAndBytesPerStream),
eq(itemsAndBytesPerStream)));
- }
-
- private GetWorkBudgetSpender createGetWorkBudgetOwnerWithRemainingBudgetOf(
- GetWorkBudget getWorkBudget) {
- return spy(
- new GetWorkBudgetSpender() {
- @Override
- public void adjustBudget(long itemsDelta, long bytesDelta) {}
-
- @Override
- public GetWorkBudget remainingBudget() {
- return getWorkBudget;
- }
- });
+ .setBudget(eq(itemsAndBytesPerStream),
eq(itemsAndBytesPerStream)));
}
}