This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 68f1543b6bf Simplify budget distribution logic and new worker metadata 
consumption (#32775)
68f1543b6bf is described below

commit 68f1543b6bfe4ceaa752c7f16fc2cae7393211fd
Author: martin trieu <[email protected]>
AuthorDate: Mon Oct 21 03:05:43 2024 -0600

    Simplify budget distribution logic and new worker metadata consumption 
(#32775)
---
 .../FanOutStreamingEngineWorkerHarness.java        | 379 +++++++++----------
 .../streaming/harness/GlobalDataStreamSender.java  |  63 ++++
 ...tionState.java => StreamingEngineBackends.java} |  30 +-
 .../streaming/harness/WindmillStreamSender.java    |  25 +-
 .../worker/windmill/WindmillEndpoints.java         |  28 +-
 .../worker/windmill/WindmillServiceAddress.java    |  22 +-
 .../worker/windmill/client/WindmillStream.java     |   7 +-
 .../client/grpc/GrpcDirectGetWorkStream.java       | 286 ++++++++++-----
 .../windmill/client/grpc/GrpcGetDataStream.java    |   2 +-
 .../windmill/client/grpc/GrpcGetWorkStream.java    |  10 +-
 .../client/grpc/GrpcWindmillStreamFactory.java     |   6 +-
 .../client/grpc/stubs/WindmillChannelFactory.java  |  17 +-
 .../work/budget/EvenGetWorkBudgetDistributor.java  |  59 +--
 .../work/budget/GetWorkBudgetDistributors.java     |   6 +-
 .../windmill/work/budget/GetWorkBudgetSpender.java |   8 +-
 .../dataflow/worker/FakeWindmillServer.java        |  10 +-
 .../FanOutStreamingEngineWorkerHarnessTest.java    | 111 ++----
 .../harness/WindmillStreamSenderTest.java          |   4 +-
 .../client/grpc/GrpcDirectGetWorkStreamTest.java   | 405 +++++++++++++++++++++
 .../budget/EvenGetWorkBudgetDistributorTest.java   | 186 ++--------
 20 files changed, 998 insertions(+), 666 deletions(-)

diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java
index 3556b7ce291..458cf57ca8e 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java
@@ -20,20 +20,25 @@ package 
org.apache.beam.runners.dataflow.worker.streaming.harness;
 import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
 import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet;
 
-import java.util.Collection;
-import java.util.List;
+import java.io.Closeable;
+import java.util.HashSet;
 import java.util.Map.Entry;
+import java.util.NoSuchElementException;
 import java.util.Optional;
-import java.util.Queue;
-import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionStage;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
-import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Function;
-import java.util.function.Supplier;
+import java.util.stream.Collectors;
 import javax.annotation.CheckReturnValue;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
 import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
 import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
@@ -54,18 +59,14 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.Thrott
 import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
 import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
 import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
-import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher;
 import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.util.MoreFutures;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
-import org.joda.time.Instant;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -80,32 +81,39 @@ import org.slf4j.LoggerFactory;
 public final class FanOutStreamingEngineWorkerHarness implements 
StreamingWorkerHarness {
   private static final Logger LOG =
       LoggerFactory.getLogger(FanOutStreamingEngineWorkerHarness.class);
-  private static final String PUBLISH_NEW_WORKER_METADATA_THREAD = 
"PublishNewWorkerMetadataThread";
-  private static final String CONSUME_NEW_WORKER_METADATA_THREAD = 
"ConsumeNewWorkerMetadataThread";
+  private static final String WORKER_METADATA_CONSUMER_THREAD_NAME =
+      "WindmillWorkerMetadataConsumerThread";
+  private static final String STREAM_MANAGER_THREAD_NAME = 
"WindmillStreamManager-%d";
 
   private final JobHeader jobHeader;
   private final GrpcWindmillStreamFactory streamFactory;
   private final WorkItemScheduler workItemScheduler;
   private final ChannelCachingStubFactory channelCachingStubFactory;
   private final GrpcDispatcherClient dispatcherClient;
-  private final AtomicBoolean isBudgetRefreshPaused;
-  private final GetWorkBudgetRefresher getWorkBudgetRefresher;
-  private final AtomicReference<Instant> lastBudgetRefresh;
+  private final GetWorkBudgetDistributor getWorkBudgetDistributor;
+  private final GetWorkBudget totalGetWorkBudget;
   private final ThrottleTimer getWorkerMetadataThrottleTimer;
-  private final ExecutorService newWorkerMetadataPublisher;
-  private final ExecutorService newWorkerMetadataConsumer;
-  private final long clientId;
-  private final Supplier<GetWorkerMetadataStream> getWorkerMetadataStream;
-  private final Queue<WindmillEndpoints> newWindmillEndpoints;
   private final Function<WindmillStream.CommitWorkStream, WorkCommitter> 
workCommitterFactory;
   private final ThrottlingGetDataMetricTracker getDataMetricTracker;
+  private final ExecutorService windmillStreamManager;
+  private final ExecutorService workerMetadataConsumer;
+  private final Object metadataLock = new Object();
 
   /** Writes are guarded by synchronization, reads are lock free. */
-  private final AtomicReference<StreamingEngineConnectionState> connections;
+  private final AtomicReference<StreamingEngineBackends> backends;
 
-  private volatile boolean started;
+  @GuardedBy("this")
+  private long activeMetadataVersion;
+
+  @GuardedBy("metadataLock")
+  private long pendingMetadataVersion;
+
+  @GuardedBy("this")
+  private boolean started;
+
+  @GuardedBy("this")
+  private @Nullable GetWorkerMetadataStream getWorkerMetadataStream;
 
-  @SuppressWarnings("FutureReturnValueIgnored")
   private FanOutStreamingEngineWorkerHarness(
       JobHeader jobHeader,
       GetWorkBudget totalGetWorkBudget,
@@ -114,7 +122,6 @@ public final class FanOutStreamingEngineWorkerHarness 
implements StreamingWorker
       ChannelCachingStubFactory channelCachingStubFactory,
       GetWorkBudgetDistributor getWorkBudgetDistributor,
       GrpcDispatcherClient dispatcherClient,
-      long clientId,
       Function<WindmillStream.CommitWorkStream, WorkCommitter> 
workCommitterFactory,
       ThrottlingGetDataMetricTracker getDataMetricTracker) {
     this.jobHeader = jobHeader;
@@ -122,42 +129,21 @@ public final class FanOutStreamingEngineWorkerHarness 
implements StreamingWorker
     this.started = false;
     this.streamFactory = streamFactory;
     this.workItemScheduler = workItemScheduler;
-    this.connections = new 
AtomicReference<>(StreamingEngineConnectionState.EMPTY);
+    this.backends = new AtomicReference<>(StreamingEngineBackends.EMPTY);
     this.channelCachingStubFactory = channelCachingStubFactory;
     this.dispatcherClient = dispatcherClient;
-    this.isBudgetRefreshPaused = new AtomicBoolean(false);
     this.getWorkerMetadataThrottleTimer = new ThrottleTimer();
-    this.newWorkerMetadataPublisher =
-        singleThreadedExecutorServiceOf(PUBLISH_NEW_WORKER_METADATA_THREAD);
-    this.newWorkerMetadataConsumer =
-        singleThreadedExecutorServiceOf(CONSUME_NEW_WORKER_METADATA_THREAD);
-    this.clientId = clientId;
-    this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH);
-    this.newWindmillEndpoints = 
Queues.synchronizedQueue(EvictingQueue.create(1));
-    this.getWorkBudgetRefresher =
-        new GetWorkBudgetRefresher(
-            isBudgetRefreshPaused::get,
-            () -> {
-              getWorkBudgetDistributor.distributeBudget(
-                  connections.get().windmillStreams().values(), 
totalGetWorkBudget);
-              lastBudgetRefresh.set(Instant.now());
-            });
-    this.getWorkerMetadataStream =
-        Suppliers.memoize(
-            () ->
-                streamFactory.createGetWorkerMetadataStream(
-                    dispatcherClient.getWindmillMetadataServiceStubBlocking(),
-                    getWorkerMetadataThrottleTimer,
-                    endpoints ->
-                        // Run this on a separate thread than the grpc stream 
thread.
-                        newWorkerMetadataPublisher.submit(
-                            () -> newWindmillEndpoints.add(endpoints))));
+    this.windmillStreamManager =
+        Executors.newCachedThreadPool(
+            new 
ThreadFactoryBuilder().setNameFormat(STREAM_MANAGER_THREAD_NAME).build());
+    this.workerMetadataConsumer =
+        Executors.newSingleThreadScheduledExecutor(
+            new 
ThreadFactoryBuilder().setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME).build());
+    this.getWorkBudgetDistributor = getWorkBudgetDistributor;
+    this.totalGetWorkBudget = totalGetWorkBudget;
+    this.activeMetadataVersion = Long.MIN_VALUE;
     this.workCommitterFactory = workCommitterFactory;
-  }
-
-  private static ExecutorService singleThreadedExecutorServiceOf(String 
threadName) {
-    return Executors.newSingleThreadScheduledExecutor(
-        new ThreadFactoryBuilder().setNameFormat(threadName).build());
+    this.getWorkerMetadataStream = null;
   }
 
   /**
@@ -183,7 +169,6 @@ public final class FanOutStreamingEngineWorkerHarness 
implements StreamingWorker
         channelCachingStubFactory,
         getWorkBudgetDistributor,
         dispatcherClient,
-        /* clientId= */ new Random().nextLong(),
         workCommitterFactory,
         getDataMetricTracker);
   }
@@ -197,7 +182,6 @@ public final class FanOutStreamingEngineWorkerHarness 
implements StreamingWorker
       ChannelCachingStubFactory stubFactory,
       GetWorkBudgetDistributor getWorkBudgetDistributor,
       GrpcDispatcherClient dispatcherClient,
-      long clientId,
       Function<WindmillStream.CommitWorkStream, WorkCommitter> 
workCommitterFactory,
       ThrottlingGetDataMetricTracker getDataMetricTracker) {
     FanOutStreamingEngineWorkerHarness fanOutStreamingEngineWorkProvider =
@@ -209,201 +193,218 @@ public final class FanOutStreamingEngineWorkerHarness 
implements StreamingWorker
             stubFactory,
             getWorkBudgetDistributor,
             dispatcherClient,
-            clientId,
             workCommitterFactory,
             getDataMetricTracker);
     fanOutStreamingEngineWorkProvider.start();
     return fanOutStreamingEngineWorkProvider;
   }
 
-  @SuppressWarnings("ReturnValueIgnored")
   @Override
   public synchronized void start() {
-    Preconditions.checkState(!started, "StreamingEngineClient cannot start 
twice.");
-    // Starts the stream, this value is memoized.
-    getWorkerMetadataStream.get();
-    startWorkerMetadataConsumer();
-    getWorkBudgetRefresher.start();
+    Preconditions.checkState(!started, "FanOutStreamingEngineWorkerHarness 
cannot start twice.");
+    getWorkerMetadataStream =
+        streamFactory.createGetWorkerMetadataStream(
+            dispatcherClient.getWindmillMetadataServiceStubBlocking(),
+            getWorkerMetadataThrottleTimer,
+            this::consumeWorkerMetadata);
     started = true;
   }
 
   public ImmutableSet<HostAndPort> currentWindmillEndpoints() {
-    return connections.get().windmillConnections().keySet().stream()
+    return backends.get().windmillStreams().keySet().stream()
         .map(Endpoint::directEndpoint)
         .filter(Optional::isPresent)
         .map(Optional::get)
-        .filter(
-            windmillServiceAddress ->
-                windmillServiceAddress.getKind() != 
WindmillServiceAddress.Kind.IPV6)
-        .map(
-            windmillServiceAddress ->
-                windmillServiceAddress.getKind() == 
WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS
-                    ? windmillServiceAddress.gcpServiceAddress()
-                    : 
windmillServiceAddress.authenticatedGcpServiceAddress().gcpServiceAddress())
+        .map(WindmillServiceAddress::getServiceAddress)
         .collect(toImmutableSet());
   }
 
   /**
-   * Fetches {@link GetDataStream} mapped to globalDataKey if one exists, or 
defaults to {@link
-   * GetDataStream} pointing to dispatcher.
+   * Fetches {@link GetDataStream} mapped to globalDataKey if or throws {@link
+   * NoSuchElementException} if one is not found.
    */
   private GetDataStream getGlobalDataStream(String globalDataKey) {
-    return 
Optional.ofNullable(connections.get().globalDataStreams().get(globalDataKey))
-        .map(Supplier::get)
-        .orElseGet(
-            () ->
-                streamFactory.createGetDataStream(
-                    dispatcherClient.getWindmillServiceStub(), new 
ThrottleTimer()));
-  }
-
-  @SuppressWarnings("FutureReturnValueIgnored")
-  private void startWorkerMetadataConsumer() {
-    newWorkerMetadataConsumer.submit(
-        () -> {
-          while (true) {
-            Optional.ofNullable(newWindmillEndpoints.poll())
-                .ifPresent(this::consumeWindmillWorkerEndpoints);
-          }
-        });
+    return 
Optional.ofNullable(backends.get().globalDataStreams().get(globalDataKey))
+        .map(GlobalDataStreamSender::get)
+        .orElseThrow(
+            () -> new NoSuchElementException("No endpoint for global data tag: 
" + globalDataKey));
   }
 
   @VisibleForTesting
   @Override
   public synchronized void shutdown() {
-    Preconditions.checkState(started, "StreamingEngineClient never started.");
-    getWorkerMetadataStream.get().halfClose();
-    getWorkBudgetRefresher.stop();
-    newWorkerMetadataPublisher.shutdownNow();
-    newWorkerMetadataConsumer.shutdownNow();
+    Preconditions.checkState(started, "FanOutStreamingEngineWorkerHarness 
never started.");
+    Preconditions.checkNotNull(getWorkerMetadataStream).shutdown();
+    workerMetadataConsumer.shutdownNow();
+    closeStreamsNotIn(WindmillEndpoints.none());
     channelCachingStubFactory.shutdown();
+
+    try {
+      Preconditions.checkNotNull(getWorkerMetadataStream).awaitTermination(10, 
TimeUnit.SECONDS);
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      LOG.warn("Interrupted waiting for GetWorkerMetadataStream to shutdown.", 
e);
+    }
+
+    windmillStreamManager.shutdown();
+    boolean isStreamManagerShutdown = false;
+    try {
+      isStreamManagerShutdown = windmillStreamManager.awaitTermination(30, 
TimeUnit.SECONDS);
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      LOG.warn("Interrupted waiting for windmillStreamManager to shutdown.", 
e);
+    }
+    if (!isStreamManagerShutdown) {
+      windmillStreamManager.shutdownNow();
+    }
+  }
+
+  private void consumeWorkerMetadata(WindmillEndpoints windmillEndpoints) {
+    synchronized (metadataLock) {
+      // Only process versions greater than what we currently have to prevent 
double processing of
+      // metadata. workerMetadataConsumer is single-threaded so we maintain 
ordering.
+      if (windmillEndpoints.version() > pendingMetadataVersion) {
+        pendingMetadataVersion = windmillEndpoints.version();
+        workerMetadataConsumer.execute(() -> 
consumeWindmillWorkerEndpoints(windmillEndpoints));
+      }
+    }
   }
 
-  /**
-   * {@link java.util.function.Consumer<WindmillEndpoints>} used to update 
{@link #connections} on
-   * new backend worker metadata.
-   */
   private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints 
newWindmillEndpoints) {
-    isBudgetRefreshPaused.set(true);
-    LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints);
-    ImmutableMap<Endpoint, WindmillConnection> newWindmillConnections =
-        createNewWindmillConnections(newWindmillEndpoints.windmillEndpoints());
-
-    StreamingEngineConnectionState newConnectionsState =
-        StreamingEngineConnectionState.builder()
-            .setWindmillConnections(newWindmillConnections)
-            .setWindmillStreams(
-                
closeStaleStreamsAndCreateNewStreams(newWindmillConnections.values()))
+    // Since this is run on a single threaded executor, multiple versions of 
the metadata maybe
+    // queued up while a previous version of the windmillEndpoints were being 
consumed. Only consume
+    // the endpoints if they are the most current version.
+    synchronized (metadataLock) {
+      if (newWindmillEndpoints.version() < pendingMetadataVersion) {
+        return;
+      }
+    }
+
+    LOG.debug(
+        "Consuming new endpoints: {}. previous metadata version: {}, current 
metadata version: {}",
+        newWindmillEndpoints,
+        activeMetadataVersion,
+        newWindmillEndpoints.version());
+    closeStreamsNotIn(newWindmillEndpoints);
+    ImmutableMap<Endpoint, WindmillStreamSender> newStreams =
+        
createAndStartNewStreams(newWindmillEndpoints.windmillEndpoints()).join();
+    StreamingEngineBackends newBackends =
+        StreamingEngineBackends.builder()
+            .setWindmillStreams(newStreams)
             .setGlobalDataStreams(
                 
createNewGlobalDataStreams(newWindmillEndpoints.globalDataEndpoints()))
             .build();
+    backends.set(newBackends);
+    getWorkBudgetDistributor.distributeBudget(newStreams.values(), 
totalGetWorkBudget);
+    activeMetadataVersion = newWindmillEndpoints.version();
+  }
+
+  /** Close the streams that are no longer valid asynchronously. */
+  private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) {
+    StreamingEngineBackends currentBackends = backends.get();
+    currentBackends.windmillStreams().entrySet().stream()
+        .filter(
+            connectionAndStream ->
+                
!newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey()))
+        .forEach(
+            entry ->
+                windmillStreamManager.execute(
+                    () -> closeStreamSender(entry.getKey(), 
entry.getValue())));
 
-    LOG.info(
-        "Setting new connections: {}. Previous connections: {}.",
-        newConnectionsState,
-        connections.get());
-    connections.set(newConnectionsState);
-    isBudgetRefreshPaused.set(false);
-    getWorkBudgetRefresher.requestBudgetRefresh();
+    Set<Endpoint> newGlobalDataEndpoints =
+        new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values());
+    currentBackends.globalDataStreams().values().stream()
+        .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint()))
+        .forEach(
+            sender ->
+                windmillStreamManager.execute(() -> 
closeStreamSender(sender.endpoint(), sender)));
+  }
+
+  private void closeStreamSender(Endpoint endpoint, Closeable sender) {
+    LOG.debug("Closing streams to endpoint={}, sender={}", endpoint, sender);
+    try {
+      sender.close();
+      endpoint.directEndpoint().ifPresent(channelCachingStubFactory::remove);
+      LOG.debug("Successfully closed streams to {}", endpoint);
+    } catch (Exception e) {
+      LOG.error("Error closing streams to endpoint={}, sender={}", endpoint, 
sender);
+    }
+  }
+
+  private synchronized CompletableFuture<ImmutableMap<Endpoint, 
WindmillStreamSender>>
+      createAndStartNewStreams(ImmutableSet<Endpoint> newWindmillEndpoints) {
+    ImmutableMap<Endpoint, WindmillStreamSender> currentStreams = 
backends.get().windmillStreams();
+    return MoreFutures.allAsList(
+            newWindmillEndpoints.stream()
+                .map(endpoint -> 
getOrCreateWindmillStreamSenderFuture(endpoint, currentStreams))
+                .collect(Collectors.toList()))
+        .thenApply(
+            backends -> 
backends.stream().collect(toImmutableMap(Pair::getLeft, Pair::getRight)))
+        .toCompletableFuture();
+  }
+
+  private CompletionStage<Pair<Endpoint, WindmillStreamSender>>
+      getOrCreateWindmillStreamSenderFuture(
+          Endpoint endpoint, ImmutableMap<Endpoint, WindmillStreamSender> 
currentStreams) {
+    return MoreFutures.supplyAsync(
+        () ->
+            Pair.of(
+                endpoint,
+                Optional.ofNullable(currentStreams.get(endpoint))
+                    .orElseGet(() -> 
createAndStartWindmillStreamSender(endpoint))),
+        windmillStreamManager);
   }
 
   /** Add up all the throttle times of all streams including 
GetWorkerMetadataStream. */
-  public long getAndResetThrottleTimes() {
-    return connections.get().windmillStreams().values().stream()
+  public long getAndResetThrottleTime() {
+    return backends.get().windmillStreams().values().stream()
             .map(WindmillStreamSender::getAndResetThrottleTime)
             .reduce(0L, Long::sum)
         + getWorkerMetadataThrottleTimer.getAndResetThrottleTime();
   }
 
   public long currentActiveCommitBytes() {
-    return connections.get().windmillStreams().values().stream()
+    return backends.get().windmillStreams().values().stream()
         .map(WindmillStreamSender::getCurrentActiveCommitBytes)
         .reduce(0L, Long::sum);
   }
 
   @VisibleForTesting
-  StreamingEngineConnectionState getCurrentConnections() {
-    return connections.get();
-  }
-
-  private synchronized ImmutableMap<Endpoint, WindmillConnection> 
createNewWindmillConnections(
-      List<Endpoint> newWindmillEndpoints) {
-    ImmutableMap<Endpoint, WindmillConnection> currentConnections =
-        connections.get().windmillConnections();
-    return newWindmillEndpoints.stream()
-        .collect(
-            toImmutableMap(
-                Function.identity(),
-                endpoint ->
-                    // Reuse existing stubs if they exist. Optional.orElseGet 
only calls the
-                    // supplier if the value is not present, preventing 
constructing expensive
-                    // objects.
-                    Optional.ofNullable(currentConnections.get(endpoint))
-                        .orElseGet(
-                            () -> WindmillConnection.from(endpoint, 
this::createWindmillStub))));
+  StreamingEngineBackends currentBackends() {
+    return backends.get();
   }
 
-  private synchronized ImmutableMap<WindmillConnection, WindmillStreamSender>
-      closeStaleStreamsAndCreateNewStreams(Collection<WindmillConnection> 
newWindmillConnections) {
-    ImmutableMap<WindmillConnection, WindmillStreamSender> currentStreams =
-        connections.get().windmillStreams();
-
-    // Close the streams that are no longer valid.
-    currentStreams.entrySet().stream()
-        .filter(
-            connectionAndStream -> 
!newWindmillConnections.contains(connectionAndStream.getKey()))
-        .forEach(
-            entry -> {
-              entry.getValue().closeAllStreams();
-              
entry.getKey().directEndpoint().ifPresent(channelCachingStubFactory::remove);
-            });
-
-    return newWindmillConnections.stream()
-        .collect(
-            toImmutableMap(
-                Function.identity(),
-                newConnection ->
-                    Optional.ofNullable(currentStreams.get(newConnection))
-                        .orElseGet(() -> 
createAndStartWindmillStreamSenderFor(newConnection))));
-  }
-
-  private ImmutableMap<String, Supplier<GetDataStream>> 
createNewGlobalDataStreams(
+  private ImmutableMap<String, GlobalDataStreamSender> 
createNewGlobalDataStreams(
       ImmutableMap<String, Endpoint> newGlobalDataEndpoints) {
-    ImmutableMap<String, Supplier<GetDataStream>> currentGlobalDataStreams =
-        connections.get().globalDataStreams();
+    ImmutableMap<String, GlobalDataStreamSender> currentGlobalDataStreams =
+        backends.get().globalDataStreams();
     return newGlobalDataEndpoints.entrySet().stream()
         .collect(
             toImmutableMap(
                 Entry::getKey,
                 keyedEndpoint ->
-                    existingOrNewGetDataStreamFor(keyedEndpoint, 
currentGlobalDataStreams)));
+                    getOrCreateGlobalDataSteam(keyedEndpoint, 
currentGlobalDataStreams)));
   }
 
-  private Supplier<GetDataStream> existingOrNewGetDataStreamFor(
+  private GlobalDataStreamSender getOrCreateGlobalDataSteam(
       Entry<String, Endpoint> keyedEndpoint,
-      ImmutableMap<String, Supplier<GetDataStream>> currentGlobalDataStreams) {
-    return Preconditions.checkNotNull(
-        currentGlobalDataStreams.getOrDefault(
-            keyedEndpoint.getKey(),
+      ImmutableMap<String, GlobalDataStreamSender> currentGlobalDataStreams) {
+    return 
Optional.ofNullable(currentGlobalDataStreams.get(keyedEndpoint.getKey()))
+        .orElseGet(
             () ->
-                streamFactory.createGetDataStream(
-                    newOrExistingStubFor(keyedEndpoint.getValue()), new 
ThrottleTimer())));
-  }
-
-  private CloudWindmillServiceV1Alpha1Stub newOrExistingStubFor(Endpoint 
endpoint) {
-    return 
Optional.ofNullable(connections.get().windmillConnections().get(endpoint))
-        .map(WindmillConnection::stub)
-        .orElseGet(() -> createWindmillStub(endpoint));
+                new GlobalDataStreamSender(
+                    () ->
+                        streamFactory.createGetDataStream(
+                            createWindmillStub(keyedEndpoint.getValue()), new 
ThrottleTimer()),
+                    keyedEndpoint.getValue()));
   }
 
-  private WindmillStreamSender createAndStartWindmillStreamSenderFor(
-      WindmillConnection connection) {
-    // Initially create each stream with no budget. The budget will be 
eventually assigned by the
-    // GetWorkBudgetDistributor.
+  private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint 
endpoint) {
     WindmillStreamSender windmillStreamSender =
         WindmillStreamSender.create(
-            connection,
+            WindmillConnection.from(endpoint, this::createWindmillStub),
             GetWorkRequest.newBuilder()
-                .setClientId(clientId)
+                .setClientId(jobHeader.getClientId())
                 .setJobId(jobHeader.getJobId())
                 .setProjectId(jobHeader.getProjectId())
                 .setWorkerId(jobHeader.getWorkerId())
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java
new file mode 100644
index 00000000000..ce5f3a7b6bf
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.streaming.harness;
+
+import java.io.Closeable;
+import java.util.function.Supplier;
+import javax.annotation.concurrent.ThreadSafe;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
+
+@Internal
+@ThreadSafe
+// TODO (m-trieu): replace Supplier<Stream> with Stream after 
github.com/apache/beam/pull/32774/ is
+// merged
+final class GlobalDataStreamSender implements Closeable, 
Supplier<GetDataStream> {
+  private final Endpoint endpoint;
+  private final Supplier<GetDataStream> delegate;
+  private volatile boolean started;
+
+  GlobalDataStreamSender(Supplier<GetDataStream> delegate, Endpoint endpoint) {
+    // Ensures that the Supplier is thread-safe
+    this.delegate = Suppliers.memoize(delegate::get);
+    this.started = false;
+    this.endpoint = endpoint;
+  }
+
+  @Override
+  public GetDataStream get() {
+    if (!started) {
+      started = true;
+    }
+
+    return delegate.get();
+  }
+
+  @Override
+  public void close() {
+    if (started) {
+      delegate.get().shutdown();
+    }
+  }
+
+  Endpoint endpoint() {
+    return endpoint;
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java
similarity index 55%
rename from 
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java
rename to 
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java
index 3c85ee6abe1..14290b48683 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java
@@ -18,47 +18,37 @@
 package org.apache.beam.runners.dataflow.worker.streaming.harness;
 
 import com.google.auto.value.AutoValue;
-import java.util.function.Supplier;
-import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
 import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
-import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 
 /**
- * Represents the current state of connections to Streaming Engine. 
Connections are updated when
- * backend workers assigned to the key ranges being processed by this user 
worker change during
+ * Represents the current state of connections to the Streaming Engine 
backend. Backends are updated
+ * when backend workers assigned to the key ranges being processed by this 
user worker change during
  * pipeline execution. For example, changes can happen via autoscaling, 
load-balancing, or other
  * backend updates.
  */
 @AutoValue
-abstract class StreamingEngineConnectionState {
-  static final StreamingEngineConnectionState EMPTY = builder().build();
+abstract class StreamingEngineBackends {
+  static final StreamingEngineBackends EMPTY = builder().build();
 
   static Builder builder() {
-    return new AutoValue_StreamingEngineConnectionState.Builder()
-        .setWindmillConnections(ImmutableMap.of())
+    return new AutoValue_StreamingEngineBackends.Builder()
         .setWindmillStreams(ImmutableMap.of())
         .setGlobalDataStreams(ImmutableMap.of());
   }
 
-  abstract ImmutableMap<Endpoint, WindmillConnection> windmillConnections();
-
-  abstract ImmutableMap<WindmillConnection, WindmillStreamSender> 
windmillStreams();
+  abstract ImmutableMap<Endpoint, WindmillStreamSender> windmillStreams();
 
   /** Mapping of GlobalDataIds and the direct GetDataStreams used fetch them. 
*/
-  abstract ImmutableMap<String, Supplier<GetDataStream>> globalDataStreams();
+  abstract ImmutableMap<String, GlobalDataStreamSender> globalDataStreams();
 
   @AutoValue.Builder
   abstract static class Builder {
-    public abstract Builder setWindmillConnections(
-        ImmutableMap<Endpoint, WindmillConnection> value);
-
-    public abstract Builder setWindmillStreams(
-        ImmutableMap<WindmillConnection, WindmillStreamSender> value);
+    public abstract Builder setWindmillStreams(ImmutableMap<Endpoint, 
WindmillStreamSender> value);
 
     public abstract Builder setGlobalDataStreams(
-        ImmutableMap<String, Supplier<GetDataStream>> value);
+        ImmutableMap<String, GlobalDataStreamSender> value);
 
-    public abstract StreamingEngineConnectionState build();
+    public abstract StreamingEngineBackends build();
   }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java
index 45aa403ee71..744c3d74445 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.runners.dataflow.worker.streaming.harness;
 
+import java.io.Closeable;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Function;
@@ -49,7 +50,7 @@ import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers
  * {@link GetWorkBudget} is set.
  *
  * <p>Once started, the underlying streams are "alive" until they are manually 
closed via {@link
- * #closeAllStreams()}.
+ * #close()} ()}.
  *
  * <p>If closed, it means that the backend endpoint is no longer in the worker 
set. Once closed,
  * these instances are not reused.
@@ -59,7 +60,7 @@ import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers
  */
 @Internal
 @ThreadSafe
-final class WindmillStreamSender implements GetWorkBudgetSpender {
+final class WindmillStreamSender implements GetWorkBudgetSpender, Closeable {
   private final AtomicBoolean started;
   private final AtomicReference<GetWorkBudget> getWorkBudget;
   private final Supplier<GetWorkStream> getWorkStream;
@@ -103,9 +104,9 @@ final class WindmillStreamSender implements 
GetWorkBudgetSpender {
                     connection,
                     withRequestBudget(getWorkRequest, getWorkBudget.get()),
                     streamingEngineThrottleTimers.getWorkThrottleTimer(),
-                    () -> 
FixedStreamHeartbeatSender.create(getDataStream.get()),
-                    () -> getDataClientFactory.apply(getDataStream.get()),
-                    workCommitter,
+                    FixedStreamHeartbeatSender.create(getDataStream.get()),
+                    getDataClientFactory.apply(getDataStream.get()),
+                    workCommitter.get(),
                     workItemScheduler));
   }
 
@@ -141,7 +142,8 @@ final class WindmillStreamSender implements 
GetWorkBudgetSpender {
     started.set(true);
   }
 
-  void closeAllStreams() {
+  @Override
+  public void close() {
     // Supplier<Stream>.get() starts the stream which is an expensive 
operation as it initiates the
     // streaming RPCs by possibly making calls over the network. Do not close 
the streams unless
     // they have already been started.
@@ -154,18 +156,13 @@ final class WindmillStreamSender implements 
GetWorkBudgetSpender {
   }
 
   @Override
-  public void adjustBudget(long itemsDelta, long bytesDelta) {
-    getWorkBudget.set(getWorkBudget.get().apply(itemsDelta, bytesDelta));
+  public void setBudget(long items, long bytes) {
+    getWorkBudget.set(getWorkBudget.get().apply(items, bytes));
     if (started.get()) {
-      getWorkStream.get().adjustBudget(itemsDelta, bytesDelta);
+      getWorkStream.get().setBudget(items, bytes);
     }
   }
 
-  @Override
-  public GetWorkBudget remainingBudget() {
-    return started.get() ? getWorkStream.get().remainingBudget() : 
getWorkBudget.get();
-  }
-
   long getAndResetThrottleTime() {
     return streamingEngineThrottleTimers.getAndResetThrottleTime();
   }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java
index d7ed83def43..eb269eef848 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java
@@ -17,8 +17,8 @@
  */
 package org.apache.beam.runners.dataflow.worker.windmill;
 
-import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
 import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet;
 
 import com.google.auto.value.AutoValue;
 import java.net.Inet6Address;
@@ -27,8 +27,8 @@ import java.net.UnknownHostException;
 import java.util.Map;
 import java.util.Optional;
 import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress.AuthenticatedGcpServiceAddress;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -41,6 +41,14 @@ import org.slf4j.LoggerFactory;
 public abstract class WindmillEndpoints {
   private static final Logger LOG = 
LoggerFactory.getLogger(WindmillEndpoints.class);
 
+  public static WindmillEndpoints none() {
+    return WindmillEndpoints.builder()
+        .setVersion(Long.MAX_VALUE)
+        .setWindmillEndpoints(ImmutableSet.of())
+        .setGlobalDataEndpoints(ImmutableMap.of())
+        .build();
+  }
+
   public static WindmillEndpoints from(
       Windmill.WorkerMetadataResponse workerMetadataResponseProto) {
     ImmutableMap<String, WindmillEndpoints.Endpoint> globalDataServers =
@@ -53,14 +61,15 @@ public abstract class WindmillEndpoints {
                             endpoint.getValue(),
                             
workerMetadataResponseProto.getExternalEndpoint())));
 
-    ImmutableList<WindmillEndpoints.Endpoint> windmillServers =
+    ImmutableSet<WindmillEndpoints.Endpoint> windmillServers =
         workerMetadataResponseProto.getWorkEndpointsList().stream()
             .map(
                 endpointProto ->
                     Endpoint.from(endpointProto, 
workerMetadataResponseProto.getExternalEndpoint()))
-            .collect(toImmutableList());
+            .collect(toImmutableSet());
 
     return WindmillEndpoints.builder()
+        .setVersion(workerMetadataResponseProto.getMetadataVersion())
         .setGlobalDataEndpoints(globalDataServers)
         .setWindmillEndpoints(windmillServers)
         .build();
@@ -123,6 +132,9 @@ public abstract class WindmillEndpoints {
             directEndpointAddress.getHostAddress(), (int) 
endpointProto.getPort()));
   }
 
+  /** Version of the endpoints which increases with every modification. */
+  public abstract long version();
+
   /**
    * Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns 
a map where the key
    * is a global data tag and the value is the endpoint where the data 
associated with the global
@@ -138,7 +150,7 @@ public abstract class WindmillEndpoints {
    * Windmill servers. Returns a list of endpoints used to communicate with 
the corresponding
    * Windmill servers.
    */
-  public abstract ImmutableList<Endpoint> windmillEndpoints();
+  public abstract ImmutableSet<Endpoint> windmillEndpoints();
 
   /**
    * Representation of an endpoint in {@link 
Windmill.WorkerMetadataResponse.Endpoint} proto with
@@ -204,13 +216,15 @@ public abstract class WindmillEndpoints {
 
   @AutoValue.Builder
   public abstract static class Builder {
+    public abstract Builder setVersion(long version);
+
     public abstract Builder setGlobalDataEndpoints(
         ImmutableMap<String, WindmillEndpoints.Endpoint> globalDataServers);
 
     public abstract Builder setWindmillEndpoints(
-        ImmutableList<WindmillEndpoints.Endpoint> windmillServers);
+        ImmutableSet<WindmillEndpoints.Endpoint> windmillServers);
 
-    abstract ImmutableList.Builder<WindmillEndpoints.Endpoint> 
windmillEndpointsBuilder();
+    abstract ImmutableSet.Builder<WindmillEndpoints.Endpoint> 
windmillEndpointsBuilder();
 
     public final Builder addWindmillEndpoint(WindmillEndpoints.Endpoint 
endpoint) {
       windmillEndpointsBuilder().add(endpoint);
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java
index 90f93b07267..0b895652efe 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java
@@ -19,38 +19,36 @@ package org.apache.beam.runners.dataflow.worker.windmill;
 
 import com.google.auto.value.AutoOneOf;
 import com.google.auto.value.AutoValue;
-import java.net.Inet6Address;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
 
 /** Used to create channels to communicate with Streaming Engine via gRpc. */
 @AutoOneOf(WindmillServiceAddress.Kind.class)
 public abstract class WindmillServiceAddress {
-  public static WindmillServiceAddress create(Inet6Address ipv6Address) {
-    return AutoOneOf_WindmillServiceAddress.ipv6(ipv6Address);
-  }
 
   public static WindmillServiceAddress create(HostAndPort gcpServiceAddress) {
     return 
AutoOneOf_WindmillServiceAddress.gcpServiceAddress(gcpServiceAddress);
   }
 
-  public abstract Kind getKind();
+  public static WindmillServiceAddress create(
+      AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) {
+    return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress(
+        authenticatedGcpServiceAddress);
+  }
 
-  public abstract Inet6Address ipv6();
+  public abstract Kind getKind();
 
   public abstract HostAndPort gcpServiceAddress();
 
   public abstract AuthenticatedGcpServiceAddress 
authenticatedGcpServiceAddress();
 
-  public static WindmillServiceAddress create(
-      AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) {
-    return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress(
-        authenticatedGcpServiceAddress);
+  public final HostAndPort getServiceAddress() {
+    return getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS
+        ? gcpServiceAddress()
+        : authenticatedGcpServiceAddress().gcpServiceAddress();
   }
 
   public enum Kind {
-    IPV6,
     GCP_SERVICE_ADDRESS,
-    // TODO(m-trieu): Use for direct connections when ALTS is enabled.
     AUTHENTICATED_GCP_SERVICE_ADDRESS
   }
 
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
index 31bd4e146a7..f26c56b14ec 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
@@ -56,10 +56,11 @@ public interface WindmillStream {
   @ThreadSafe
   interface GetWorkStream extends WindmillStream {
     /** Adjusts the {@link GetWorkBudget} for the stream. */
-    void adjustBudget(long itemsDelta, long bytesDelta);
+    void setBudget(GetWorkBudget newBudget);
 
-    /** Returns the remaining in-flight {@link GetWorkBudget}. */
-    GetWorkBudget remainingBudget();
+    default void setBudget(long newItems, long newBytes) {
+      
setBudget(GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build());
+    }
   }
 
   /** Interface for streaming GetDataRequests to Windmill. */
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
index 19de998b1da..b27ebc8e9ee 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java
@@ -21,9 +21,11 @@ import java.io.PrintWriter;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.RejectedExecutionException;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Function;
-import java.util.function.Supplier;
+import javax.annotation.concurrent.GuardedBy;
+import net.jcip.annotations.ThreadSafe;
 import org.apache.beam.runners.dataflow.worker.streaming.Watermarks;
 import org.apache.beam.runners.dataflow.worker.streaming.Work;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
@@ -44,8 +46,8 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSe
 import org.apache.beam.sdk.annotations.Internal;
 import org.apache.beam.sdk.util.BackOff;
 import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * Implementation of {@link GetWorkStream} that passes along a specific {@link
@@ -55,9 +57,10 @@ import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers
  * these direct streams are used to facilitate these RPC calls to specific 
backend workers.
  */
 @Internal
-public final class GrpcDirectGetWorkStream
+final class GrpcDirectGetWorkStream
     extends AbstractWindmillStream<StreamingGetWorkRequest, 
StreamingGetWorkResponseChunk>
     implements GetWorkStream {
+  private static final Logger LOG = 
LoggerFactory.getLogger(GrpcDirectGetWorkStream.class);
   private static final StreamingGetWorkRequest HEALTH_CHECK_REQUEST =
       StreamingGetWorkRequest.newBuilder()
           .setRequestExtension(
@@ -67,15 +70,14 @@ public final class GrpcDirectGetWorkStream
                   .build())
           .build();
 
-  private final AtomicReference<GetWorkBudget> inFlightBudget;
-  private final AtomicReference<GetWorkBudget> nextBudgetAdjustment;
-  private final AtomicReference<GetWorkBudget> pendingResponseBudget;
-  private final GetWorkRequest request;
+  private final GetWorkBudgetTracker budgetTracker;
+  private final GetWorkRequest requestHeader;
   private final WorkItemScheduler workItemScheduler;
   private final ThrottleTimer getWorkThrottleTimer;
-  private final Supplier<HeartbeatSender> heartbeatSender;
-  private final Supplier<WorkCommitter> workCommitter;
-  private final Supplier<GetDataClient> getDataClient;
+  private final HeartbeatSender heartbeatSender;
+  private final WorkCommitter workCommitter;
+  private final GetDataClient getDataClient;
+  private final AtomicReference<StreamingGetWorkRequest> lastRequest;
 
   /**
    * Map of stream IDs to their buffers. Used to aggregate streaming gRPC 
response chunks as they
@@ -92,15 +94,15 @@ public final class GrpcDirectGetWorkStream
               StreamObserver<StreamingGetWorkResponseChunk>,
               StreamObserver<StreamingGetWorkRequest>>
           startGetWorkRpcFn,
-      GetWorkRequest request,
+      GetWorkRequest requestHeader,
       BackOff backoff,
       StreamObserverFactory streamObserverFactory,
       Set<AbstractWindmillStream<?, ?>> streamRegistry,
       int logEveryNStreamFailures,
       ThrottleTimer getWorkThrottleTimer,
-      Supplier<HeartbeatSender> heartbeatSender,
-      Supplier<GetDataClient> getDataClient,
-      Supplier<WorkCommitter> workCommitter,
+      HeartbeatSender heartbeatSender,
+      GetDataClient getDataClient,
+      WorkCommitter workCommitter,
       WorkItemScheduler workItemScheduler) {
     super(
         "GetWorkStream",
@@ -110,19 +112,23 @@ public final class GrpcDirectGetWorkStream
         streamRegistry,
         logEveryNStreamFailures,
         backendWorkerToken);
-    this.request = request;
+    this.requestHeader = requestHeader;
     this.getWorkThrottleTimer = getWorkThrottleTimer;
     this.workItemScheduler = workItemScheduler;
     this.workItemAssemblers = new ConcurrentHashMap<>();
-    this.heartbeatSender = Suppliers.memoize(heartbeatSender::get);
-    this.workCommitter = Suppliers.memoize(workCommitter::get);
-    this.getDataClient = Suppliers.memoize(getDataClient::get);
-    this.inFlightBudget = new AtomicReference<>(GetWorkBudget.noBudget());
-    this.nextBudgetAdjustment = new 
AtomicReference<>(GetWorkBudget.noBudget());
-    this.pendingResponseBudget = new 
AtomicReference<>(GetWorkBudget.noBudget());
+    this.heartbeatSender = heartbeatSender;
+    this.workCommitter = workCommitter;
+    this.getDataClient = getDataClient;
+    this.lastRequest = new AtomicReference<>();
+    this.budgetTracker =
+        new GetWorkBudgetTracker(
+            GetWorkBudget.builder()
+                .setItems(requestHeader.getMaxItems())
+                .setBytes(requestHeader.getMaxBytes())
+                .build());
   }
 
-  public static GrpcDirectGetWorkStream create(
+  static GrpcDirectGetWorkStream create(
       String backendWorkerToken,
       Function<
               StreamObserver<StreamingGetWorkResponseChunk>,
@@ -134,9 +140,9 @@ public final class GrpcDirectGetWorkStream
       Set<AbstractWindmillStream<?, ?>> streamRegistry,
       int logEveryNStreamFailures,
       ThrottleTimer getWorkThrottleTimer,
-      Supplier<HeartbeatSender> heartbeatSender,
-      Supplier<GetDataClient> getDataClient,
-      Supplier<WorkCommitter> workCommitter,
+      HeartbeatSender heartbeatSender,
+      GetDataClient getDataClient,
+      WorkCommitter workCommitter,
       WorkItemScheduler workItemScheduler) {
     GrpcDirectGetWorkStream getWorkStream =
         new GrpcDirectGetWorkStream(
@@ -165,46 +171,52 @@ public final class GrpcDirectGetWorkStream
         .build();
   }
 
-  private void sendRequestExtension(GetWorkBudget adjustment) {
-    inFlightBudget.getAndUpdate(budget -> budget.apply(adjustment));
-    StreamingGetWorkRequest extension =
-        StreamingGetWorkRequest.newBuilder()
-            .setRequestExtension(
-                Windmill.StreamingGetWorkRequestExtension.newBuilder()
-                    .setMaxItems(adjustment.items())
-                    .setMaxBytes(adjustment.bytes()))
-            .build();
-
-    executor()
-        .execute(
-            () -> {
-              try {
-                send(extension);
-              } catch (IllegalStateException e) {
-                // Stream was closed.
-              }
-            });
+  /**
+   * @implNote Do not lock/synchronize here due to this running on grpc serial 
executor for message
+   *     which can deadlock since we send on the stream beneath the 
synchronization. {@link
+   *     AbstractWindmillStream#send(Object)} is synchronized so the sends are 
already guarded.
+   */
+  private void maybeSendRequestExtension(GetWorkBudget extension) {
+    if (extension.items() > 0 || extension.bytes() > 0) {
+      executeSafely(
+          () -> {
+            StreamingGetWorkRequest request =
+                StreamingGetWorkRequest.newBuilder()
+                    .setRequestExtension(
+                        Windmill.StreamingGetWorkRequestExtension.newBuilder()
+                            .setMaxItems(extension.items())
+                            .setMaxBytes(extension.bytes()))
+                    .build();
+            lastRequest.set(request);
+            budgetTracker.recordBudgetRequested(extension);
+            try {
+              send(request);
+            } catch (IllegalStateException e) {
+              // Stream was closed.
+            }
+          });
+    }
   }
 
   @Override
   protected synchronized void onNewStream() {
     workItemAssemblers.clear();
-    // Add the current in-flight budget to the next adjustment. Only positive 
values are allowed
-    // here
-    // with negatives defaulting to 0, since GetWorkBudgets cannot be created 
with negative values.
-    GetWorkBudget budgetAdjustment = 
nextBudgetAdjustment.get().apply(inFlightBudget.get());
-    inFlightBudget.set(budgetAdjustment);
-    send(
-        StreamingGetWorkRequest.newBuilder()
-            .setRequest(
-                request
-                    .toBuilder()
-                    .setMaxBytes(budgetAdjustment.bytes())
-                    .setMaxItems(budgetAdjustment.items()))
-            .build());
-
-    // We just sent the budget, reset it.
-    nextBudgetAdjustment.set(GetWorkBudget.noBudget());
+    if (!isShutdown()) {
+      budgetTracker.reset();
+      GetWorkBudget initialGetWorkBudget = 
budgetTracker.computeBudgetExtension();
+      StreamingGetWorkRequest request =
+          StreamingGetWorkRequest.newBuilder()
+              .setRequest(
+                  requestHeader
+                      .toBuilder()
+                      .setMaxItems(initialGetWorkBudget.items())
+                      .setMaxBytes(initialGetWorkBudget.bytes())
+                      .build())
+              .build();
+      lastRequest.set(request);
+      budgetTracker.recordBudgetRequested(initialGetWorkBudget);
+      send(request);
+    }
   }
 
   @Override
@@ -216,8 +228,9 @@ public final class GrpcDirectGetWorkStream
   public void appendSpecificHtml(PrintWriter writer) {
     // Number of buffers is same as distinct workers that sent work on this 
stream.
     writer.format(
-        "GetWorkStream: %d buffers, %s inflight budget allowed.",
-        workItemAssemblers.size(), inFlightBudget.get());
+        "GetWorkStream: %d buffers, " + "last sent request: %s; ",
+        workItemAssemblers.size(), lastRequest.get());
+    writer.print(budgetTracker.debugString());
   }
 
   @Override
@@ -235,30 +248,22 @@ public final class GrpcDirectGetWorkStream
   }
 
   private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) {
-    // Record the fact that there are now fewer outstanding messages and bytes 
on the stream.
-    inFlightBudget.updateAndGet(budget -> budget.subtract(1, 
assembledWorkItem.bufferedSize()));
     WorkItem workItem = assembledWorkItem.workItem();
     GetWorkResponseChunkAssembler.ComputationMetadata metadata =
         assembledWorkItem.computationMetadata();
-    pendingResponseBudget.getAndUpdate(budget -> budget.apply(1, 
workItem.getSerializedSize()));
-    try {
-      workItemScheduler.scheduleWork(
-          workItem,
-          createWatermarks(workItem, Preconditions.checkNotNull(metadata)),
-          
createProcessingContext(Preconditions.checkNotNull(metadata.computationId())),
-          assembledWorkItem.latencyAttributions());
-    } finally {
-      pendingResponseBudget.getAndUpdate(budget -> budget.apply(-1, 
-workItem.getSerializedSize()));
-    }
+    workItemScheduler.scheduleWork(
+        workItem,
+        createWatermarks(workItem, metadata),
+        createProcessingContext(metadata.computationId()),
+        assembledWorkItem.latencyAttributions());
+    budgetTracker.recordBudgetReceived(assembledWorkItem.bufferedSize());
+    GetWorkBudget extension = budgetTracker.computeBudgetExtension();
+    maybeSendRequestExtension(extension);
   }
 
   private Work.ProcessingContext createProcessingContext(String computationId) 
{
     return Work.createProcessingContext(
-        computationId,
-        getDataClient.get(),
-        workCommitter.get()::commit,
-        heartbeatSender.get(),
-        backendWorkerToken());
+        computationId, getDataClient, workCommitter::commit, heartbeatSender, 
backendWorkerToken());
   }
 
   @Override
@@ -267,25 +272,110 @@ public final class GrpcDirectGetWorkStream
   }
 
   @Override
-  public void adjustBudget(long itemsDelta, long bytesDelta) {
-    GetWorkBudget adjustment =
-        nextBudgetAdjustment
-            // Get the current value, and reset the nextBudgetAdjustment. This 
will be set again
-            // when adjustBudget is called.
-            .getAndUpdate(unused -> GetWorkBudget.noBudget())
-            .apply(itemsDelta, bytesDelta);
-    sendRequestExtension(adjustment);
+  public void setBudget(GetWorkBudget newBudget) {
+    GetWorkBudget extension = 
budgetTracker.consumeAndComputeBudgetUpdate(newBudget);
+    maybeSendRequestExtension(extension);
   }
 
-  @Override
-  public GetWorkBudget remainingBudget() {
-    // Snapshot the current budgets.
-    GetWorkBudget currentPendingResponseBudget = pendingResponseBudget.get();
-    GetWorkBudget currentNextBudgetAdjustment = nextBudgetAdjustment.get();
-    GetWorkBudget currentInflightBudget = inFlightBudget.get();
-
-    return currentPendingResponseBudget
-        .apply(currentNextBudgetAdjustment)
-        .apply(currentInflightBudget);
+  private void executeSafely(Runnable runnable) {
+    try {
+      executor().execute(runnable);
+    } catch (RejectedExecutionException e) {
+      LOG.debug("{} has been shutdown.", getClass());
+    }
+  }
+
+  /**
+   * Tracks sent, received, max {@link GetWorkBudget} and uses this 
information to generate request
+   * extensions.
+   */
+  @ThreadSafe
+  private static final class GetWorkBudgetTracker {
+
+    @GuardedBy("GetWorkBudgetTracker.this")
+    private GetWorkBudget maxGetWorkBudget;
+
+    @GuardedBy("GetWorkBudgetTracker.this")
+    private long itemsRequested = 0;
+
+    @GuardedBy("GetWorkBudgetTracker.this")
+    private long bytesRequested = 0;
+
+    @GuardedBy("GetWorkBudgetTracker.this")
+    private long itemsReceived = 0;
+
+    @GuardedBy("GetWorkBudgetTracker.this")
+    private long bytesReceived = 0;
+
+    private GetWorkBudgetTracker(GetWorkBudget maxGetWorkBudget) {
+      this.maxGetWorkBudget = maxGetWorkBudget;
+    }
+
+    private synchronized void reset() {
+      itemsRequested = 0;
+      bytesRequested = 0;
+      itemsReceived = 0;
+      bytesReceived = 0;
+    }
+
+    private synchronized String debugString() {
+      return String.format(
+          "max budget: %s; "
+              + "in-flight budget: %s; "
+              + "total budget requested: %s; "
+              + "total budget received: %s.",
+          maxGetWorkBudget, inFlightBudget(), totalRequestedBudget(), 
totalReceivedBudget());
+    }
+
+    /** Consumes the new budget and computes an extension based on the new 
budget. */
+    private synchronized GetWorkBudget 
consumeAndComputeBudgetUpdate(GetWorkBudget newBudget) {
+      maxGetWorkBudget = newBudget;
+      return computeBudgetExtension();
+    }
+
+    private synchronized void recordBudgetRequested(GetWorkBudget 
budgetRequested) {
+      itemsRequested += budgetRequested.items();
+      bytesRequested += budgetRequested.bytes();
+    }
+
+    private synchronized void recordBudgetReceived(long returnedBudget) {
+      itemsReceived++;
+      bytesReceived += returnedBudget;
+    }
+
+    /**
+     * If the outstanding items or bytes limit has gotten too low, top both 
off with a
+     * GetWorkExtension. The goal is to keep the limits relatively close to 
their maximum values
+     * without sending too many extension requests.
+     */
+    private synchronized GetWorkBudget computeBudgetExtension() {
+      // Expected items and bytes can go negative here, since WorkItems 
returned might be larger
+      // than the initially requested budget.
+      long inFlightItems = itemsRequested - itemsReceived;
+      long inFlightBytes = bytesRequested - bytesReceived;
+
+      // Don't send negative budget extensions.
+      long requestBytes = Math.max(0, maxGetWorkBudget.bytes() - 
inFlightBytes);
+      long requestItems = Math.max(0, maxGetWorkBudget.items() - 
inFlightItems);
+
+      return (inFlightItems > requestItems / 2 && inFlightBytes > requestBytes 
/ 2)
+          ? GetWorkBudget.noBudget()
+          : 
GetWorkBudget.builder().setItems(requestItems).setBytes(requestBytes).build();
+    }
+
+    private synchronized GetWorkBudget inFlightBudget() {
+      return GetWorkBudget.builder()
+          .setItems(itemsRequested - itemsReceived)
+          .setBytes(bytesRequested - bytesReceived)
+          .build();
+    }
+
+    private synchronized GetWorkBudget totalRequestedBudget() {
+      return 
GetWorkBudget.builder().setItems(itemsRequested).setBytes(bytesRequested).build();
+    }
+
+    private synchronized GetWorkBudget totalReceivedBudget() {
+      return 
GetWorkBudget.builder().setItems(itemsReceived).setBytes(bytesReceived).build();
+    }
   }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
index 0e9a0c6316e..c99e05a7707 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java
@@ -59,7 +59,7 @@ import org.joda.time.Instant;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-public final class GrpcGetDataStream
+final class GrpcGetDataStream
     extends AbstractWindmillStream<StreamingGetDataRequest, 
StreamingGetDataResponse>
     implements GetDataStream {
   private static final Logger LOG = 
LoggerFactory.getLogger(GrpcGetDataStream.class);
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java
index 09ecbf3f305..a368f3fec23 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java
@@ -194,15 +194,7 @@ final class GrpcGetWorkStream
   }
 
   @Override
-  public void adjustBudget(long itemsDelta, long bytesDelta) {
+  public void setBudget(GetWorkBudget newBudget) {
     // no-op
   }
-
-  @Override
-  public GetWorkBudget remainingBudget() {
-    return GetWorkBudget.builder()
-        .setBytes(request.getMaxBytes() - inflightBytes.get())
-        .setItems(request.getMaxItems() - inflightMessages.get())
-        .build();
-  }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
index 92f031db997..9e6a02d135e 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java
@@ -198,9 +198,9 @@ public class GrpcWindmillStreamFactory implements 
StatusDataProvider {
       WindmillConnection connection,
       GetWorkRequest request,
       ThrottleTimer getWorkThrottleTimer,
-      Supplier<HeartbeatSender> heartbeatSender,
-      Supplier<GetDataClient> getDataClient,
-      Supplier<WorkCommitter> workCommitter,
+      HeartbeatSender heartbeatSender,
+      GetDataClient getDataClient,
+      WorkCommitter workCommitter,
       WorkItemScheduler workItemScheduler) {
     return GrpcDirectGetWorkStream.create(
         connection.backendWorkerToken(),
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java
index 9aec29a3ba4..f0ea2f550a7 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java
@@ -36,7 +36,6 @@ import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPor
 /** Utility class used to create different RPC Channels. */
 public final class WindmillChannelFactory {
   public static final String LOCALHOST = "localhost";
-  private static final int DEFAULT_GRPC_PORT = 443;
   private static final int MAX_REMOTE_TRACE_EVENTS = 100;
 
   private WindmillChannelFactory() {}
@@ -55,8 +54,6 @@ public final class WindmillChannelFactory {
   public static ManagedChannel remoteChannel(
       WindmillServiceAddress windmillServiceAddress, int 
windmillServiceRpcChannelTimeoutSec) {
     switch (windmillServiceAddress.getKind()) {
-      case IPV6:
-        return remoteChannel(windmillServiceAddress.ipv6(), 
windmillServiceRpcChannelTimeoutSec);
       case GCP_SERVICE_ADDRESS:
         return remoteChannel(
             windmillServiceAddress.gcpServiceAddress(), 
windmillServiceRpcChannelTimeoutSec);
@@ -67,7 +64,8 @@ public final class WindmillChannelFactory {
             windmillServiceRpcChannelTimeoutSec);
       default:
         throw new UnsupportedOperationException(
-            "Only IPV6, GCP_SERVICE_ADDRESS, AUTHENTICATED_GCP_SERVICE_ADDRESS 
are supported WindmillServiceAddresses.");
+            "Only GCP_SERVICE_ADDRESS and AUTHENTICATED_GCP_SERVICE_ADDRESS 
are supported"
+                + " WindmillServiceAddresses.");
     }
   }
 
@@ -105,17 +103,6 @@ public final class WindmillChannelFactory {
     }
   }
 
-  public static ManagedChannel remoteChannel(
-      Inet6Address directEndpoint, int windmillServiceRpcChannelTimeoutSec) {
-    try {
-      return createRemoteChannel(
-          NettyChannelBuilder.forAddress(new InetSocketAddress(directEndpoint, 
DEFAULT_GRPC_PORT)),
-          windmillServiceRpcChannelTimeoutSec);
-    } catch (SSLException sslException) {
-      throw new WindmillChannelCreationException(directEndpoint.toString(), 
sslException);
-    }
-  }
-
   @SuppressWarnings("nullness")
   private static ManagedChannel createRemoteChannel(
       NettyChannelBuilder channelBuilder, int 
windmillServiceRpcChannelTimeoutSec)
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java
index 403bb99efb4..8a1ba2556cf 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java
@@ -17,18 +17,11 @@
  */
 package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
 
-import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
-import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.DoubleMath.roundToLong;
 import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath.divide;
 
 import java.math.RoundingMode;
-import java.util.Map;
-import java.util.Map.Entry;
-import java.util.function.Function;
-import java.util.function.Supplier;
 import org.apache.beam.sdk.annotations.Internal;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -36,22 +29,11 @@ import org.slf4j.LoggerFactory;
 @Internal
 final class EvenGetWorkBudgetDistributor implements GetWorkBudgetDistributor {
   private static final Logger LOG = 
LoggerFactory.getLogger(EvenGetWorkBudgetDistributor.class);
-  private final Supplier<GetWorkBudget> activeWorkBudgetSupplier;
-
-  EvenGetWorkBudgetDistributor(Supplier<GetWorkBudget> 
activeWorkBudgetSupplier) {
-    this.activeWorkBudgetSupplier = activeWorkBudgetSupplier;
-  }
-
-  private static boolean isBelowFiftyPercentOfTarget(
-      GetWorkBudget remaining, GetWorkBudget target) {
-    return remaining.items() < roundToLong(target.items() * 0.5, 
RoundingMode.CEILING)
-        || remaining.bytes() < roundToLong(target.bytes() * 0.5, 
RoundingMode.CEILING);
-  }
 
   @Override
   public <T extends GetWorkBudgetSpender> void distributeBudget(
-      ImmutableCollection<T> budgetOwners, GetWorkBudget getWorkBudget) {
-    if (budgetOwners.isEmpty()) {
+      ImmutableCollection<T> budgetSpenders, GetWorkBudget getWorkBudget) {
+    if (budgetSpenders.isEmpty()) {
       LOG.debug("Cannot distribute budget to no owners.");
       return;
     }
@@ -61,38 +43,15 @@ final class EvenGetWorkBudgetDistributor implements 
GetWorkBudgetDistributor {
       return;
     }
 
-    Map<T, GetWorkBudget> desiredBudgets = computeDesiredBudgets(budgetOwners, 
getWorkBudget);
-
-    for (Entry<T, GetWorkBudget> streamAndDesiredBudget : 
desiredBudgets.entrySet()) {
-      GetWorkBudgetSpender getWorkBudgetSpender = 
streamAndDesiredBudget.getKey();
-      GetWorkBudget desired = streamAndDesiredBudget.getValue();
-      GetWorkBudget remaining = getWorkBudgetSpender.remainingBudget();
-      if (isBelowFiftyPercentOfTarget(remaining, desired)) {
-        GetWorkBudget adjustment = desired.subtract(remaining);
-        getWorkBudgetSpender.adjustBudget(adjustment);
-      }
-    }
+    GetWorkBudget budgetPerStream = 
computeDesiredPerStreamBudget(budgetSpenders, getWorkBudget);
+    budgetSpenders.forEach(getWorkBudgetSpender -> 
getWorkBudgetSpender.setBudget(budgetPerStream));
   }
 
-  private <T extends GetWorkBudgetSpender> ImmutableMap<T, GetWorkBudget> 
computeDesiredBudgets(
+  private <T extends GetWorkBudgetSpender> GetWorkBudget 
computeDesiredPerStreamBudget(
       ImmutableCollection<T> streams, GetWorkBudget totalGetWorkBudget) {
-    GetWorkBudget activeWorkBudget = activeWorkBudgetSupplier.get();
-    LOG.info("Current active work budget: {}", activeWorkBudget);
-    // TODO: Fix possibly non-deterministic handing out of budgets.
-    // Rounding up here will drift upwards over the lifetime of the streams.
-    GetWorkBudget budgetPerStream =
-        GetWorkBudget.builder()
-            .setItems(
-                divide(
-                    totalGetWorkBudget.items() - activeWorkBudget.items(),
-                    streams.size(),
-                    RoundingMode.CEILING))
-            .setBytes(
-                divide(
-                    totalGetWorkBudget.bytes() - activeWorkBudget.bytes(),
-                    streams.size(),
-                    RoundingMode.CEILING))
-            .build();
-    return streams.stream().collect(toImmutableMap(Function.identity(), unused 
-> budgetPerStream));
+    return GetWorkBudget.builder()
+        .setItems(divide(totalGetWorkBudget.items(), streams.size(), 
RoundingMode.CEILING))
+        .setBytes(divide(totalGetWorkBudget.bytes(), streams.size(), 
RoundingMode.CEILING))
+        .build();
   }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java
index 43c0d46139d..2013c9ff1cb 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java
@@ -17,13 +17,11 @@
  */
 package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
 
-import java.util.function.Supplier;
 import org.apache.beam.sdk.annotations.Internal;
 
 @Internal
 public final class GetWorkBudgetDistributors {
-  public static GetWorkBudgetDistributor distributeEvenly(
-      Supplier<GetWorkBudget> activeWorkBudgetSupplier) {
-    return new EvenGetWorkBudgetDistributor(activeWorkBudgetSupplier);
+  public static GetWorkBudgetDistributor distributeEvenly() {
+    return new EvenGetWorkBudgetDistributor();
   }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java
index 254b2589062..decf101a641 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java
@@ -22,11 +22,9 @@ package 
org.apache.beam.runners.dataflow.worker.windmill.work.budget;
  * org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget}
  */
 public interface GetWorkBudgetSpender {
-  void adjustBudget(long itemsDelta, long bytesDelta);
+  void setBudget(long items, long bytes);
 
-  default void adjustBudget(GetWorkBudget adjustment) {
-    adjustBudget(adjustment.items(), adjustment.bytes());
+  default void setBudget(GetWorkBudget budget) {
+    setBudget(budget.items(), budget.bytes());
   }
-
-  GetWorkBudget remainingBudget();
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
index b3f7467cdbd..90ffb3d3fbc 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
@@ -245,18 +245,10 @@ public final class FakeWindmillServer extends 
WindmillServerStub {
       }
 
       @Override
-      public void adjustBudget(long itemsDelta, long bytesDelta) {
+      public void setBudget(GetWorkBudget newBudget) {
         // no-op.
       }
 
-      @Override
-      public GetWorkBudget remainingBudget() {
-        return GetWorkBudget.builder()
-            .setItems(request.getMaxItems())
-            .setBytes(request.getMaxBytes())
-            .build();
-      }
-
       @Override
       public boolean awaitTermination(int time, TimeUnit unit) throws 
InterruptedException {
         while (done.getCount() > 0) {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java
index ed8815c48e7..0092fcc7bcd 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java
@@ -30,9 +30,7 @@ import static org.mockito.Mockito.verify;
 
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Comparator;
 import java.util.HashSet;
-import java.util.List;
 import java.util.Set;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.Executors;
@@ -46,7 +44,6 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse;
-import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
 import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
 import 
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker;
@@ -71,7 +68,6 @@ import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Precondit
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
 import org.junit.After;
 import org.junit.Before;
@@ -92,7 +88,6 @@ public class FanOutStreamingEngineWorkerHarnessTest {
               
.setDirectEndpoint(DEFAULT_WINDMILL_SERVICE_ADDRESS.gcpServiceAddress().toString())
               .build());
 
-  private static final long CLIENT_ID = 1L;
   private static final String JOB_ID = "jobId";
   private static final String PROJECT_ID = "projectId";
   private static final String WORKER_ID = "workerId";
@@ -101,6 +96,7 @@ public class FanOutStreamingEngineWorkerHarnessTest {
           .setJobId(JOB_ID)
           .setProjectId(PROJECT_ID)
           .setWorkerId(WORKER_ID)
+          .setClientId(1L)
           .build();
 
   @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
@@ -134,7 +130,7 @@ public class FanOutStreamingEngineWorkerHarnessTest {
         .setJobId(JOB_ID)
         .setProjectId(PROJECT_ID)
         .setWorkerId(WORKER_ID)
-        .setClientId(CLIENT_ID)
+        .setClientId(JOB_HEADER.getClientId())
         .setMaxItems(items)
         .setMaxBytes(bytes)
         .build();
@@ -174,7 +170,7 @@ public class FanOutStreamingEngineWorkerHarnessTest {
     stubFactory.shutdown();
   }
 
-  private FanOutStreamingEngineWorkerHarness newStreamingEngineClient(
+  private FanOutStreamingEngineWorkerHarness 
newFanOutStreamingEngineWorkerHarness(
       GetWorkBudget getWorkBudget,
       GetWorkBudgetDistributor getWorkBudgetDistributor,
       WorkItemScheduler workItemScheduler) {
@@ -186,7 +182,6 @@ public class FanOutStreamingEngineWorkerHarnessTest {
         stubFactory,
         getWorkBudgetDistributor,
         dispatcherClient,
-        CLIENT_ID,
         ignored -> mock(WorkCommitter.class),
         new ThrottlingGetDataMetricTracker(mock(MemoryMonitor.class)));
   }
@@ -201,7 +196,7 @@ public class FanOutStreamingEngineWorkerHarnessTest {
         spy(new TestGetWorkBudgetDistributor(numBudgetDistributionsExpected));
 
     fanOutStreamingEngineWorkProvider =
-        newStreamingEngineClient(
+        newFanOutStreamingEngineWorkerHarness(
             GetWorkBudget.builder().setItems(items).setBytes(bytes).build(),
             getWorkBudgetDistributor,
             noOpProcessWorkItemFn());
@@ -219,16 +214,14 @@ public class FanOutStreamingEngineWorkerHarnessTest {
 
     getWorkerMetadataReady.await();
     fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
-    waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
+    assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());
 
-    StreamingEngineConnectionState currentConnections =
-        fanOutStreamingEngineWorkProvider.getCurrentConnections();
+    StreamingEngineBackends currentBackends = 
fanOutStreamingEngineWorkProvider.currentBackends();
 
-    assertEquals(2, currentConnections.windmillConnections().size());
-    assertEquals(2, currentConnections.windmillStreams().size());
+    assertEquals(2, currentBackends.windmillStreams().size());
     Set<String> workerTokens =
-        currentConnections.windmillConnections().values().stream()
-            .map(WindmillConnection::backendWorkerToken)
+        currentBackends.windmillStreams().keySet().stream()
+            .map(endpoint -> 
endpoint.workerToken().orElseThrow(IllegalStateException::new))
             .collect(Collectors.toSet());
 
     assertTrue(workerTokens.contains(workerToken));
@@ -252,27 +245,6 @@ public class FanOutStreamingEngineWorkerHarnessTest {
     verify(streamFactory, times(2)).createCommitWorkStream(any(), any());
   }
 
-  @Test
-  public void testScheduledBudgetRefresh() throws InterruptedException {
-    TestGetWorkBudgetDistributor getWorkBudgetDistributor =
-        spy(new TestGetWorkBudgetDistributor(2));
-    fanOutStreamingEngineWorkProvider =
-        newStreamingEngineClient(
-            GetWorkBudget.builder().setItems(1L).setBytes(1L).build(),
-            getWorkBudgetDistributor,
-            noOpProcessWorkItemFn());
-
-    getWorkerMetadataReady.await();
-    fakeGetWorkerMetadataStub.injectWorkerMetadata(
-        WorkerMetadataResponse.newBuilder()
-            .setMetadataVersion(1)
-            .addWorkEndpoints(metadataResponseEndpoint("workerToken"))
-            .putAllGlobalDataEndpoints(DEFAULT)
-            .build());
-    waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
-    verify(getWorkBudgetDistributor, atLeast(2)).distributeBudget(any(), 
any());
-  }
-
   @Test
   public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers()
       throws InterruptedException {
@@ -280,7 +252,7 @@ public class FanOutStreamingEngineWorkerHarnessTest {
     TestGetWorkBudgetDistributor getWorkBudgetDistributor =
         spy(new TestGetWorkBudgetDistributor(metadataCount));
     fanOutStreamingEngineWorkProvider =
-        newStreamingEngineClient(
+        newFanOutStreamingEngineWorkerHarness(
             GetWorkBudget.builder().setItems(1).setBytes(1).build(),
             getWorkBudgetDistributor,
             noOpProcessWorkItemFn());
@@ -309,32 +281,28 @@ public class FanOutStreamingEngineWorkerHarnessTest {
                 WorkerMetadataResponse.Endpoint.newBuilder()
                     .setBackendWorkerToken(workerToken3)
                     .build())
-            .putAllGlobalDataEndpoints(DEFAULT)
             .build();
 
     getWorkerMetadataReady.await();
     fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
     fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata);
-    waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
-    StreamingEngineConnectionState currentConnections =
-        fanOutStreamingEngineWorkProvider.getCurrentConnections();
-    assertEquals(1, currentConnections.windmillConnections().size());
-    assertEquals(1, currentConnections.windmillStreams().size());
+    assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());
+    StreamingEngineBackends currentBackends = 
fanOutStreamingEngineWorkProvider.currentBackends();
+    assertEquals(1, currentBackends.windmillStreams().size());
     Set<String> workerTokens =
-        
fanOutStreamingEngineWorkProvider.getCurrentConnections().windmillConnections().values()
-            .stream()
-            .map(WindmillConnection::backendWorkerToken)
+        
fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().keySet().stream()
+            .map(endpoint -> 
endpoint.workerToken().orElseThrow(IllegalStateException::new))
             .collect(Collectors.toSet());
 
     assertFalse(workerTokens.contains(workerToken));
     assertFalse(workerTokens.contains(workerToken2));
+    assertTrue(currentBackends.globalDataStreams().isEmpty());
   }
 
   @Test
   public void testOnNewWorkerMetadata_redistributesBudget() throws 
InterruptedException {
     String workerToken = "workerToken1";
     String workerToken2 = "workerToken2";
-    String workerToken3 = "workerToken3";
 
     WorkerMetadataResponse firstWorkerMetadata =
         WorkerMetadataResponse.newBuilder()
@@ -354,42 +322,24 @@ public class FanOutStreamingEngineWorkerHarnessTest {
                     .build())
             .putAllGlobalDataEndpoints(DEFAULT)
             .build();
-    WorkerMetadataResponse thirdWorkerMetadata =
-        WorkerMetadataResponse.newBuilder()
-            .setMetadataVersion(3)
-            .addWorkEndpoints(
-                WorkerMetadataResponse.Endpoint.newBuilder()
-                    .setBackendWorkerToken(workerToken3)
-                    .build())
-            .putAllGlobalDataEndpoints(DEFAULT)
-            .build();
-
-    List<WorkerMetadataResponse> workerMetadataResponses =
-        Lists.newArrayList(firstWorkerMetadata, secondWorkerMetadata, 
thirdWorkerMetadata);
 
     TestGetWorkBudgetDistributor getWorkBudgetDistributor =
-        spy(new TestGetWorkBudgetDistributor(workerMetadataResponses.size()));
+        spy(new TestGetWorkBudgetDistributor(1));
     fanOutStreamingEngineWorkProvider =
-        newStreamingEngineClient(
+        newFanOutStreamingEngineWorkerHarness(
             GetWorkBudget.builder().setItems(1).setBytes(1).build(),
             getWorkBudgetDistributor,
             noOpProcessWorkItemFn());
 
     getWorkerMetadataReady.await();
 
-    // Make sure we are injecting the metadata from smallest to largest.
-    workerMetadataResponses.stream()
-        
.sorted(Comparator.comparingLong(WorkerMetadataResponse::getMetadataVersion))
-        .forEach(fakeGetWorkerMetadataStub::injectWorkerMetadata);
-
-    waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor);
-    verify(getWorkBudgetDistributor, atLeast(workerMetadataResponses.size()))
-        .distributeBudget(any(), any());
-  }
+    fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata);
+    assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());
+    getWorkBudgetDistributor.expectNumDistributions(1);
+    fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata);
+    assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution());
 
-  private void waitForWorkerMetadataToBeConsumed(
-      TestGetWorkBudgetDistributor getWorkBudgetDistributor) throws 
InterruptedException {
-    getWorkBudgetDistributor.waitForBudgetDistribution();
+    verify(getWorkBudgetDistributor, times(2)).distributeBudget(any(), any());
   }
 
   private static class GetWorkerMetadataTestStub
@@ -434,21 +384,24 @@ public class FanOutStreamingEngineWorkerHarnessTest {
   }
 
   private static class TestGetWorkBudgetDistributor implements 
GetWorkBudgetDistributor {
-    private final CountDownLatch getWorkBudgetDistributorTriggered;
+    private CountDownLatch getWorkBudgetDistributorTriggered;
 
     private TestGetWorkBudgetDistributor(int numBudgetDistributionsExpected) {
       this.getWorkBudgetDistributorTriggered = new 
CountDownLatch(numBudgetDistributionsExpected);
     }
 
-    @SuppressWarnings("ReturnValueIgnored")
-    private void waitForBudgetDistribution() throws InterruptedException {
-      getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS);
+    private boolean waitForBudgetDistribution() throws InterruptedException {
+      return getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS);
+    }
+
+    private void expectNumDistributions(int numBudgetDistributionsExpected) {
+      this.getWorkBudgetDistributorTriggered = new 
CountDownLatch(numBudgetDistributionsExpected);
     }
 
     @Override
     public <T extends GetWorkBudgetSpender> void distributeBudget(
         ImmutableCollection<T> streams, GetWorkBudget getWorkBudget) {
-      streams.forEach(stream -> stream.adjustBudget(getWorkBudget.items(), 
getWorkBudget.bytes()));
+      streams.forEach(stream -> stream.setBudget(getWorkBudget.items(), 
getWorkBudget.bytes()));
       getWorkBudgetDistributorTriggered.countDown();
     }
   }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java
index dc6cc564105..32d1f573808 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java
@@ -193,7 +193,7 @@ public class WindmillStreamSenderTest {
     WindmillStreamSender windmillStreamSender =
         
newWindmillStreamSender(GetWorkBudget.builder().setBytes(1L).setItems(1L).build());
 
-    windmillStreamSender.closeAllStreams();
+    windmillStreamSender.close();
 
     verifyNoInteractions(streamFactory);
   }
@@ -230,7 +230,7 @@ public class WindmillStreamSenderTest {
             mockStreamFactory);
 
     windmillStreamSender.startStreams();
-    windmillStreamSender.closeAllStreams();
+    windmillStreamSender.close();
 
     verify(mockGetWorkStream).shutdown();
     verify(mockGetDataStream).shutdown();
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java
new file mode 100644
index 00000000000..fd2b3023883
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java
@@ -0,0 +1,405 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+  private static final WorkItemScheduler NO_OP_WORK_ITEM_SCHEDULER =
+      (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {};
+  private static final Windmill.JobHeader TEST_JOB_HEADER =
+      Windmill.JobHeader.newBuilder()
+          .setClientId(1L)
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcDirectGetWorkStreamTest";
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+  private ManagedChannel inProcessChannel;
+  private GrpcDirectGetWorkStream stream;
+
+  private static Windmill.StreamingGetWorkRequestExtension 
extension(GetWorkBudget budget) {
+    return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+        .setMaxItems(budget.items())
+        .setMaxBytes(budget.bytes())
+        .build();
+  }
+
+  private static void assertHeader(
+      Windmill.StreamingGetWorkRequest getWorkRequest, GetWorkBudget 
expectedInitialBudget) {
+    assertTrue(getWorkRequest.hasRequest());
+    assertFalse(getWorkRequest.hasRequestExtension());
+    assertThat(getWorkRequest.getRequest())
+        .isEqualTo(
+            Windmill.GetWorkRequest.newBuilder()
+                .setClientId(TEST_JOB_HEADER.getClientId())
+                .setJobId(TEST_JOB_HEADER.getJobId())
+                .setProjectId(TEST_JOB_HEADER.getProjectId())
+                .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+                .setMaxItems(expectedInitialBudget.items())
+                .setMaxBytes(expectedInitialBudget.bytes())
+                .build());
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+    checkNotNull(stream).shutdown();
+  }
+
+  private GrpcDirectGetWorkStream createGetWorkStream(
+      GetWorkStreamTestStub testStub,
+      GetWorkBudget initialGetWorkBudget,
+      ThrottleTimer throttleTimer,
+      WorkItemScheduler workItemScheduler) {
+    serviceRegistry.addService(testStub);
+    return (GrpcDirectGetWorkStream)
+        GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+            .build()
+            .createDirectGetWorkStream(
+                WindmillConnection.builder()
+                    
.setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel))
+                    .build(),
+                Windmill.GetWorkRequest.newBuilder()
+                    .setClientId(TEST_JOB_HEADER.getClientId())
+                    .setJobId(TEST_JOB_HEADER.getJobId())
+                    .setProjectId(TEST_JOB_HEADER.getProjectId())
+                    .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+                    .setMaxItems(initialGetWorkBudget.items())
+                    .setMaxBytes(initialGetWorkBudget.bytes())
+                    .build(),
+                throttleTimer,
+                mock(HeartbeatSender.class),
+                mock(GetDataClient.class),
+                mock(WorkCommitter.class),
+                workItemScheduler);
+  }
+
+  private Windmill.StreamingGetWorkResponseChunk 
createResponse(Windmill.WorkItem workItem) {
+    return Windmill.StreamingGetWorkResponseChunk.newBuilder()
+        .setStreamId(1L)
+        .setComputationMetadata(
+            Windmill.ComputationWorkItemMetadata.newBuilder()
+                .setComputationId("compId")
+                .setInputDataWatermark(1L)
+                .setDependentRealtimeInputWatermark(1L)
+                .build())
+        .setSerializedWorkItem(workItem.toByteString())
+        .setRemainingBytesForWorkItem(0)
+        .build();
+  }
+
+  @Test
+  public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget()
+      throws InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream =
+        createGetWorkStream(
+            testStub, GetWorkBudget.noBudget(), new ThrottleTimer(), 
NO_OP_WORK_ITEM_SCHEDULER);
+    GetWorkBudget newBudget = 
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+    stream.setBudget(newBudget);
+
+    assertTrue(waitForRequests.await(5, TimeUnit.SECONDS));
+
+    // Header and extension.
+    assertThat(requestObserver.sent()).hasSize(expectedRequests);
+    assertHeader(requestObserver.sent().get(0), GetWorkBudget.noBudget());
+    assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension())
+        .isEqualTo(extension(newBudget));
+  }
+
+  @Test
+  public void testSetBudget_computesAndSendsCorrectExtension_existingBudget()
+      throws InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    GetWorkBudget initialBudget = 
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+    stream =
+        createGetWorkStream(
+            testStub, initialBudget, new ThrottleTimer(), 
NO_OP_WORK_ITEM_SCHEDULER);
+    GetWorkBudget newBudget = 
GetWorkBudget.builder().setItems(100).setBytes(100).build();
+    stream.setBudget(newBudget);
+    GetWorkBudget diff = newBudget.subtract(initialBudget);
+
+    assertTrue(waitForRequests.await(5, TimeUnit.SECONDS));
+
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    // Header and extension.
+    assertThat(requests).hasSize(expectedRequests);
+    assertHeader(requests.get(0), initialBudget);
+    
assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff));
+  }
+
+  @Test
+  public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh()
+      throws InterruptedException {
+    int expectedRequests = 1;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    GetWorkBudget initialBudget =
+        
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build();
+    stream =
+        createGetWorkStream(
+            testStub, initialBudget, new ThrottleTimer(), 
NO_OP_WORK_ITEM_SCHEDULER);
+    
stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build());
+
+    assertTrue(waitForRequests.await(5, TimeUnit.SECONDS));
+
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    // Assert that the extension was never sent, only the header.
+    assertThat(requests).hasSize(expectedRequests);
+    assertHeader(Iterables.getOnlyElement(requests), initialBudget);
+  }
+
+  @Test
+  public void testSetBudget_doesNothingIfStreamShutdown() throws 
InterruptedException {
+    int expectedRequests = 1;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream =
+        createGetWorkStream(
+            testStub, GetWorkBudget.noBudget(), new ThrottleTimer(), 
NO_OP_WORK_ITEM_SCHEDULER);
+    stream.shutdown();
+    stream.setBudget(
+        
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+
+    assertTrue(waitForRequests.await(5, TimeUnit.SECONDS));
+
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    // Assert that the extension was never sent, only the header.
+    assertThat(requests).hasSize(1);
+    assertHeader(Iterables.getOnlyElement(requests), GetWorkBudget.noBudget());
+  }
+
+  @Test
+  public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws 
InterruptedException {
+    int expectedRequests = 2;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    GetWorkBudget initialBudget = 
GetWorkBudget.builder().setItems(1).setBytes(100).build();
+    Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+    stream =
+        createGetWorkStream(
+            testStub,
+            initialBudget,
+            new ThrottleTimer(),
+            (work, watermarks, processingContext, getWorkStreamLatencies) -> {
+              scheduledWorkItems.add(work);
+            });
+    Windmill.WorkItem workItem =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+            .setWorkToken(1L)
+            .setShardingKey(1L)
+            .setCacheToken(1L)
+            .build();
+
+    testStub.injectResponse(createResponse(workItem));
+
+    assertTrue(waitForRequests.await(5, TimeUnit.SECONDS));
+
+    assertThat(scheduledWorkItems).containsExactly(workItem);
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+    long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize();
+
+    assertThat(requests).hasSize(expectedRequests);
+    assertHeader(requests.get(0), initialBudget);
+    assertThat(Iterables.getLast(requests).getRequestExtension())
+        .isEqualTo(
+            extension(
+                GetWorkBudget.builder()
+                    .setItems(1)
+                    .setBytes(initialBudget.bytes() - inFlightBytes)
+                    .build()));
+  }
+
+  @Test
+  public void 
testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh()
+      throws InterruptedException {
+    int expectedRequests = 1;
+    CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+    TestGetWorkRequestObserver requestObserver = new 
TestGetWorkRequestObserver(waitForRequests);
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+    GetWorkBudget initialBudget =
+        
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build();
+    stream =
+        createGetWorkStream(
+            testStub,
+            initialBudget,
+            new ThrottleTimer(),
+            (work, watermarks, processingContext, getWorkStreamLatencies) ->
+                scheduledWorkItems.add(work));
+    Windmill.WorkItem workItem =
+        Windmill.WorkItem.newBuilder()
+            .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+            .setWorkToken(1L)
+            .setShardingKey(1L)
+            .setCacheToken(1L)
+            .build();
+
+    testStub.injectResponse(createResponse(workItem));
+
+    assertTrue(waitForRequests.await(5, TimeUnit.SECONDS));
+
+    assertThat(scheduledWorkItems).containsExactly(workItem);
+    List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+
+    // Assert that the extension was never sent, only the header.
+    assertThat(requests).hasSize(expectedRequests);
+    assertHeader(Iterables.getOnlyElement(requests), initialBudget);
+  }
+
+  @Test
+  public void testOnResponse_stopsThrottling() {
+    ThrottleTimer throttleTimer = new ThrottleTimer();
+    TestGetWorkRequestObserver requestObserver =
+        new TestGetWorkRequestObserver(new CountDownLatch(1));
+    GetWorkStreamTestStub testStub = new 
GetWorkStreamTestStub(requestObserver);
+    stream =
+        createGetWorkStream(
+            testStub, GetWorkBudget.noBudget(), throttleTimer, 
NO_OP_WORK_ITEM_SCHEDULER);
+    stream.startThrottleTimer();
+    assertTrue(throttleTimer.throttled());
+    
testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance());
+    assertFalse(throttleTimer.throttled());
+  }
+
+  private static class GetWorkStreamTestStub
+      extends 
CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase {
+    private final TestGetWorkRequestObserver requestObserver;
+    private @Nullable StreamObserver<Windmill.StreamingGetWorkResponseChunk> 
responseObserver;
+
+    private GetWorkStreamTestStub(TestGetWorkRequestObserver requestObserver) {
+      this.requestObserver = requestObserver;
+    }
+
+    @Override
+    public StreamObserver<Windmill.StreamingGetWorkRequest> getWorkStream(
+        StreamObserver<Windmill.StreamingGetWorkResponseChunk> 
responseObserver) {
+      if (this.responseObserver == null) {
+        this.responseObserver = responseObserver;
+        requestObserver.responseObserver = this.responseObserver;
+      }
+
+      return requestObserver;
+    }
+
+    private void injectResponse(Windmill.StreamingGetWorkResponseChunk 
responseChunk) {
+      checkNotNull(responseObserver).onNext(responseChunk);
+    }
+  }
+
+  private static class TestGetWorkRequestObserver
+      implements StreamObserver<Windmill.StreamingGetWorkRequest> {
+    private final List<Windmill.StreamingGetWorkRequest> requests =
+        Collections.synchronizedList(new ArrayList<>());
+    private final CountDownLatch waitForRequests;
+    private @Nullable volatile 
StreamObserver<Windmill.StreamingGetWorkResponseChunk>
+        responseObserver;
+
+    public TestGetWorkRequestObserver(CountDownLatch waitForRequests) {
+      this.waitForRequests = waitForRequests;
+    }
+
+    @Override
+    public void onNext(Windmill.StreamingGetWorkRequest request) {
+      requests.add(request);
+      waitForRequests.countDown();
+    }
+
+    @Override
+    public void onError(Throwable throwable) {}
+
+    @Override
+    public void onCompleted() {
+      responseObserver.onCompleted();
+    }
+
+    List<Windmill.StreamingGetWorkRequest> sent() {
+      return requests;
+    }
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java
index 3cda4559c10..c76d5a58418 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java
@@ -17,9 +17,7 @@
  */
 package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
 
-import static org.mockito.ArgumentMatchers.anyLong;
 import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -40,169 +38,79 @@ public class EvenGetWorkBudgetDistributorTest {
   @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
   @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
 
-  private static GetWorkBudgetDistributor 
createBudgetDistributor(GetWorkBudget activeWorkBudget) {
-    return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget);
-  }
+  private static GetWorkBudgetSpender createGetWorkBudgetOwner() {
+    // Lambdas are final and cannot be spied.
+    return spy(
+        new GetWorkBudgetSpender() {
 
-  private static GetWorkBudgetDistributor createBudgetDistributor(long 
activeWorkItemsAndBytes) {
-    return createBudgetDistributor(
-        GetWorkBudget.builder()
-            .setItems(activeWorkItemsAndBytes)
-            .setBytes(activeWorkItemsAndBytes)
-            .build());
+          @Override
+          public void setBudget(long items, long bytes) {}
+        });
   }
 
   @Test
   public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() {
-    createBudgetDistributor(1L)
+    GetWorkBudgetDistributors.distributeEvenly()
         .distributeBudget(
             ImmutableList.of(), 
GetWorkBudget.builder().setItems(10L).setBytes(10L).build());
   }
 
   @Test
   public void testDistributeBudget_doesNothingWithNoBudget() {
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()));
-    createBudgetDistributor(1L)
+    GetWorkBudgetSpender getWorkBudgetSpender = createGetWorkBudgetOwner();
+    GetWorkBudgetDistributors.distributeEvenly()
         .distributeBudget(ImmutableList.of(getWorkBudgetSpender), 
GetWorkBudget.noBudget());
     verifyNoInteractions(getWorkBudgetSpender);
   }
 
   @Test
-  public void 
testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighNoActiveWork()
 {
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        spy(
-            createGetWorkBudgetOwnerWithRemainingBudgetOf(
-                GetWorkBudget.builder().setItems(10L).setBytes(10L).build()));
-    createBudgetDistributor(0L)
-        .distributeBudget(
-            ImmutableList.of(getWorkBudgetSpender),
-            GetWorkBudget.builder().setItems(10L).setBytes(10L).build());
-
-    verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong());
-  }
-
-  @Test
-  public void
-      
testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighWithActiveWork()
 {
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        spy(
-            createGetWorkBudgetOwnerWithRemainingBudgetOf(
-                GetWorkBudget.builder().setItems(5L).setBytes(5L).build()));
-    createBudgetDistributor(10L)
+  public void testDistributeBudget_distributesBudgetEvenlyIfPossible() {
+    int totalStreams = 10;
+    long totalItems = 10L;
+    long totalBytes = 100L;
+    List<GetWorkBudgetSpender> streams = new ArrayList<>();
+    for (int i = 0; i < totalStreams; i++) {
+      streams.add(createGetWorkBudgetOwner());
+    }
+    GetWorkBudgetDistributors.distributeEvenly()
         .distributeBudget(
-            ImmutableList.of(getWorkBudgetSpender),
-            GetWorkBudget.builder().setItems(20L).setBytes(20L).build());
-
-    verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong());
-  }
-
-  @Test
-  public void
-      
testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithNoActiveWork()
 {
-    GetWorkBudget streamRemainingBudget =
-        GetWorkBudget.builder().setItems(1L).setBytes(10L).build();
-    GetWorkBudget totalGetWorkBudget = 
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
-    createBudgetDistributor(0L)
-        .distributeBudget(ImmutableList.of(getWorkBudgetSpender), 
totalGetWorkBudget);
-
-    verify(getWorkBudgetSpender, times(1))
-        .adjustBudget(
-            eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
-            eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
-  }
-
-  @Test
-  public void
-      
testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithActiveWork()
 {
-    GetWorkBudget streamRemainingBudget =
-        GetWorkBudget.builder().setItems(1L).setBytes(10L).build();
-    GetWorkBudget totalGetWorkBudget = 
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
-    long activeWorkItemsAndBytes = 2L;
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
-    createBudgetDistributor(activeWorkItemsAndBytes)
-        .distributeBudget(ImmutableList.of(getWorkBudgetSpender), 
totalGetWorkBudget);
-
-    verify(getWorkBudgetSpender, times(1))
-        .adjustBudget(
-            eq(
-                totalGetWorkBudget.items()
-                    - streamRemainingBudget.items()
-                    - activeWorkItemsAndBytes),
-            eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
-  }
-
-  @Test
-  public void 
testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowNoActiveWork()
 {
-    GetWorkBudget streamRemainingBudget =
-        GetWorkBudget.builder().setItems(10L).setBytes(1L).build();
-    GetWorkBudget totalGetWorkBudget = 
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
-    createBudgetDistributor(0L)
-        .distributeBudget(ImmutableList.of(getWorkBudgetSpender), 
totalGetWorkBudget);
-
-    verify(getWorkBudgetSpender, times(1))
-        .adjustBudget(
-            eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
-            eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
-  }
-
-  @Test
-  public void
-      
testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowWithActiveWork()
 {
-    GetWorkBudget streamRemainingBudget =
-        GetWorkBudget.builder().setItems(10L).setBytes(1L).build();
-    GetWorkBudget totalGetWorkBudget = 
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
-    long activeWorkItemsAndBytes = 2L;
-
-    GetWorkBudgetSpender getWorkBudgetSpender =
-        
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
-    createBudgetDistributor(activeWorkItemsAndBytes)
-        .distributeBudget(ImmutableList.of(getWorkBudgetSpender), 
totalGetWorkBudget);
+            ImmutableList.copyOf(streams),
+            
GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build());
 
-    verify(getWorkBudgetSpender, times(1))
-        .adjustBudget(
-            eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
-            eq(
-                totalGetWorkBudget.bytes()
-                    - streamRemainingBudget.bytes()
-                    - activeWorkItemsAndBytes));
+    streams.forEach(
+        stream ->
+            verify(stream, times(1))
+                
.setBudget(eq(GetWorkBudget.builder().setItems(1L).setBytes(10L).build())));
   }
 
   @Test
-  public void testDistributeBudget_distributesBudgetEvenlyIfPossible() {
-    long totalItemsAndBytes = 10L;
+  public void testDistributeBudget_distributesFairlyWhenNotEven() {
+    long totalItems = 10L;
+    long totalBytes = 19L;
     List<GetWorkBudgetSpender> streams = new ArrayList<>();
-    for (int i = 0; i < totalItemsAndBytes; i++) {
-      
streams.add(spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget())));
+    for (int i = 0; i < 3; i++) {
+      streams.add(createGetWorkBudgetOwner());
     }
-    createBudgetDistributor(0L)
+    GetWorkBudgetDistributors.distributeEvenly()
         .distributeBudget(
             ImmutableList.copyOf(streams),
-            GetWorkBudget.builder()
-                .setItems(totalItemsAndBytes)
-                .setBytes(totalItemsAndBytes)
-                .build());
+            
GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build());
 
-    long itemsAndBytesPerStream = totalItemsAndBytes / streams.size();
     streams.forEach(
         stream ->
             verify(stream, times(1))
-                .adjustBudget(eq(itemsAndBytesPerStream), 
eq(itemsAndBytesPerStream)));
+                
.setBudget(eq(GetWorkBudget.builder().setItems(4L).setBytes(7L).build())));
   }
 
   @Test
-  public void testDistributeBudget_distributesFairlyWhenNotEven() {
+  public void testDistributeBudget_distributesBudgetEvenly() {
     long totalItemsAndBytes = 10L;
     List<GetWorkBudgetSpender> streams = new ArrayList<>();
-    for (int i = 0; i < 3; i++) {
-      
streams.add(spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget())));
+    for (int i = 0; i < totalItemsAndBytes; i++) {
+      streams.add(createGetWorkBudgetOwner());
     }
-    createBudgetDistributor(0L)
+
+    GetWorkBudgetDistributors.distributeEvenly()
         .distributeBudget(
             ImmutableList.copyOf(streams),
             GetWorkBudget.builder()
@@ -210,24 +118,10 @@ public class EvenGetWorkBudgetDistributorTest {
                 .setBytes(totalItemsAndBytes)
                 .build());
 
-    long itemsAndBytesPerStream = (long) Math.ceil(totalItemsAndBytes / 
(streams.size() * 1.0));
+    long itemsAndBytesPerStream = totalItemsAndBytes / streams.size();
     streams.forEach(
         stream ->
             verify(stream, times(1))
-                .adjustBudget(eq(itemsAndBytesPerStream), 
eq(itemsAndBytesPerStream)));
-  }
-
-  private GetWorkBudgetSpender createGetWorkBudgetOwnerWithRemainingBudgetOf(
-      GetWorkBudget getWorkBudget) {
-    return spy(
-        new GetWorkBudgetSpender() {
-          @Override
-          public void adjustBudget(long itemsDelta, long bytesDelta) {}
-
-          @Override
-          public GetWorkBudget remainingBudget() {
-            return getWorkBudget;
-          }
-        });
+                .setBudget(eq(itemsAndBytesPerStream), 
eq(itemsAndBytesPerStream)));
   }
 }

Reply via email to