scwhittle commented on code in PR #28835:
URL: https://github.com/apache/beam/pull/28835#discussion_r1399078334


##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java:
##########
@@ -54,7 +55,7 @@
  * RPC streams for health check/heartbeat requests to keep the streams alive.
  */
 @ThreadSafe

Review Comment:
   mark internal



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java:
##########
@@ -0,0 +1,432 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+
+import com.google.errorprone.annotations.CheckReturnValue;

Review Comment:
   does this need to be some vendored lib?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java:
##########
@@ -211,7 +232,7 @@ public void appendSummaryHtml(PrintWriter writer) {
   }
 
   @AutoBuilder(ofClass = GrpcWindmillStreamFactory.class)
-  interface Builder {
+  public interface Builder {

Review Comment:
   mark Internal



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java:
##########
@@ -0,0 +1,432 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+
+import com.google.errorprone.annotations.CheckReturnValue;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Random;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.ProcessWorkItem;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Client for StreamingEngine. Given a {@link GetWorkBudget}, divides the 
budget and starts the
+ * {@link WindmillStream.GetWorkStream}(s).
+ */
+@Internal
+@CheckReturnValue
+@ThreadSafe
+public class StreamingEngineClient {
+  private static final Logger LOG = 
LoggerFactory.getLogger(StreamingEngineClient.class);
+  private static final String CONSUMER_WORKER_METADATA_THREAD = 
"ConsumeWorkerMetadataThread";
+
+  private final AtomicBoolean started;
+  private final JobHeader jobHeader;
+  private final GetWorkBudget totalGetWorkBudget;
+  private final GrpcWindmillStreamFactory streamFactory;
+  private final ProcessWorkItem processWorkItem;
+  private final WindmillStubFactory stubFactory;
+  private final GetWorkBudgetDistributor getWorkBudgetDistributor;
+  private final GrpcDispatcherClient dispatcherClient;
+  private final AtomicBoolean isBudgetRefreshPaused;
+  /** Writes are guarded by synchronization, reads are lock free. */
+  private final AtomicReference<StreamEngineConnectionState> connections;
+
+  private final GetWorkBudgetRefresher getWorkBudgetRefresher;
+  private final AtomicReference<Instant> lastBudgetRefresh;
+  private final ThrottleTimer getWorkerMetadataThrottleTimer;
+  private final CountDownLatch getWorkerMetadataReady;
+  private final ExecutorService consumeWorkerMetadataExecutor;
+  private final long clientId;
+  /**
+   * Reference to {@link GetWorkerMetadataStream} that is lazily initialized 
via double check
+   * locking, with its initial value being null.
+   */
+  private volatile @Nullable GetWorkerMetadataStream getWorkerMetadataStream;
+
+  private StreamingEngineClient(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      GrpcWindmillStreamFactory streamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory stubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient,
+      long clientId) {
+    this.jobHeader = jobHeader;
+    this.totalGetWorkBudget = totalGetWorkBudget;
+    this.started = new AtomicBoolean();
+    this.streamFactory = streamFactory;
+    this.processWorkItem = processWorkItem;
+    this.connections = connections;
+    this.stubFactory = stubFactory;
+    this.getWorkBudgetDistributor = getWorkBudgetDistributor;
+    this.dispatcherClient = dispatcherClient;
+    this.isBudgetRefreshPaused = new AtomicBoolean(false);
+    this.getWorkerMetadataThrottleTimer = new ThrottleTimer();
+    this.consumeWorkerMetadataExecutor =
+        Executors.newSingleThreadScheduledExecutor(
+            new ThreadFactoryBuilder()
+                
.setNameFormat(StreamingEngineClient.CONSUMER_WORKER_METADATA_THREAD)
+                // JVM will be responsible for shutdown and garbage collect 
these threads.
+                .setDaemon(true)
+                .setUncaughtExceptionHandler(
+                    (t, e) ->
+                        LOG.error(
+                            "{} failed due to uncaught exception during 
execution. ",
+                            t.getName(),
+                            e))
+                .build());
+    this.getWorkerMetadataStream = null;
+    this.getWorkerMetadataReady = new CountDownLatch(1);
+    this.clientId = clientId;
+    this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH);
+    this.getWorkBudgetRefresher =
+        new GetWorkBudgetRefresher(
+            isBudgetRefreshPaused::get,
+            () -> {
+              getWorkBudgetDistributor.distributeBudget(
+                  connections.get().windmillStreams().values(), 
totalGetWorkBudget);
+              lastBudgetRefresh.set(Instant.now());
+            });
+  }
+
+  /**
+   * Creates an instance of {@link StreamingEngineClient} and starts the {@link
+   * GetWorkerMetadataStream} with an RPC to the StreamingEngine backend. 
{@link
+   * GetWorkerMetadataStream} will populate {@link #connections} when a 
response is received. Calls
+   * to {@link #startAndCacheStreams()} will block until {@link #connections} 
are populated.
+   */
+  public static StreamingEngineClient create(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      GrpcWindmillStreamFactory streamingEngineStreamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory windmillGrpcStubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient) {
+    StreamingEngineClient streamingEngineClient =
+        new StreamingEngineClient(
+            jobHeader,
+            totalGetWorkBudget,
+            new AtomicReference<>(StreamEngineConnectionState.EMPTY),
+            streamingEngineStreamFactory,
+            processWorkItem,
+            windmillGrpcStubFactory,
+            getWorkBudgetDistributor,
+            dispatcherClient,
+            new Random().nextLong());
+    streamingEngineClient.startGetWorkerMetadataStream();
+    return streamingEngineClient;
+  }
+
+  @VisibleForTesting
+  static StreamingEngineClient forTesting(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      GrpcWindmillStreamFactory streamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory stubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient,
+      long clientId) {
+    StreamingEngineClient streamingEngineClient =
+        new StreamingEngineClient(
+            jobHeader,
+            totalGetWorkBudget,
+            connections,
+            streamFactory,
+            processWorkItem,
+            stubFactory,
+            getWorkBudgetDistributor,
+            dispatcherClient,
+            clientId);
+    streamingEngineClient.startGetWorkerMetadataStream();
+    return streamingEngineClient;
+  }
+
+  /**
+   * Starts the streams with the {@link #connections} values. Does nothing if 
this has already been
+   * called.
+   *
+   * @throws IllegalArgumentException if trying to start before {@link 
#connections} are set with
+   *     {@link GetWorkerMetadataStream}.
+   */
+  public void startAndCacheStreams() {
+    // Do nothing if we have already initialized the initial streams.
+    if (!started.compareAndSet(false, true)) {
+      return;
+    }
+
+    // This will block if we have not received the first response from 
GetWorkerMetadata.
+    waitForFirstStreamingEngineEndpoints();
+    StreamEngineConnectionState currentConnectionsState = connections.get();
+    Preconditions.checkState(
+        !StreamEngineConnectionState.EMPTY.equals(currentConnectionsState),
+        "Cannot start streams without connections.");
+
+    LOG.info("Starting initial GetWorkStreams with connections={}", 
currentConnectionsState);
+    ImmutableCollection<WindmillStreamSender> windmillStreamSenders =
+        currentConnectionsState.windmillStreams().values();
+    getWorkBudgetDistributor.distributeBudget(
+        currentConnectionsState.windmillStreams().values(), 
totalGetWorkBudget);
+    lastBudgetRefresh.compareAndSet(Instant.EPOCH, Instant.now());
+    windmillStreamSenders.forEach(WindmillStreamSender::startStreams);
+    getWorkBudgetRefresher.start();
+  }
+
+  @VisibleForTesting
+  void finish() {
+    if (started.compareAndSet(true, false)) {
+      return;
+    }
+
+    
Optional.ofNullable(getWorkerMetadataStream).ifPresent(GetWorkerMetadataStream::close);
+    getWorkBudgetRefresher.stop();
+    consumeWorkerMetadataExecutor.shutdownNow();
+  }
+
+  private void waitForFirstStreamingEngineEndpoints() {
+    try {
+      getWorkerMetadataReady.await();
+    } catch (InterruptedException e) {
+      throw new StreamingEngineClientException(
+          "Error occurred waiting for StreamingEngine backend endpoints.", e);
+    }
+  }
+
+  /**
+   * {@link java.util.function.Consumer< WindmillEndpoints >} used to update 
{@link #connections} on
+   * new backend worker metadata.
+   */
+  private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints 
newWindmillEndpoints) {
+    isBudgetRefreshPaused.set(true);
+    LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints);
+    ImmutableMap<Endpoint, WindmillConnection> newWindmillConnections =
+        
createNewWindmillConnections(ImmutableSet.copyOf(newWindmillEndpoints.windmillEndpoints()));
+    ImmutableMap<WindmillConnection, WindmillStreamSender> newWindmillStreams =
+        
closeStaleStreamsAndCreateNewStreams(ImmutableSet.copyOf(newWindmillConnections.values()));
+    ImmutableMap<Endpoint, Supplier<GetDataStream>> newGlobalDataStreams =
+        createNewGlobalDataStreams(
+            
ImmutableSet.copyOf(newWindmillEndpoints.globalDataEndpoints().values()));
+
+    StreamEngineConnectionState newConnectionsState =
+        StreamEngineConnectionState.builder()
+            .setWindmillConnections(newWindmillConnections)
+            .setWindmillStreams(newWindmillStreams)
+            .setGlobalDataEndpoints(newWindmillEndpoints.globalDataEndpoints())
+            .setGlobalDataStreams(newGlobalDataStreams)
+            .build();
+
+    LOG.info(
+        "Setting new connections: {}. Previous connections: {}.",
+        newConnectionsState,
+        connections.get());
+    connections.set(newConnectionsState);
+    isBudgetRefreshPaused.set(false);
+
+    // Trigger on first worker metadata. Since startAndCacheStreams will block 
until first worker
+    // metadata is ready. Afterwards, just trigger a budget refresh.
+    if (getWorkerMetadataReady.getCount() > 0) {
+      getWorkerMetadataReady.countDown();
+    } else {
+      getWorkBudgetRefresher.requestBudgetRefresh();
+    }
+  }
+
+  public final ImmutableList<Long> getAndResetThrottleTimes() {
+    StreamEngineConnectionState currentConnections = connections.get();
+
+    ImmutableList<Long> keyedWorkStreamThrottleTimes =
+        currentConnections.windmillStreams().values().stream()
+            .map(WindmillStreamSender::getAndResetThrottleTime)
+            .collect(toImmutableList());
+
+    return ImmutableList.<Long>builder()
+        .add(getWorkerMetadataThrottleTimer.getAndResetThrottleTime())
+        .addAll(keyedWorkStreamThrottleTimes)
+        .build();
+  }
+
+  /** Starts {@link GetWorkerMetadataStream}. */
+  @SuppressWarnings({
+    "FutureReturnValueIgnored", // ignoring Future returned from 
Executor.submit()
+    "nullness" // Uninitialized value of getWorkerMetadataStream is null.
+  })
+  private void startGetWorkerMetadataStream() {
+    // We only want to set and start this value once.
+    if (getWorkerMetadataStream == null) {
+      synchronized (this) {
+        if (getWorkerMetadataStream == null) {
+          getWorkerMetadataStream =
+              streamFactory.createGetWorkerMetadataStream(
+                  dispatcherClient.getDispatcherStub(),
+                  getWorkerMetadataThrottleTimer,
+                  endpoints ->
+                      // Run this on a separate thread than the grpc stream 
thread.

Review Comment:
   if consuming is slower than updates are recieved this will do unnecessary 
processing.
   
   maybe you could use some queue for the executor which just keeps the last 
added element to ignore intermediate?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java:
##########
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
+import javax.annotation.concurrent.ThreadSafe;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.StreamingEngineThrottleTimers;
+import org.apache.beam.runners.dataflow.worker.windmill.work.ProcessWorkItem;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
+
+/**
+ * Owns and maintains a set of streams used to communicate with a specific 
Windmill worker.
+ * Underlying streams are "cached" in a threadsafe manner so that once {@link 
Supplier#get} is
+ * called, a stream that is already started is returned.
+ *
+ * <p>Holds references to {@link
+ * 
Supplier<org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream>}
 because
+ * initializing the streams automatically start them, and we want to do so 
lazily here once the
+ * {@link GetWorkBudget} is set.
+ *
+ * <p>Once started, the underlying streams are "alive" until they are manually 
closed via {@link
+ * #closeAllStreams()}.
+ *
+ * <p>If closed, it means that the backend endpoint is no longer in the worker 
set. Once closed,
+ * these instances are not reused.
+ *
+ * @implNote Does not manage streams for fetching {@link
+ *     org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData} 
for side inputs.
+ */
+@Internal
+@ThreadSafe
+public class WindmillStreamSender {
+  private final AtomicBoolean started;
+  private final AtomicReference<GetWorkBudget> getWorkBudget;
+  private final Supplier<GetWorkStream> getWorkStream;
+  private final Supplier<GetDataStream> getDataStream;
+  private final Supplier<CommitWorkStream> commitWorkStream;
+  private final StreamingEngineThrottleTimers streamingEngineThrottleTimers;
+
+  private WindmillStreamSender(
+      CloudWindmillServiceV1Alpha1Stub stub,
+      GetWorkRequest getWorkRequest,
+      AtomicReference<GetWorkBudget> getWorkBudget,
+      GrpcWindmillStreamFactory streamingEngineStreamFactory,
+      ProcessWorkItem processWorkItem) {
+    this.started = new AtomicBoolean(false);
+    this.getWorkBudget = getWorkBudget;
+    this.streamingEngineThrottleTimers = 
StreamingEngineThrottleTimers.create();
+
+    // All streams are memoized/cached since they are expensive to create and 
some implementations
+    // perform side effects on construction (i.e. sending initial requests to 
the stream server to
+    // initiate the streaming RPC connection). Stream instances 
connect/reconnect internally so we
+    // can reuse the same instance through the entire lifecycle of 
WindmillStreamSender.
+    this.getDataStream =
+        Suppliers.memoize(
+            () ->
+                streamingEngineStreamFactory.createGetDataStream(
+                    stub, 
streamingEngineThrottleTimers.getDataThrottleTimer()));
+    this.commitWorkStream =
+        Suppliers.memoize(
+            () ->
+                streamingEngineStreamFactory.createCommitWorkStream(
+                    stub, 
streamingEngineThrottleTimers.commitWorkThrottleTimer()));
+    this.getWorkStream =
+        Suppliers.memoize(
+            () ->
+                streamingEngineStreamFactory.createDirectGetWorkStream(
+                    stub,
+                    withRequestBudget(getWorkRequest, getWorkBudget.get()),
+                    streamingEngineThrottleTimers.getWorkThrottleTimer(),
+                    getDataStream,
+                    commitWorkStream,
+                    processWorkItem));
+  }
+
+  public static WindmillStreamSender create(
+      CloudWindmillServiceV1Alpha1Stub stub,
+      GetWorkRequest getWorkRequest,
+      GetWorkBudget getWorkBudget,
+      GrpcWindmillStreamFactory streamingEngineStreamFactory,
+      ProcessWorkItem processWorkItem) {
+    return new WindmillStreamSender(
+        stub,
+        getWorkRequest,
+        new AtomicReference<>(getWorkBudget),
+        streamingEngineStreamFactory,
+        processWorkItem);
+  }
+
+  private static GetWorkRequest withRequestBudget(GetWorkRequest request, 
GetWorkBudget budget) {
+    return 
request.toBuilder().setMaxItems(budget.items()).setMaxBytes(budget.bytes()).build();
+  }
+
+  @SuppressWarnings("ReturnValueIgnored")
+  public void startStreams() {
+    Preconditions.checkState(
+        !getWorkBudget.get().equals(GetWorkBudget.noBudget()), "Cannot GetWork 
with no budget.");
+    getWorkStream.get();
+    getDataStream.get();
+    commitWorkStream.get();
+    // *stream.get() is all memoized in a threadsafe manner.
+    started.compareAndSet(false, true);

Review Comment:
   no need to compare, just unconditionally set true



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java:
##########
@@ -0,0 +1,432 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+
+import com.google.errorprone.annotations.CheckReturnValue;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Random;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.ProcessWorkItem;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Client for StreamingEngine. Given a {@link GetWorkBudget}, divides the 
budget and starts the
+ * {@link WindmillStream.GetWorkStream}(s).
+ */
+@Internal
+@CheckReturnValue
+@ThreadSafe
+public class StreamingEngineClient {
+  private static final Logger LOG = 
LoggerFactory.getLogger(StreamingEngineClient.class);
+  private static final String CONSUMER_WORKER_METADATA_THREAD = 
"ConsumeWorkerMetadataThread";
+
+  private final AtomicBoolean started;
+  private final JobHeader jobHeader;
+  private final GetWorkBudget totalGetWorkBudget;
+  private final GrpcWindmillStreamFactory streamFactory;
+  private final ProcessWorkItem processWorkItem;
+  private final WindmillStubFactory stubFactory;
+  private final GetWorkBudgetDistributor getWorkBudgetDistributor;
+  private final GrpcDispatcherClient dispatcherClient;
+  private final AtomicBoolean isBudgetRefreshPaused;
+  /** Writes are guarded by synchronization, reads are lock free. */
+  private final AtomicReference<StreamEngineConnectionState> connections;
+
+  private final GetWorkBudgetRefresher getWorkBudgetRefresher;
+  private final AtomicReference<Instant> lastBudgetRefresh;
+  private final ThrottleTimer getWorkerMetadataThrottleTimer;
+  private final CountDownLatch getWorkerMetadataReady;
+  private final ExecutorService consumeWorkerMetadataExecutor;
+  private final long clientId;
+  /**
+   * Reference to {@link GetWorkerMetadataStream} that is lazily initialized 
via double check
+   * locking, with its initial value being null.
+   */
+  private volatile @Nullable GetWorkerMetadataStream getWorkerMetadataStream;
+
+  private StreamingEngineClient(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      GrpcWindmillStreamFactory streamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory stubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient,
+      long clientId) {
+    this.jobHeader = jobHeader;
+    this.totalGetWorkBudget = totalGetWorkBudget;
+    this.started = new AtomicBoolean();
+    this.streamFactory = streamFactory;
+    this.processWorkItem = processWorkItem;
+    this.connections = connections;
+    this.stubFactory = stubFactory;
+    this.getWorkBudgetDistributor = getWorkBudgetDistributor;
+    this.dispatcherClient = dispatcherClient;
+    this.isBudgetRefreshPaused = new AtomicBoolean(false);
+    this.getWorkerMetadataThrottleTimer = new ThrottleTimer();
+    this.consumeWorkerMetadataExecutor =
+        Executors.newSingleThreadScheduledExecutor(
+            new ThreadFactoryBuilder()
+                
.setNameFormat(StreamingEngineClient.CONSUMER_WORKER_METADATA_THREAD)
+                // JVM will be responsible for shutdown and garbage collect 
these threads.
+                .setDaemon(true)
+                .setUncaughtExceptionHandler(
+                    (t, e) ->
+                        LOG.error(
+                            "{} failed due to uncaught exception during 
execution. ",
+                            t.getName(),
+                            e))
+                .build());
+    this.getWorkerMetadataStream = null;
+    this.getWorkerMetadataReady = new CountDownLatch(1);
+    this.clientId = clientId;
+    this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH);
+    this.getWorkBudgetRefresher =
+        new GetWorkBudgetRefresher(
+            isBudgetRefreshPaused::get,
+            () -> {
+              getWorkBudgetDistributor.distributeBudget(
+                  connections.get().windmillStreams().values(), 
totalGetWorkBudget);
+              lastBudgetRefresh.set(Instant.now());
+            });
+  }
+
+  /**
+   * Creates an instance of {@link StreamingEngineClient} and starts the {@link
+   * GetWorkerMetadataStream} with an RPC to the StreamingEngine backend. 
{@link
+   * GetWorkerMetadataStream} will populate {@link #connections} when a 
response is received. Calls
+   * to {@link #startAndCacheStreams()} will block until {@link #connections} 
are populated.
+   */
+  public static StreamingEngineClient create(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      GrpcWindmillStreamFactory streamingEngineStreamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory windmillGrpcStubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient) {
+    StreamingEngineClient streamingEngineClient =
+        new StreamingEngineClient(
+            jobHeader,
+            totalGetWorkBudget,
+            new AtomicReference<>(StreamEngineConnectionState.EMPTY),
+            streamingEngineStreamFactory,
+            processWorkItem,
+            windmillGrpcStubFactory,
+            getWorkBudgetDistributor,
+            dispatcherClient,
+            new Random().nextLong());
+    streamingEngineClient.startGetWorkerMetadataStream();
+    return streamingEngineClient;
+  }
+
+  @VisibleForTesting
+  static StreamingEngineClient forTesting(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      GrpcWindmillStreamFactory streamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory stubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient,
+      long clientId) {
+    StreamingEngineClient streamingEngineClient =
+        new StreamingEngineClient(
+            jobHeader,
+            totalGetWorkBudget,
+            connections,
+            streamFactory,
+            processWorkItem,
+            stubFactory,
+            getWorkBudgetDistributor,
+            dispatcherClient,
+            clientId);
+    streamingEngineClient.startGetWorkerMetadataStream();
+    return streamingEngineClient;
+  }
+
+  /**
+   * Starts the streams with the {@link #connections} values. Does nothing if 
this has already been
+   * called.
+   *
+   * @throws IllegalArgumentException if trying to start before {@link 
#connections} are set with
+   *     {@link GetWorkerMetadataStream}.
+   */
+  public void startAndCacheStreams() {
+    // Do nothing if we have already initialized the initial streams.
+    if (!started.compareAndSet(false, true)) {
+      return;
+    }
+
+    // This will block if we have not received the first response from 
GetWorkerMetadata.
+    waitForFirstStreamingEngineEndpoints();
+    StreamEngineConnectionState currentConnectionsState = connections.get();
+    Preconditions.checkState(
+        !StreamEngineConnectionState.EMPTY.equals(currentConnectionsState),
+        "Cannot start streams without connections.");
+
+    LOG.info("Starting initial GetWorkStreams with connections={}", 
currentConnectionsState);
+    ImmutableCollection<WindmillStreamSender> windmillStreamSenders =
+        currentConnectionsState.windmillStreams().values();
+    getWorkBudgetDistributor.distributeBudget(
+        currentConnectionsState.windmillStreams().values(), 
totalGetWorkBudget);
+    lastBudgetRefresh.compareAndSet(Instant.EPOCH, Instant.now());
+    windmillStreamSenders.forEach(WindmillStreamSender::startStreams);
+    getWorkBudgetRefresher.start();
+  }
+
+  @VisibleForTesting
+  void finish() {
+    if (started.compareAndSet(true, false)) {
+      return;
+    }
+
+    
Optional.ofNullable(getWorkerMetadataStream).ifPresent(GetWorkerMetadataStream::close);
+    getWorkBudgetRefresher.stop();
+    consumeWorkerMetadataExecutor.shutdownNow();
+  }
+
+  private void waitForFirstStreamingEngineEndpoints() {
+    try {
+      getWorkerMetadataReady.await();
+    } catch (InterruptedException e) {
+      throw new StreamingEngineClientException(
+          "Error occurred waiting for StreamingEngine backend endpoints.", e);
+    }
+  }
+
+  /**
+   * {@link java.util.function.Consumer< WindmillEndpoints >} used to update 
{@link #connections} on
+   * new backend worker metadata.
+   */
+  private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints 
newWindmillEndpoints) {
+    isBudgetRefreshPaused.set(true);
+    LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints);
+    ImmutableMap<Endpoint, WindmillConnection> newWindmillConnections =
+        
createNewWindmillConnections(ImmutableSet.copyOf(newWindmillEndpoints.windmillEndpoints()));
+    ImmutableMap<WindmillConnection, WindmillStreamSender> newWindmillStreams =
+        
closeStaleStreamsAndCreateNewStreams(ImmutableSet.copyOf(newWindmillConnections.values()));
+    ImmutableMap<Endpoint, Supplier<GetDataStream>> newGlobalDataStreams =
+        createNewGlobalDataStreams(
+            
ImmutableSet.copyOf(newWindmillEndpoints.globalDataEndpoints().values()));
+
+    StreamEngineConnectionState newConnectionsState =
+        StreamEngineConnectionState.builder()
+            .setWindmillConnections(newWindmillConnections)
+            .setWindmillStreams(newWindmillStreams)
+            .setGlobalDataEndpoints(newWindmillEndpoints.globalDataEndpoints())
+            .setGlobalDataStreams(newGlobalDataStreams)
+            .build();
+
+    LOG.info(
+        "Setting new connections: {}. Previous connections: {}.",
+        newConnectionsState,
+        connections.get());
+    connections.set(newConnectionsState);
+    isBudgetRefreshPaused.set(false);
+
+    // Trigger on first worker metadata. Since startAndCacheStreams will block 
until first worker
+    // metadata is ready. Afterwards, just trigger a budget refresh.
+    if (getWorkerMetadataReady.getCount() > 0) {
+      getWorkerMetadataReady.countDown();
+    } else {
+      getWorkBudgetRefresher.requestBudgetRefresh();
+    }
+  }
+
+  public final ImmutableList<Long> getAndResetThrottleTimes() {
+    StreamEngineConnectionState currentConnections = connections.get();
+
+    ImmutableList<Long> keyedWorkStreamThrottleTimes =
+        currentConnections.windmillStreams().values().stream()
+            .map(WindmillStreamSender::getAndResetThrottleTime)
+            .collect(toImmutableList());
+
+    return ImmutableList.<Long>builder()
+        .add(getWorkerMetadataThrottleTimer.getAndResetThrottleTime())
+        .addAll(keyedWorkStreamThrottleTimes)
+        .build();
+  }
+
+  /** Starts {@link GetWorkerMetadataStream}. */
+  @SuppressWarnings({
+    "FutureReturnValueIgnored", // ignoring Future returned from 
Executor.submit()
+    "nullness" // Uninitialized value of getWorkerMetadataStream is null.
+  })
+  private void startGetWorkerMetadataStream() {
+    // We only want to set and start this value once.
+    if (getWorkerMetadataStream == null) {
+      synchronized (this) {
+        if (getWorkerMetadataStream == null) {
+          getWorkerMetadataStream =
+              streamFactory.createGetWorkerMetadataStream(
+                  dispatcherClient.getDispatcherStub(),
+                  getWorkerMetadataThrottleTimer,
+                  endpoints ->
+                      // Run this on a separate thread than the grpc stream 
thread.
+                      consumeWorkerMetadataExecutor.submit(
+                          () -> consumeWindmillWorkerEndpoints(endpoints)));
+        }
+      }
+    }
+  }
+
+  private synchronized ImmutableMap<Endpoint, WindmillConnection> 
createNewWindmillConnections(
+      ImmutableSet<Endpoint> newWindmillEndpoints) {
+    ImmutableMap<Endpoint, WindmillConnection> currentWindmillConnections =
+        connections.get().windmillConnections();
+    Map<Endpoint, WindmillConnection> newWindmillConnections =
+        currentWindmillConnections.entrySet().stream()
+            .filter(entry -> newWindmillEndpoints.contains(entry.getKey()))
+            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+
+    // Reuse existing stubs if they exist.
+    newWindmillEndpoints.forEach(
+        newEndpoint ->
+            newWindmillConnections.putIfAbsent(
+                newEndpoint, WindmillConnection.from(newEndpoint, 
this::createWindmillStub)));
+
+    return ImmutableMap.copyOf(newWindmillConnections);
+  }
+
+  private synchronized ImmutableMap<WindmillConnection, WindmillStreamSender>
+      closeStaleStreamsAndCreateNewStreams(
+          ImmutableSet<WindmillConnection> newWindmillConnections) {
+    ImmutableMap<WindmillConnection, WindmillStreamSender> currentStreams =
+        connections.get().windmillStreams();
+
+    // Close the streams that are no longer valid.
+    currentStreams.entrySet().stream()
+        .filter(
+            connectionAndStream -> 
!newWindmillConnections.contains(connectionAndStream.getKey()))
+        .map(Map.Entry::getValue)
+        .forEach(WindmillStreamSender::closeAllStreams);
+
+    return newWindmillConnections.stream()
+        .collect(
+            toImmutableMap(
+                Function.identity(),
+                newConnection ->
+                    Optional.ofNullable(currentStreams.get(newConnection))
+                        .orElseGet(() -> 
createWindmillStreamSenderWithNoBudget(newConnection))));
+  }
+
+  private ImmutableMap<Endpoint, Supplier<GetDataStream>> 
createNewGlobalDataStreams(
+      ImmutableSet<Endpoint> newGlobalDataEndpoints) {
+    ImmutableMap<Endpoint, Supplier<GetDataStream>> currentGlobalDataStreams =
+        connections.get().globalDataStreams();
+
+    return newGlobalDataEndpoints.stream()
+        .map(endpoint -> existingOrNewGetDataStreamFor(endpoint, 
currentGlobalDataStreams))
+        .collect(toImmutableMap(Pair::getKey, Pair::getValue));

Review Comment:
   instead of pairs how about using Function.identity() like above?
   
   newGlobalDataEndpoints.stream().collect(
     toImmutableMap(Function.identity(), 
existingOrNewGetDataStreamFor(endpoint, currentGlobalDataStreams));



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java:
##########
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
+import javax.annotation.concurrent.ThreadSafe;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.StreamingEngineThrottleTimers;
+import org.apache.beam.runners.dataflow.worker.windmill.work.ProcessWorkItem;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
+
+/**
+ * Owns and maintains a set of streams used to communicate with a specific 
Windmill worker.
+ * Underlying streams are "cached" in a threadsafe manner so that once {@link 
Supplier#get} is
+ * called, a stream that is already started is returned.
+ *
+ * <p>Holds references to {@link
+ * 
Supplier<org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream>}
 because
+ * initializing the streams automatically start them, and we want to do so 
lazily here once the
+ * {@link GetWorkBudget} is set.
+ *
+ * <p>Once started, the underlying streams are "alive" until they are manually 
closed via {@link
+ * #closeAllStreams()}.
+ *
+ * <p>If closed, it means that the backend endpoint is no longer in the worker 
set. Once closed,
+ * these instances are not reused.
+ *
+ * @implNote Does not manage streams for fetching {@link
+ *     org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData} 
for side inputs.
+ */
+@Internal
+@ThreadSafe
+public class WindmillStreamSender {
+  private final AtomicBoolean started;
+  private final AtomicReference<GetWorkBudget> getWorkBudget;
+  private final Supplier<GetWorkStream> getWorkStream;
+  private final Supplier<GetDataStream> getDataStream;
+  private final Supplier<CommitWorkStream> commitWorkStream;
+  private final StreamingEngineThrottleTimers streamingEngineThrottleTimers;
+
+  private WindmillStreamSender(
+      CloudWindmillServiceV1Alpha1Stub stub,
+      GetWorkRequest getWorkRequest,
+      AtomicReference<GetWorkBudget> getWorkBudget,
+      GrpcWindmillStreamFactory streamingEngineStreamFactory,
+      ProcessWorkItem processWorkItem) {
+    this.started = new AtomicBoolean(false);
+    this.getWorkBudget = getWorkBudget;
+    this.streamingEngineThrottleTimers = 
StreamingEngineThrottleTimers.create();
+
+    // All streams are memoized/cached since they are expensive to create and 
some implementations
+    // perform side effects on construction (i.e. sending initial requests to 
the stream server to
+    // initiate the streaming RPC connection). Stream instances 
connect/reconnect internally so we
+    // can reuse the same instance through the entire lifecycle of 
WindmillStreamSender.
+    this.getDataStream =
+        Suppliers.memoize(
+            () ->
+                streamingEngineStreamFactory.createGetDataStream(
+                    stub, 
streamingEngineThrottleTimers.getDataThrottleTimer()));
+    this.commitWorkStream =
+        Suppliers.memoize(
+            () ->
+                streamingEngineStreamFactory.createCommitWorkStream(
+                    stub, 
streamingEngineThrottleTimers.commitWorkThrottleTimer()));
+    this.getWorkStream =
+        Suppliers.memoize(
+            () ->
+                streamingEngineStreamFactory.createDirectGetWorkStream(
+                    stub,
+                    withRequestBudget(getWorkRequest, getWorkBudget.get()),
+                    streamingEngineThrottleTimers.getWorkThrottleTimer(),
+                    getDataStream,
+                    commitWorkStream,
+                    processWorkItem));
+  }
+
+  public static WindmillStreamSender create(
+      CloudWindmillServiceV1Alpha1Stub stub,
+      GetWorkRequest getWorkRequest,
+      GetWorkBudget getWorkBudget,
+      GrpcWindmillStreamFactory streamingEngineStreamFactory,
+      ProcessWorkItem processWorkItem) {
+    return new WindmillStreamSender(
+        stub,
+        getWorkRequest,
+        new AtomicReference<>(getWorkBudget),
+        streamingEngineStreamFactory,
+        processWorkItem);
+  }
+
+  private static GetWorkRequest withRequestBudget(GetWorkRequest request, 
GetWorkBudget budget) {
+    return 
request.toBuilder().setMaxItems(budget.items()).setMaxBytes(budget.bytes()).build();
+  }
+
+  @SuppressWarnings("ReturnValueIgnored")
+  public void startStreams() {
+    Preconditions.checkState(
+        !getWorkBudget.get().equals(GetWorkBudget.noBudget()), "Cannot GetWork 
with no budget.");
+    getWorkStream.get();
+    getDataStream.get();
+    commitWorkStream.get();
+    // *stream.get() is all memoized in a threadsafe manner.
+    started.compareAndSet(false, true);
+  }
+
+  public void closeAllStreams() {
+    // Supplier<Stream>.get() starts the stream which is an expensive 
operation as it initiates the
+    // streaming RPCs by possibly making calls over the network. Do not close 
the streams unless
+    // they have already been started.
+    if (started.get()) {
+      getWorkStream.get().close();
+      getDataStream.get().close();
+      commitWorkStream.get().close();
+    }
+  }
+
+  public synchronized void adjustBudget(long itemsDelta, long bytesDelta) {
+    getWorkBudget.set(getWorkBudget.get().apply(itemsDelta, bytesDelta));
+    getWorkStream.get().adjustBudget(itemsDelta, bytesDelta);
+  }
+
+  public synchronized void adjustBudget(GetWorkBudget adjustment) {
+    adjustBudget(adjustment.items(), adjustment.bytes());
+  }
+
+  public synchronized GetWorkBudget remainingGetWorkBudget() {

Review Comment:
   don't think this needs synchronized



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java:
##########
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath.divide;
+
+import java.math.RoundingMode;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.WindmillStreamSender;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Evenly distributes the provided budget across the available {@link 
WindmillStreamSender}(s). */
+@Internal
+public final class EvenGetWorkBudgetDistributor implements 
GetWorkBudgetDistributor {
+  private static final Logger LOG = 
LoggerFactory.getLogger(EvenGetWorkBudgetDistributor.class);
+  private final Supplier<GetWorkBudget> activeWorkBudgetSupplier;
+
+  public EvenGetWorkBudgetDistributor(Supplier<GetWorkBudget> 
activeWorkBudgetSupplier) {
+    this.activeWorkBudgetSupplier = activeWorkBudgetSupplier;
+  }
+
+  private static boolean isBelowFiftyPercentOfTarget(
+      GetWorkBudget remaining, GetWorkBudget target) {
+    return remaining.items() < (target.items() * 0.5) || remaining.bytes() < 
(target.bytes() * 0.5);
+  }
+
+  @Override
+  public void distributeBudget(
+      ImmutableCollection<WindmillStreamSender> streams, GetWorkBudget 
getWorkBudget) {
+    if (streams.isEmpty()) {
+      LOG.debug("Cannot distribute budget to no streams.");
+      return;
+    }
+
+    if (getWorkBudget.equals(GetWorkBudget.noBudget())) {
+      LOG.debug("Cannot distribute 0 budget.");
+      return;
+    }
+
+    Map<WindmillStreamSender, GetWorkBudget> desiredBudgets =
+        computeDesiredBudgets(streams, getWorkBudget);
+
+    for (Entry<WindmillStreamSender, GetWorkBudget> streamAndDesiredBudget :
+        desiredBudgets.entrySet()) {
+      WindmillStreamSender stream = streamAndDesiredBudget.getKey();
+      GetWorkBudget desired = streamAndDesiredBudget.getValue();
+      GetWorkBudget remaining = stream.remainingGetWorkBudget();
+      if (isBelowFiftyPercentOfTarget(remaining, desired)) {
+        GetWorkBudget adjustment = desired.subtract(remaining);
+        LOG.info("Adjusting budget for stream={} by {}", stream, adjustment);

Review Comment:
   this will be too spammy



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java:
##########
@@ -0,0 +1,432 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+
+import com.google.errorprone.annotations.CheckReturnValue;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Random;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.ProcessWorkItem;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Client for StreamingEngine. Given a {@link GetWorkBudget}, divides the 
budget and starts the
+ * {@link WindmillStream.GetWorkStream}(s).
+ */
+@Internal
+@CheckReturnValue
+@ThreadSafe
+public class StreamingEngineClient {
+  private static final Logger LOG = 
LoggerFactory.getLogger(StreamingEngineClient.class);
+  private static final String CONSUMER_WORKER_METADATA_THREAD = 
"ConsumeWorkerMetadataThread";
+
+  private final AtomicBoolean started;
+  private final JobHeader jobHeader;
+  private final GetWorkBudget totalGetWorkBudget;
+  private final GrpcWindmillStreamFactory streamFactory;
+  private final ProcessWorkItem processWorkItem;
+  private final WindmillStubFactory stubFactory;
+  private final GetWorkBudgetDistributor getWorkBudgetDistributor;
+  private final GrpcDispatcherClient dispatcherClient;
+  private final AtomicBoolean isBudgetRefreshPaused;
+  /** Writes are guarded by synchronization, reads are lock free. */
+  private final AtomicReference<StreamEngineConnectionState> connections;
+
+  private final GetWorkBudgetRefresher getWorkBudgetRefresher;
+  private final AtomicReference<Instant> lastBudgetRefresh;
+  private final ThrottleTimer getWorkerMetadataThrottleTimer;
+  private final CountDownLatch getWorkerMetadataReady;
+  private final ExecutorService consumeWorkerMetadataExecutor;
+  private final long clientId;
+  /**
+   * Reference to {@link GetWorkerMetadataStream} that is lazily initialized 
via double check
+   * locking, with its initial value being null.
+   */
+  private volatile @Nullable GetWorkerMetadataStream getWorkerMetadataStream;

Review Comment:
   could you make this final and remove locking by having the create method 
create it and then pass it into the constructor?  Actually looks like you would 
have to change how the stream is started then since you'd want to start it only 
after this client is constructed.
   
   Another idea, what about using Suppliers.memoize() instead of doing the 
double-check yourself. It's more obviously correct.
   
   



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java:
##########
@@ -0,0 +1,432 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+
+import com.google.errorprone.annotations.CheckReturnValue;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Random;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.ProcessWorkItem;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Client for StreamingEngine. Given a {@link GetWorkBudget}, divides the 
budget and starts the
+ * {@link WindmillStream.GetWorkStream}(s).
+ */
+@Internal
+@CheckReturnValue
+@ThreadSafe
+public class StreamingEngineClient {
+  private static final Logger LOG = 
LoggerFactory.getLogger(StreamingEngineClient.class);
+  private static final String CONSUMER_WORKER_METADATA_THREAD = 
"ConsumeWorkerMetadataThread";
+
+  private final AtomicBoolean started;
+  private final JobHeader jobHeader;
+  private final GetWorkBudget totalGetWorkBudget;
+  private final GrpcWindmillStreamFactory streamFactory;
+  private final ProcessWorkItem processWorkItem;
+  private final WindmillStubFactory stubFactory;
+  private final GetWorkBudgetDistributor getWorkBudgetDistributor;
+  private final GrpcDispatcherClient dispatcherClient;
+  private final AtomicBoolean isBudgetRefreshPaused;
+  /** Writes are guarded by synchronization, reads are lock free. */
+  private final AtomicReference<StreamEngineConnectionState> connections;
+
+  private final GetWorkBudgetRefresher getWorkBudgetRefresher;
+  private final AtomicReference<Instant> lastBudgetRefresh;
+  private final ThrottleTimer getWorkerMetadataThrottleTimer;
+  private final CountDownLatch getWorkerMetadataReady;
+  private final ExecutorService consumeWorkerMetadataExecutor;
+  private final long clientId;
+  /**
+   * Reference to {@link GetWorkerMetadataStream} that is lazily initialized 
via double check
+   * locking, with its initial value being null.
+   */
+  private volatile @Nullable GetWorkerMetadataStream getWorkerMetadataStream;
+
+  private StreamingEngineClient(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      GrpcWindmillStreamFactory streamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory stubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient,
+      long clientId) {
+    this.jobHeader = jobHeader;
+    this.totalGetWorkBudget = totalGetWorkBudget;
+    this.started = new AtomicBoolean();
+    this.streamFactory = streamFactory;
+    this.processWorkItem = processWorkItem;
+    this.connections = connections;
+    this.stubFactory = stubFactory;
+    this.getWorkBudgetDistributor = getWorkBudgetDistributor;
+    this.dispatcherClient = dispatcherClient;
+    this.isBudgetRefreshPaused = new AtomicBoolean(false);
+    this.getWorkerMetadataThrottleTimer = new ThrottleTimer();
+    this.consumeWorkerMetadataExecutor =
+        Executors.newSingleThreadScheduledExecutor(
+            new ThreadFactoryBuilder()
+                
.setNameFormat(StreamingEngineClient.CONSUMER_WORKER_METADATA_THREAD)
+                // JVM will be responsible for shutdown and garbage collect 
these threads.
+                .setDaemon(true)
+                .setUncaughtExceptionHandler(
+                    (t, e) ->
+                        LOG.error(

Review Comment:
   throw it after logging? better to crash and restart than get stuck forever 
potentially



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java:
##########
@@ -0,0 +1,432 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+
+import com.google.errorprone.annotations.CheckReturnValue;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Random;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.ProcessWorkItem;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Client for StreamingEngine. Given a {@link GetWorkBudget}, divides the 
budget and starts the
+ * {@link WindmillStream.GetWorkStream}(s).
+ */
+@Internal
+@CheckReturnValue
+@ThreadSafe
+public class StreamingEngineClient {
+  private static final Logger LOG = 
LoggerFactory.getLogger(StreamingEngineClient.class);
+  private static final String CONSUMER_WORKER_METADATA_THREAD = 
"ConsumeWorkerMetadataThread";
+
+  private final AtomicBoolean started;
+  private final JobHeader jobHeader;
+  private final GetWorkBudget totalGetWorkBudget;
+  private final GrpcWindmillStreamFactory streamFactory;
+  private final ProcessWorkItem processWorkItem;
+  private final WindmillStubFactory stubFactory;
+  private final GetWorkBudgetDistributor getWorkBudgetDistributor;
+  private final GrpcDispatcherClient dispatcherClient;
+  private final AtomicBoolean isBudgetRefreshPaused;
+  /** Writes are guarded by synchronization, reads are lock free. */
+  private final AtomicReference<StreamEngineConnectionState> connections;
+
+  private final GetWorkBudgetRefresher getWorkBudgetRefresher;
+  private final AtomicReference<Instant> lastBudgetRefresh;
+  private final ThrottleTimer getWorkerMetadataThrottleTimer;
+  private final CountDownLatch getWorkerMetadataReady;
+  private final ExecutorService consumeWorkerMetadataExecutor;
+  private final long clientId;
+  /**
+   * Reference to {@link GetWorkerMetadataStream} that is lazily initialized 
via double check
+   * locking, with its initial value being null.
+   */
+  private volatile @Nullable GetWorkerMetadataStream getWorkerMetadataStream;
+
+  private StreamingEngineClient(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      GrpcWindmillStreamFactory streamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory stubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient,
+      long clientId) {
+    this.jobHeader = jobHeader;
+    this.totalGetWorkBudget = totalGetWorkBudget;
+    this.started = new AtomicBoolean();
+    this.streamFactory = streamFactory;
+    this.processWorkItem = processWorkItem;
+    this.connections = connections;
+    this.stubFactory = stubFactory;
+    this.getWorkBudgetDistributor = getWorkBudgetDistributor;
+    this.dispatcherClient = dispatcherClient;
+    this.isBudgetRefreshPaused = new AtomicBoolean(false);
+    this.getWorkerMetadataThrottleTimer = new ThrottleTimer();
+    this.consumeWorkerMetadataExecutor =
+        Executors.newSingleThreadScheduledExecutor(
+            new ThreadFactoryBuilder()
+                
.setNameFormat(StreamingEngineClient.CONSUMER_WORKER_METADATA_THREAD)
+                // JVM will be responsible for shutdown and garbage collect 
these threads.
+                .setDaemon(true)
+                .setUncaughtExceptionHandler(
+                    (t, e) ->
+                        LOG.error(
+                            "{} failed due to uncaught exception during 
execution. ",
+                            t.getName(),
+                            e))
+                .build());
+    this.getWorkerMetadataStream = null;
+    this.getWorkerMetadataReady = new CountDownLatch(1);
+    this.clientId = clientId;
+    this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH);
+    this.getWorkBudgetRefresher =
+        new GetWorkBudgetRefresher(
+            isBudgetRefreshPaused::get,
+            () -> {
+              getWorkBudgetDistributor.distributeBudget(
+                  connections.get().windmillStreams().values(), 
totalGetWorkBudget);
+              lastBudgetRefresh.set(Instant.now());
+            });
+  }
+
+  /**
+   * Creates an instance of {@link StreamingEngineClient} and starts the {@link
+   * GetWorkerMetadataStream} with an RPC to the StreamingEngine backend. 
{@link
+   * GetWorkerMetadataStream} will populate {@link #connections} when a 
response is received. Calls
+   * to {@link #startAndCacheStreams()} will block until {@link #connections} 
are populated.
+   */
+  public static StreamingEngineClient create(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      GrpcWindmillStreamFactory streamingEngineStreamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory windmillGrpcStubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient) {
+    StreamingEngineClient streamingEngineClient =
+        new StreamingEngineClient(
+            jobHeader,
+            totalGetWorkBudget,
+            new AtomicReference<>(StreamEngineConnectionState.EMPTY),
+            streamingEngineStreamFactory,
+            processWorkItem,
+            windmillGrpcStubFactory,
+            getWorkBudgetDistributor,
+            dispatcherClient,
+            new Random().nextLong());
+    streamingEngineClient.startGetWorkerMetadataStream();
+    return streamingEngineClient;
+  }
+
+  @VisibleForTesting
+  static StreamingEngineClient forTesting(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      GrpcWindmillStreamFactory streamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory stubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient,
+      long clientId) {
+    StreamingEngineClient streamingEngineClient =
+        new StreamingEngineClient(
+            jobHeader,
+            totalGetWorkBudget,
+            connections,
+            streamFactory,
+            processWorkItem,
+            stubFactory,
+            getWorkBudgetDistributor,
+            dispatcherClient,
+            clientId);
+    streamingEngineClient.startGetWorkerMetadataStream();
+    return streamingEngineClient;
+  }
+
+  /**
+   * Starts the streams with the {@link #connections} values. Does nothing if 
this has already been
+   * called.
+   *
+   * @throws IllegalArgumentException if trying to start before {@link 
#connections} are set with
+   *     {@link GetWorkerMetadataStream}.
+   */
+  public void startAndCacheStreams() {
+    // Do nothing if we have already initialized the initial streams.
+    if (!started.compareAndSet(false, true)) {
+      return;
+    }
+
+    // This will block if we have not received the first response from 
GetWorkerMetadata.
+    waitForFirstStreamingEngineEndpoints();
+    StreamEngineConnectionState currentConnectionsState = connections.get();
+    Preconditions.checkState(
+        !StreamEngineConnectionState.EMPTY.equals(currentConnectionsState),
+        "Cannot start streams without connections.");
+
+    LOG.info("Starting initial GetWorkStreams with connections={}", 
currentConnectionsState);
+    ImmutableCollection<WindmillStreamSender> windmillStreamSenders =
+        currentConnectionsState.windmillStreams().values();
+    getWorkBudgetDistributor.distributeBudget(
+        currentConnectionsState.windmillStreams().values(), 
totalGetWorkBudget);
+    lastBudgetRefresh.compareAndSet(Instant.EPOCH, Instant.now());
+    windmillStreamSenders.forEach(WindmillStreamSender::startStreams);
+    getWorkBudgetRefresher.start();
+  }
+
+  @VisibleForTesting
+  void finish() {
+    if (started.compareAndSet(true, false)) {
+      return;
+    }
+
+    
Optional.ofNullable(getWorkerMetadataStream).ifPresent(GetWorkerMetadataStream::close);
+    getWorkBudgetRefresher.stop();
+    consumeWorkerMetadataExecutor.shutdownNow();
+  }
+
+  private void waitForFirstStreamingEngineEndpoints() {
+    try {
+      getWorkerMetadataReady.await();
+    } catch (InterruptedException e) {
+      throw new StreamingEngineClientException(
+          "Error occurred waiting for StreamingEngine backend endpoints.", e);
+    }
+  }
+
+  /**
+   * {@link java.util.function.Consumer< WindmillEndpoints >} used to update 
{@link #connections} on
+   * new backend worker metadata.
+   */
+  private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints 
newWindmillEndpoints) {
+    isBudgetRefreshPaused.set(true);
+    LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints);
+    ImmutableMap<Endpoint, WindmillConnection> newWindmillConnections =
+        
createNewWindmillConnections(ImmutableSet.copyOf(newWindmillEndpoints.windmillEndpoints()));
+    ImmutableMap<WindmillConnection, WindmillStreamSender> newWindmillStreams =
+        
closeStaleStreamsAndCreateNewStreams(ImmutableSet.copyOf(newWindmillConnections.values()));
+    ImmutableMap<Endpoint, Supplier<GetDataStream>> newGlobalDataStreams =
+        createNewGlobalDataStreams(
+            
ImmutableSet.copyOf(newWindmillEndpoints.globalDataEndpoints().values()));
+
+    StreamEngineConnectionState newConnectionsState =
+        StreamEngineConnectionState.builder()
+            .setWindmillConnections(newWindmillConnections)
+            .setWindmillStreams(newWindmillStreams)
+            .setGlobalDataEndpoints(newWindmillEndpoints.globalDataEndpoints())
+            .setGlobalDataStreams(newGlobalDataStreams)
+            .build();
+
+    LOG.info(
+        "Setting new connections: {}. Previous connections: {}.",
+        newConnectionsState,
+        connections.get());
+    connections.set(newConnectionsState);
+    isBudgetRefreshPaused.set(false);
+
+    // Trigger on first worker metadata. Since startAndCacheStreams will block 
until first worker
+    // metadata is ready. Afterwards, just trigger a budget refresh.
+    if (getWorkerMetadataReady.getCount() > 0) {
+      getWorkerMetadataReady.countDown();
+    } else {
+      getWorkBudgetRefresher.requestBudgetRefresh();
+    }
+  }
+
+  public final ImmutableList<Long> getAndResetThrottleTimes() {
+    StreamEngineConnectionState currentConnections = connections.get();
+
+    ImmutableList<Long> keyedWorkStreamThrottleTimes =
+        currentConnections.windmillStreams().values().stream()
+            .map(WindmillStreamSender::getAndResetThrottleTime)
+            .collect(toImmutableList());
+
+    return ImmutableList.<Long>builder()
+        .add(getWorkerMetadataThrottleTimer.getAndResetThrottleTime())
+        .addAll(keyedWorkStreamThrottleTimes)
+        .build();
+  }
+
+  /** Starts {@link GetWorkerMetadataStream}. */
+  @SuppressWarnings({
+    "FutureReturnValueIgnored", // ignoring Future returned from 
Executor.submit()
+    "nullness" // Uninitialized value of getWorkerMetadataStream is null.
+  })
+  private void startGetWorkerMetadataStream() {
+    // We only want to set and start this value once.
+    if (getWorkerMetadataStream == null) {
+      synchronized (this) {
+        if (getWorkerMetadataStream == null) {
+          getWorkerMetadataStream =
+              streamFactory.createGetWorkerMetadataStream(
+                  dispatcherClient.getDispatcherStub(),
+                  getWorkerMetadataThrottleTimer,
+                  endpoints ->
+                      // Run this on a separate thread than the grpc stream 
thread.
+                      consumeWorkerMetadataExecutor.submit(
+                          () -> consumeWindmillWorkerEndpoints(endpoints)));
+        }
+      }
+    }
+  }
+
+  private synchronized ImmutableMap<Endpoint, WindmillConnection> 
createNewWindmillConnections(
+      ImmutableSet<Endpoint> newWindmillEndpoints) {
+    ImmutableMap<Endpoint, WindmillConnection> currentWindmillConnections =
+        connections.get().windmillConnections();
+    Map<Endpoint, WindmillConnection> newWindmillConnections =
+        currentWindmillConnections.entrySet().stream()
+            .filter(entry -> newWindmillEndpoints.contains(entry.getKey()))
+            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+
+    // Reuse existing stubs if they exist.
+    newWindmillEndpoints.forEach(
+        newEndpoint ->
+            newWindmillConnections.putIfAbsent(
+                newEndpoint, WindmillConnection.from(newEndpoint, 
this::createWindmillStub)));
+
+    return ImmutableMap.copyOf(newWindmillConnections);
+  }
+
+  private synchronized ImmutableMap<WindmillConnection, WindmillStreamSender>
+      closeStaleStreamsAndCreateNewStreams(
+          ImmutableSet<WindmillConnection> newWindmillConnections) {
+    ImmutableMap<WindmillConnection, WindmillStreamSender> currentStreams =
+        connections.get().windmillStreams();
+
+    // Close the streams that are no longer valid.
+    currentStreams.entrySet().stream()
+        .filter(
+            connectionAndStream -> 
!newWindmillConnections.contains(connectionAndStream.getKey()))
+        .map(Map.Entry::getValue)
+        .forEach(WindmillStreamSender::closeAllStreams);

Review Comment:
   verify this is non-blocking, add comment that it doesn't block



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java:
##########
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
+import javax.annotation.concurrent.ThreadSafe;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.StreamingEngineThrottleTimers;
+import org.apache.beam.runners.dataflow.worker.windmill.work.ProcessWorkItem;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
+
+/**
+ * Owns and maintains a set of streams used to communicate with a specific 
Windmill worker.
+ * Underlying streams are "cached" in a threadsafe manner so that once {@link 
Supplier#get} is
+ * called, a stream that is already started is returned.
+ *
+ * <p>Holds references to {@link
+ * 
Supplier<org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream>}
 because
+ * initializing the streams automatically start them, and we want to do so 
lazily here once the
+ * {@link GetWorkBudget} is set.
+ *
+ * <p>Once started, the underlying streams are "alive" until they are manually 
closed via {@link
+ * #closeAllStreams()}.
+ *
+ * <p>If closed, it means that the backend endpoint is no longer in the worker 
set. Once closed,
+ * these instances are not reused.
+ *
+ * @implNote Does not manage streams for fetching {@link
+ *     org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData} 
for side inputs.
+ */
+@Internal
+@ThreadSafe
+public class WindmillStreamSender {
+  private final AtomicBoolean started;
+  private final AtomicReference<GetWorkBudget> getWorkBudget;
+  private final Supplier<GetWorkStream> getWorkStream;
+  private final Supplier<GetDataStream> getDataStream;
+  private final Supplier<CommitWorkStream> commitWorkStream;
+  private final StreamingEngineThrottleTimers streamingEngineThrottleTimers;
+
+  private WindmillStreamSender(
+      CloudWindmillServiceV1Alpha1Stub stub,
+      GetWorkRequest getWorkRequest,
+      AtomicReference<GetWorkBudget> getWorkBudget,
+      GrpcWindmillStreamFactory streamingEngineStreamFactory,
+      ProcessWorkItem processWorkItem) {
+    this.started = new AtomicBoolean(false);
+    this.getWorkBudget = getWorkBudget;
+    this.streamingEngineThrottleTimers = 
StreamingEngineThrottleTimers.create();
+
+    // All streams are memoized/cached since they are expensive to create and 
some implementations
+    // perform side effects on construction (i.e. sending initial requests to 
the stream server to
+    // initiate the streaming RPC connection). Stream instances 
connect/reconnect internally so we
+    // can reuse the same instance through the entire lifecycle of 
WindmillStreamSender.
+    this.getDataStream =
+        Suppliers.memoize(
+            () ->
+                streamingEngineStreamFactory.createGetDataStream(
+                    stub, 
streamingEngineThrottleTimers.getDataThrottleTimer()));
+    this.commitWorkStream =
+        Suppliers.memoize(
+            () ->
+                streamingEngineStreamFactory.createCommitWorkStream(
+                    stub, 
streamingEngineThrottleTimers.commitWorkThrottleTimer()));
+    this.getWorkStream =
+        Suppliers.memoize(
+            () ->
+                streamingEngineStreamFactory.createDirectGetWorkStream(
+                    stub,
+                    withRequestBudget(getWorkRequest, getWorkBudget.get()),
+                    streamingEngineThrottleTimers.getWorkThrottleTimer(),
+                    getDataStream,
+                    commitWorkStream,
+                    processWorkItem));
+  }
+
+  public static WindmillStreamSender create(
+      CloudWindmillServiceV1Alpha1Stub stub,
+      GetWorkRequest getWorkRequest,
+      GetWorkBudget getWorkBudget,
+      GrpcWindmillStreamFactory streamingEngineStreamFactory,
+      ProcessWorkItem processWorkItem) {
+    return new WindmillStreamSender(
+        stub,
+        getWorkRequest,
+        new AtomicReference<>(getWorkBudget),
+        streamingEngineStreamFactory,
+        processWorkItem);
+  }
+
+  private static GetWorkRequest withRequestBudget(GetWorkRequest request, 
GetWorkBudget budget) {
+    return 
request.toBuilder().setMaxItems(budget.items()).setMaxBytes(budget.bytes()).build();
+  }
+
+  @SuppressWarnings("ReturnValueIgnored")
+  public void startStreams() {
+    Preconditions.checkState(
+        !getWorkBudget.get().equals(GetWorkBudget.noBudget()), "Cannot GetWork 
with no budget.");
+    getWorkStream.get();
+    getDataStream.get();
+    commitWorkStream.get();
+    // *stream.get() is all memoized in a threadsafe manner.
+    started.compareAndSet(false, true);
+  }
+
+  public void closeAllStreams() {
+    // Supplier<Stream>.get() starts the stream which is an expensive 
operation as it initiates the
+    // streaming RPCs by possibly making calls over the network. Do not close 
the streams unless
+    // they have already been started.
+    if (started.get()) {
+      getWorkStream.get().close();
+      getDataStream.get().close();
+      commitWorkStream.get().close();
+    }
+  }
+
+  public synchronized void adjustBudget(long itemsDelta, long bytesDelta) {
+    getWorkBudget.set(getWorkBudget.get().apply(itemsDelta, bytesDelta));
+    getWorkStream.get().adjustBudget(itemsDelta, bytesDelta);

Review Comment:
   only do this if started? or enforce that started was called?
   this seems like it will create getworkstream with possibly uncached 
getData/commitWork streams otherwise



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java:
##########
@@ -0,0 +1,432 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+
+import com.google.errorprone.annotations.CheckReturnValue;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Random;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.ProcessWorkItem;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Client for StreamingEngine. Given a {@link GetWorkBudget}, divides the 
budget and starts the
+ * {@link WindmillStream.GetWorkStream}(s).
+ */
+@Internal
+@CheckReturnValue
+@ThreadSafe
+public class StreamingEngineClient {
+  private static final Logger LOG = 
LoggerFactory.getLogger(StreamingEngineClient.class);
+  private static final String CONSUMER_WORKER_METADATA_THREAD = 
"ConsumeWorkerMetadataThread";
+
+  private final AtomicBoolean started;
+  private final JobHeader jobHeader;
+  private final GetWorkBudget totalGetWorkBudget;
+  private final GrpcWindmillStreamFactory streamFactory;
+  private final ProcessWorkItem processWorkItem;
+  private final WindmillStubFactory stubFactory;
+  private final GetWorkBudgetDistributor getWorkBudgetDistributor;
+  private final GrpcDispatcherClient dispatcherClient;
+  private final AtomicBoolean isBudgetRefreshPaused;
+  /** Writes are guarded by synchronization, reads are lock free. */
+  private final AtomicReference<StreamEngineConnectionState> connections;
+
+  private final GetWorkBudgetRefresher getWorkBudgetRefresher;
+  private final AtomicReference<Instant> lastBudgetRefresh;
+  private final ThrottleTimer getWorkerMetadataThrottleTimer;
+  private final CountDownLatch getWorkerMetadataReady;
+  private final ExecutorService consumeWorkerMetadataExecutor;
+  private final long clientId;
+  /**
+   * Reference to {@link GetWorkerMetadataStream} that is lazily initialized 
via double check
+   * locking, with its initial value being null.
+   */
+  private volatile @Nullable GetWorkerMetadataStream getWorkerMetadataStream;
+
+  private StreamingEngineClient(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      GrpcWindmillStreamFactory streamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory stubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient,
+      long clientId) {
+    this.jobHeader = jobHeader;
+    this.totalGetWorkBudget = totalGetWorkBudget;
+    this.started = new AtomicBoolean();
+    this.streamFactory = streamFactory;
+    this.processWorkItem = processWorkItem;
+    this.connections = connections;
+    this.stubFactory = stubFactory;
+    this.getWorkBudgetDistributor = getWorkBudgetDistributor;
+    this.dispatcherClient = dispatcherClient;
+    this.isBudgetRefreshPaused = new AtomicBoolean(false);
+    this.getWorkerMetadataThrottleTimer = new ThrottleTimer();
+    this.consumeWorkerMetadataExecutor =
+        Executors.newSingleThreadScheduledExecutor(
+            new ThreadFactoryBuilder()
+                
.setNameFormat(StreamingEngineClient.CONSUMER_WORKER_METADATA_THREAD)
+                // JVM will be responsible for shutdown and garbage collect 
these threads.
+                .setDaemon(true)
+                .setUncaughtExceptionHandler(
+                    (t, e) ->
+                        LOG.error(
+                            "{} failed due to uncaught exception during 
execution. ",
+                            t.getName(),
+                            e))
+                .build());
+    this.getWorkerMetadataStream = null;
+    this.getWorkerMetadataReady = new CountDownLatch(1);
+    this.clientId = clientId;
+    this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH);
+    this.getWorkBudgetRefresher =
+        new GetWorkBudgetRefresher(
+            isBudgetRefreshPaused::get,
+            () -> {
+              getWorkBudgetDistributor.distributeBudget(
+                  connections.get().windmillStreams().values(), 
totalGetWorkBudget);
+              lastBudgetRefresh.set(Instant.now());
+            });
+  }
+
+  /**
+   * Creates an instance of {@link StreamingEngineClient} and starts the {@link
+   * GetWorkerMetadataStream} with an RPC to the StreamingEngine backend. 
{@link
+   * GetWorkerMetadataStream} will populate {@link #connections} when a 
response is received. Calls
+   * to {@link #startAndCacheStreams()} will block until {@link #connections} 
are populated.
+   */
+  public static StreamingEngineClient create(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      GrpcWindmillStreamFactory streamingEngineStreamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory windmillGrpcStubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient) {
+    StreamingEngineClient streamingEngineClient =
+        new StreamingEngineClient(
+            jobHeader,
+            totalGetWorkBudget,
+            new AtomicReference<>(StreamEngineConnectionState.EMPTY),
+            streamingEngineStreamFactory,
+            processWorkItem,
+            windmillGrpcStubFactory,
+            getWorkBudgetDistributor,
+            dispatcherClient,
+            new Random().nextLong());
+    streamingEngineClient.startGetWorkerMetadataStream();
+    return streamingEngineClient;
+  }
+
+  @VisibleForTesting
+  static StreamingEngineClient forTesting(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      GrpcWindmillStreamFactory streamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory stubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient,
+      long clientId) {
+    StreamingEngineClient streamingEngineClient =
+        new StreamingEngineClient(
+            jobHeader,
+            totalGetWorkBudget,
+            connections,
+            streamFactory,
+            processWorkItem,
+            stubFactory,
+            getWorkBudgetDistributor,
+            dispatcherClient,
+            clientId);
+    streamingEngineClient.startGetWorkerMetadataStream();
+    return streamingEngineClient;
+  }
+
+  /**
+   * Starts the streams with the {@link #connections} values. Does nothing if 
this has already been
+   * called.
+   *
+   * @throws IllegalArgumentException if trying to start before {@link 
#connections} are set with
+   *     {@link GetWorkerMetadataStream}.
+   */
+  public void startAndCacheStreams() {
+    // Do nothing if we have already initialized the initial streams.
+    if (!started.compareAndSet(false, true)) {
+      return;
+    }
+
+    // This will block if we have not received the first response from 
GetWorkerMetadata.
+    waitForFirstStreamingEngineEndpoints();
+    StreamEngineConnectionState currentConnectionsState = connections.get();
+    Preconditions.checkState(
+        !StreamEngineConnectionState.EMPTY.equals(currentConnectionsState),
+        "Cannot start streams without connections.");
+
+    LOG.info("Starting initial GetWorkStreams with connections={}", 
currentConnectionsState);
+    ImmutableCollection<WindmillStreamSender> windmillStreamSenders =
+        currentConnectionsState.windmillStreams().values();
+    getWorkBudgetDistributor.distributeBudget(
+        currentConnectionsState.windmillStreams().values(), 
totalGetWorkBudget);
+    lastBudgetRefresh.compareAndSet(Instant.EPOCH, Instant.now());
+    windmillStreamSenders.forEach(WindmillStreamSender::startStreams);
+    getWorkBudgetRefresher.start();
+  }
+
+  @VisibleForTesting
+  void finish() {
+    if (started.compareAndSet(true, false)) {
+      return;
+    }
+
+    
Optional.ofNullable(getWorkerMetadataStream).ifPresent(GetWorkerMetadataStream::close);
+    getWorkBudgetRefresher.stop();
+    consumeWorkerMetadataExecutor.shutdownNow();
+  }
+
+  private void waitForFirstStreamingEngineEndpoints() {
+    try {
+      getWorkerMetadataReady.await();
+    } catch (InterruptedException e) {
+      throw new StreamingEngineClientException(
+          "Error occurred waiting for StreamingEngine backend endpoints.", e);
+    }
+  }
+
+  /**
+   * {@link java.util.function.Consumer< WindmillEndpoints >} used to update 
{@link #connections} on
+   * new backend worker metadata.
+   */
+  private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints 
newWindmillEndpoints) {
+    isBudgetRefreshPaused.set(true);
+    LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints);
+    ImmutableMap<Endpoint, WindmillConnection> newWindmillConnections =
+        
createNewWindmillConnections(ImmutableSet.copyOf(newWindmillEndpoints.windmillEndpoints()));
+    ImmutableMap<WindmillConnection, WindmillStreamSender> newWindmillStreams =
+        
closeStaleStreamsAndCreateNewStreams(ImmutableSet.copyOf(newWindmillConnections.values()));
+    ImmutableMap<Endpoint, Supplier<GetDataStream>> newGlobalDataStreams =
+        createNewGlobalDataStreams(
+            
ImmutableSet.copyOf(newWindmillEndpoints.globalDataEndpoints().values()));
+
+    StreamEngineConnectionState newConnectionsState =
+        StreamEngineConnectionState.builder()
+            .setWindmillConnections(newWindmillConnections)
+            .setWindmillStreams(newWindmillStreams)
+            .setGlobalDataEndpoints(newWindmillEndpoints.globalDataEndpoints())
+            .setGlobalDataStreams(newGlobalDataStreams)
+            .build();
+
+    LOG.info(
+        "Setting new connections: {}. Previous connections: {}.",
+        newConnectionsState,
+        connections.get());
+    connections.set(newConnectionsState);
+    isBudgetRefreshPaused.set(false);
+
+    // Trigger on first worker metadata. Since startAndCacheStreams will block 
until first worker
+    // metadata is ready. Afterwards, just trigger a budget refresh.
+    if (getWorkerMetadataReady.getCount() > 0) {
+      getWorkerMetadataReady.countDown();

Review Comment:
   can you instead move the initialization that startAndCacheStreams is doing 
here?
   
   then callers don't have to call any blocking method on this object for it to 
start working.  The method could be kept around just as waitForInitialResponse 
if it's useful for testing or clients but it also allows for just creating this 
object in usable state.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java:
##########
@@ -0,0 +1,432 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+
+import com.google.errorprone.annotations.CheckReturnValue;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Random;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactory;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.ProcessWorkItem;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Client for StreamingEngine. Given a {@link GetWorkBudget}, divides the 
budget and starts the
+ * {@link WindmillStream.GetWorkStream}(s).
+ */
+@Internal
+@CheckReturnValue
+@ThreadSafe
+public class StreamingEngineClient {
+  private static final Logger LOG = 
LoggerFactory.getLogger(StreamingEngineClient.class);
+  private static final String CONSUMER_WORKER_METADATA_THREAD = 
"ConsumeWorkerMetadataThread";
+
+  private final AtomicBoolean started;
+  private final JobHeader jobHeader;
+  private final GetWorkBudget totalGetWorkBudget;
+  private final GrpcWindmillStreamFactory streamFactory;
+  private final ProcessWorkItem processWorkItem;
+  private final WindmillStubFactory stubFactory;
+  private final GetWorkBudgetDistributor getWorkBudgetDistributor;
+  private final GrpcDispatcherClient dispatcherClient;
+  private final AtomicBoolean isBudgetRefreshPaused;
+  /** Writes are guarded by synchronization, reads are lock free. */
+  private final AtomicReference<StreamEngineConnectionState> connections;
+
+  private final GetWorkBudgetRefresher getWorkBudgetRefresher;
+  private final AtomicReference<Instant> lastBudgetRefresh;
+  private final ThrottleTimer getWorkerMetadataThrottleTimer;
+  private final CountDownLatch getWorkerMetadataReady;
+  private final ExecutorService consumeWorkerMetadataExecutor;
+  private final long clientId;
+  /**
+   * Reference to {@link GetWorkerMetadataStream} that is lazily initialized 
via double check
+   * locking, with its initial value being null.
+   */
+  private volatile @Nullable GetWorkerMetadataStream getWorkerMetadataStream;
+
+  private StreamingEngineClient(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      GrpcWindmillStreamFactory streamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory stubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient,
+      long clientId) {
+    this.jobHeader = jobHeader;
+    this.totalGetWorkBudget = totalGetWorkBudget;
+    this.started = new AtomicBoolean();
+    this.streamFactory = streamFactory;
+    this.processWorkItem = processWorkItem;
+    this.connections = connections;
+    this.stubFactory = stubFactory;
+    this.getWorkBudgetDistributor = getWorkBudgetDistributor;
+    this.dispatcherClient = dispatcherClient;
+    this.isBudgetRefreshPaused = new AtomicBoolean(false);
+    this.getWorkerMetadataThrottleTimer = new ThrottleTimer();
+    this.consumeWorkerMetadataExecutor =
+        Executors.newSingleThreadScheduledExecutor(
+            new ThreadFactoryBuilder()
+                
.setNameFormat(StreamingEngineClient.CONSUMER_WORKER_METADATA_THREAD)
+                // JVM will be responsible for shutdown and garbage collect 
these threads.
+                .setDaemon(true)
+                .setUncaughtExceptionHandler(
+                    (t, e) ->
+                        LOG.error(
+                            "{} failed due to uncaught exception during 
execution. ",
+                            t.getName(),
+                            e))
+                .build());
+    this.getWorkerMetadataStream = null;
+    this.getWorkerMetadataReady = new CountDownLatch(1);
+    this.clientId = clientId;
+    this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH);
+    this.getWorkBudgetRefresher =
+        new GetWorkBudgetRefresher(
+            isBudgetRefreshPaused::get,
+            () -> {
+              getWorkBudgetDistributor.distributeBudget(
+                  connections.get().windmillStreams().values(), 
totalGetWorkBudget);
+              lastBudgetRefresh.set(Instant.now());
+            });
+  }
+
+  /**
+   * Creates an instance of {@link StreamingEngineClient} and starts the {@link
+   * GetWorkerMetadataStream} with an RPC to the StreamingEngine backend. 
{@link
+   * GetWorkerMetadataStream} will populate {@link #connections} when a 
response is received. Calls
+   * to {@link #startAndCacheStreams()} will block until {@link #connections} 
are populated.
+   */
+  public static StreamingEngineClient create(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      GrpcWindmillStreamFactory streamingEngineStreamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory windmillGrpcStubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient) {
+    StreamingEngineClient streamingEngineClient =
+        new StreamingEngineClient(
+            jobHeader,
+            totalGetWorkBudget,
+            new AtomicReference<>(StreamEngineConnectionState.EMPTY),
+            streamingEngineStreamFactory,
+            processWorkItem,
+            windmillGrpcStubFactory,
+            getWorkBudgetDistributor,
+            dispatcherClient,
+            new Random().nextLong());
+    streamingEngineClient.startGetWorkerMetadataStream();
+    return streamingEngineClient;
+  }
+
+  @VisibleForTesting
+  static StreamingEngineClient forTesting(
+      JobHeader jobHeader,
+      GetWorkBudget totalGetWorkBudget,
+      AtomicReference<StreamEngineConnectionState> connections,
+      GrpcWindmillStreamFactory streamFactory,
+      ProcessWorkItem processWorkItem,
+      WindmillStubFactory stubFactory,
+      GetWorkBudgetDistributor getWorkBudgetDistributor,
+      GrpcDispatcherClient dispatcherClient,
+      long clientId) {
+    StreamingEngineClient streamingEngineClient =
+        new StreamingEngineClient(
+            jobHeader,
+            totalGetWorkBudget,
+            connections,
+            streamFactory,
+            processWorkItem,
+            stubFactory,
+            getWorkBudgetDistributor,
+            dispatcherClient,
+            clientId);
+    streamingEngineClient.startGetWorkerMetadataStream();
+    return streamingEngineClient;
+  }
+
+  /**
+   * Starts the streams with the {@link #connections} values. Does nothing if 
this has already been
+   * called.
+   *
+   * @throws IllegalArgumentException if trying to start before {@link 
#connections} are set with
+   *     {@link GetWorkerMetadataStream}.
+   */
+  public void startAndCacheStreams() {
+    // Do nothing if we have already initialized the initial streams.
+    if (!started.compareAndSet(false, true)) {
+      return;
+    }
+
+    // This will block if we have not received the first response from 
GetWorkerMetadata.
+    waitForFirstStreamingEngineEndpoints();
+    StreamEngineConnectionState currentConnectionsState = connections.get();
+    Preconditions.checkState(
+        !StreamEngineConnectionState.EMPTY.equals(currentConnectionsState),
+        "Cannot start streams without connections.");
+
+    LOG.info("Starting initial GetWorkStreams with connections={}", 
currentConnectionsState);
+    ImmutableCollection<WindmillStreamSender> windmillStreamSenders =
+        currentConnectionsState.windmillStreams().values();
+    getWorkBudgetDistributor.distributeBudget(
+        currentConnectionsState.windmillStreams().values(), 
totalGetWorkBudget);
+    lastBudgetRefresh.compareAndSet(Instant.EPOCH, Instant.now());
+    windmillStreamSenders.forEach(WindmillStreamSender::startStreams);
+    getWorkBudgetRefresher.start();
+  }
+
+  @VisibleForTesting
+  void finish() {
+    if (started.compareAndSet(true, false)) {
+      return;
+    }
+
+    
Optional.ofNullable(getWorkerMetadataStream).ifPresent(GetWorkerMetadataStream::close);
+    getWorkBudgetRefresher.stop();
+    consumeWorkerMetadataExecutor.shutdownNow();
+  }
+
+  private void waitForFirstStreamingEngineEndpoints() {
+    try {
+      getWorkerMetadataReady.await();
+    } catch (InterruptedException e) {
+      throw new StreamingEngineClientException(
+          "Error occurred waiting for StreamingEngine backend endpoints.", e);
+    }
+  }
+
+  /**
+   * {@link java.util.function.Consumer< WindmillEndpoints >} used to update 
{@link #connections} on
+   * new backend worker metadata.
+   */
+  private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints 
newWindmillEndpoints) {
+    isBudgetRefreshPaused.set(true);
+    LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints);
+    ImmutableMap<Endpoint, WindmillConnection> newWindmillConnections =
+        
createNewWindmillConnections(ImmutableSet.copyOf(newWindmillEndpoints.windmillEndpoints()));
+    ImmutableMap<WindmillConnection, WindmillStreamSender> newWindmillStreams =
+        
closeStaleStreamsAndCreateNewStreams(ImmutableSet.copyOf(newWindmillConnections.values()));
+    ImmutableMap<Endpoint, Supplier<GetDataStream>> newGlobalDataStreams =
+        createNewGlobalDataStreams(
+            
ImmutableSet.copyOf(newWindmillEndpoints.globalDataEndpoints().values()));
+
+    StreamEngineConnectionState newConnectionsState =
+        StreamEngineConnectionState.builder()
+            .setWindmillConnections(newWindmillConnections)
+            .setWindmillStreams(newWindmillStreams)
+            .setGlobalDataEndpoints(newWindmillEndpoints.globalDataEndpoints())
+            .setGlobalDataStreams(newGlobalDataStreams)
+            .build();
+
+    LOG.info(
+        "Setting new connections: {}. Previous connections: {}.",
+        newConnectionsState,
+        connections.get());
+    connections.set(newConnectionsState);
+    isBudgetRefreshPaused.set(false);
+
+    // Trigger on first worker metadata. Since startAndCacheStreams will block 
until first worker
+    // metadata is ready. Afterwards, just trigger a budget refresh.
+    if (getWorkerMetadataReady.getCount() > 0) {
+      getWorkerMetadataReady.countDown();
+    } else {
+      getWorkBudgetRefresher.requestBudgetRefresh();
+    }
+  }
+
+  public final ImmutableList<Long> getAndResetThrottleTimes() {
+    StreamEngineConnectionState currentConnections = connections.get();
+
+    ImmutableList<Long> keyedWorkStreamThrottleTimes =
+        currentConnections.windmillStreams().values().stream()
+            .map(WindmillStreamSender::getAndResetThrottleTime)
+            .collect(toImmutableList());
+
+    return ImmutableList.<Long>builder()
+        .add(getWorkerMetadataThrottleTimer.getAndResetThrottleTime())
+        .addAll(keyedWorkStreamThrottleTimes)
+        .build();
+  }
+
+  /** Starts {@link GetWorkerMetadataStream}. */
+  @SuppressWarnings({
+    "FutureReturnValueIgnored", // ignoring Future returned from 
Executor.submit()
+    "nullness" // Uninitialized value of getWorkerMetadataStream is null.
+  })
+  private void startGetWorkerMetadataStream() {
+    // We only want to set and start this value once.
+    if (getWorkerMetadataStream == null) {
+      synchronized (this) {
+        if (getWorkerMetadataStream == null) {
+          getWorkerMetadataStream =
+              streamFactory.createGetWorkerMetadataStream(
+                  dispatcherClient.getDispatcherStub(),
+                  getWorkerMetadataThrottleTimer,
+                  endpoints ->
+                      // Run this on a separate thread than the grpc stream 
thread.
+                      consumeWorkerMetadataExecutor.submit(
+                          () -> consumeWindmillWorkerEndpoints(endpoints)));
+        }
+      }
+    }
+  }
+
+  private synchronized ImmutableMap<Endpoint, WindmillConnection> 
createNewWindmillConnections(
+      ImmutableSet<Endpoint> newWindmillEndpoints) {
+    ImmutableMap<Endpoint, WindmillConnection> currentWindmillConnections =
+        connections.get().windmillConnections();
+    Map<Endpoint, WindmillConnection> newWindmillConnections =
+        currentWindmillConnections.entrySet().stream()
+            .filter(entry -> newWindmillEndpoints.contains(entry.getKey()))
+            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+
+    // Reuse existing stubs if they exist.
+    newWindmillEndpoints.forEach(
+        newEndpoint ->
+            newWindmillConnections.putIfAbsent(

Review Comment:
   how about just a single loop doing this
   
   newWindmillConnections.put(
      newEndpoint, 
      currentWindmillConnections.getOrDefault(
         newEndpoint, WindmillConnection.from(newEndpoint, 
this::createWindmillStub)));
   
   still constructs WindmillConnection even if not necessary but that isn't 
expensive.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java:
##########
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath.divide;
+
+import java.math.RoundingMode;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.WindmillStreamSender;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Evenly distributes the provided budget across the available {@link 
WindmillStreamSender}(s). */
+@Internal
+public final class EvenGetWorkBudgetDistributor implements 
GetWorkBudgetDistributor {
+  private static final Logger LOG = 
LoggerFactory.getLogger(EvenGetWorkBudgetDistributor.class);
+  private final Supplier<GetWorkBudget> activeWorkBudgetSupplier;
+
+  public EvenGetWorkBudgetDistributor(Supplier<GetWorkBudget> 
activeWorkBudgetSupplier) {
+    this.activeWorkBudgetSupplier = activeWorkBudgetSupplier;
+  }
+
+  private static boolean isBelowFiftyPercentOfTarget(
+      GetWorkBudget remaining, GetWorkBudget target) {
+    return remaining.items() < (target.items() * 0.5) || remaining.bytes() < 
(target.bytes() * 0.5);
+  }
+
+  @Override
+  public void distributeBudget(
+      ImmutableCollection<WindmillStreamSender> streams, GetWorkBudget 
getWorkBudget) {
+    if (streams.isEmpty()) {
+      LOG.debug("Cannot distribute budget to no streams.");
+      return;
+    }
+
+    if (getWorkBudget.equals(GetWorkBudget.noBudget())) {
+      LOG.debug("Cannot distribute 0 budget.");
+      return;
+    }
+
+    Map<WindmillStreamSender, GetWorkBudget> desiredBudgets =
+        computeDesiredBudgets(streams, getWorkBudget);
+
+    for (Entry<WindmillStreamSender, GetWorkBudget> streamAndDesiredBudget :
+        desiredBudgets.entrySet()) {
+      WindmillStreamSender stream = streamAndDesiredBudget.getKey();
+      GetWorkBudget desired = streamAndDesiredBudget.getValue();
+      GetWorkBudget remaining = stream.remainingGetWorkBudget();
+      if (isBelowFiftyPercentOfTarget(remaining, desired)) {
+        GetWorkBudget adjustment = desired.subtract(remaining);
+        LOG.info("Adjusting budget for stream={} by {}", stream, adjustment);
+        stream.adjustBudget(adjustment);
+      }
+    }
+  }
+
+  private ImmutableMap<WindmillStreamSender, GetWorkBudget> 
computeDesiredBudgets(
+      ImmutableCollection<WindmillStreamSender> streams, GetWorkBudget 
totalGetWorkBudget) {
+    GetWorkBudget activeWorkBudget = activeWorkBudgetSupplier.get();
+    LOG.info("Current active work budget: {}", activeWorkBudget);
+    GetWorkBudget budgetPerStream =
+        GetWorkBudget.builder()
+            .setItems(
+                divide(
+                    totalGetWorkBudget.items() - activeWorkBudget.items(),
+                    streams.size(),
+                    RoundingMode.CEILING))

Review Comment:
   rounding up here will drift upwards over the lifetime of the stream if the 
budget is a closed loop otherwise.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamSender.java:
##########
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
+import javax.annotation.concurrent.ThreadSafe;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.StreamingEngineThrottleTimers;
+import org.apache.beam.runners.dataflow.worker.windmill.work.ProcessWorkItem;
+import 
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
+
+/**
+ * Owns and maintains a set of streams used to communicate with a specific 
Windmill worker.
+ * Underlying streams are "cached" in a threadsafe manner so that once {@link 
Supplier#get} is
+ * called, a stream that is already started is returned.
+ *
+ * <p>Holds references to {@link
+ * 
Supplier<org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream>}
 because
+ * initializing the streams automatically start them, and we want to do so 
lazily here once the
+ * {@link GetWorkBudget} is set.
+ *
+ * <p>Once started, the underlying streams are "alive" until they are manually 
closed via {@link
+ * #closeAllStreams()}.
+ *
+ * <p>If closed, it means that the backend endpoint is no longer in the worker 
set. Once closed,
+ * these instances are not reused.
+ *
+ * @implNote Does not manage streams for fetching {@link
+ *     org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData} 
for side inputs.
+ */
+@Internal
+@ThreadSafe
+public class WindmillStreamSender {
+  private final AtomicBoolean started;
+  private final AtomicReference<GetWorkBudget> getWorkBudget;
+  private final Supplier<GetWorkStream> getWorkStream;
+  private final Supplier<GetDataStream> getDataStream;
+  private final Supplier<CommitWorkStream> commitWorkStream;
+  private final StreamingEngineThrottleTimers streamingEngineThrottleTimers;
+
+  private WindmillStreamSender(
+      CloudWindmillServiceV1Alpha1Stub stub,
+      GetWorkRequest getWorkRequest,
+      AtomicReference<GetWorkBudget> getWorkBudget,
+      GrpcWindmillStreamFactory streamingEngineStreamFactory,
+      ProcessWorkItem processWorkItem) {
+    this.started = new AtomicBoolean(false);
+    this.getWorkBudget = getWorkBudget;
+    this.streamingEngineThrottleTimers = 
StreamingEngineThrottleTimers.create();
+
+    // All streams are memoized/cached since they are expensive to create and 
some implementations
+    // perform side effects on construction (i.e. sending initial requests to 
the stream server to
+    // initiate the streaming RPC connection). Stream instances 
connect/reconnect internally so we
+    // can reuse the same instance through the entire lifecycle of 
WindmillStreamSender.
+    this.getDataStream =
+        Suppliers.memoize(
+            () ->
+                streamingEngineStreamFactory.createGetDataStream(
+                    stub, 
streamingEngineThrottleTimers.getDataThrottleTimer()));
+    this.commitWorkStream =
+        Suppliers.memoize(
+            () ->
+                streamingEngineStreamFactory.createCommitWorkStream(
+                    stub, 
streamingEngineThrottleTimers.commitWorkThrottleTimer()));
+    this.getWorkStream =
+        Suppliers.memoize(
+            () ->
+                streamingEngineStreamFactory.createDirectGetWorkStream(
+                    stub,
+                    withRequestBudget(getWorkRequest, getWorkBudget.get()),
+                    streamingEngineThrottleTimers.getWorkThrottleTimer(),
+                    getDataStream,
+                    commitWorkStream,
+                    processWorkItem));
+  }
+
+  public static WindmillStreamSender create(
+      CloudWindmillServiceV1Alpha1Stub stub,
+      GetWorkRequest getWorkRequest,
+      GetWorkBudget getWorkBudget,
+      GrpcWindmillStreamFactory streamingEngineStreamFactory,
+      ProcessWorkItem processWorkItem) {
+    return new WindmillStreamSender(
+        stub,
+        getWorkRequest,
+        new AtomicReference<>(getWorkBudget),
+        streamingEngineStreamFactory,
+        processWorkItem);
+  }
+
+  private static GetWorkRequest withRequestBudget(GetWorkRequest request, 
GetWorkBudget budget) {
+    return 
request.toBuilder().setMaxItems(budget.items()).setMaxBytes(budget.bytes()).build();
+  }
+
+  @SuppressWarnings("ReturnValueIgnored")
+  public void startStreams() {
+    Preconditions.checkState(
+        !getWorkBudget.get().equals(GetWorkBudget.noBudget()), "Cannot GetWork 
with no budget.");
+    getWorkStream.get();
+    getDataStream.get();
+    commitWorkStream.get();
+    // *stream.get() is all memoized in a threadsafe manner.
+    started.compareAndSet(false, true);
+  }
+
+  public void closeAllStreams() {
+    // Supplier<Stream>.get() starts the stream which is an expensive 
operation as it initiates the
+    // streaming RPCs by possibly making calls over the network. Do not close 
the streams unless
+    // they have already been started.
+    if (started.get()) {
+      getWorkStream.get().close();
+      getDataStream.get().close();
+      commitWorkStream.get().close();
+    }
+  }
+
+  public synchronized void adjustBudget(long itemsDelta, long bytesDelta) {
+    getWorkBudget.set(getWorkBudget.get().apply(itemsDelta, bytesDelta));
+    getWorkStream.get().adjustBudget(itemsDelta, bytesDelta);
+  }
+
+  public synchronized void adjustBudget(GetWorkBudget adjustment) {

Review Comment:
   remove synchronized and let delegate method handle it



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java:
##########
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.work.budget;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath.divide;
+
+import java.math.RoundingMode;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.WindmillStreamSender;
+import org.apache.beam.sdk.annotations.Internal;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Evenly distributes the provided budget across the available {@link 
WindmillStreamSender}(s). */
+@Internal
+public final class EvenGetWorkBudgetDistributor implements 
GetWorkBudgetDistributor {
+  private static final Logger LOG = 
LoggerFactory.getLogger(EvenGetWorkBudgetDistributor.class);
+  private final Supplier<GetWorkBudget> activeWorkBudgetSupplier;
+
+  public EvenGetWorkBudgetDistributor(Supplier<GetWorkBudget> 
activeWorkBudgetSupplier) {
+    this.activeWorkBudgetSupplier = activeWorkBudgetSupplier;
+  }
+
+  private static boolean isBelowFiftyPercentOfTarget(
+      GetWorkBudget remaining, GetWorkBudget target) {
+    return remaining.items() < (target.items() * 0.5) || remaining.bytes() < 
(target.bytes() * 0.5);
+  }
+
+  @Override
+  public void distributeBudget(
+      ImmutableCollection<WindmillStreamSender> streams, GetWorkBudget 
getWorkBudget) {
+    if (streams.isEmpty()) {
+      LOG.debug("Cannot distribute budget to no streams.");
+      return;
+    }
+
+    if (getWorkBudget.equals(GetWorkBudget.noBudget())) {
+      LOG.debug("Cannot distribute 0 budget.");
+      return;
+    }
+
+    Map<WindmillStreamSender, GetWorkBudget> desiredBudgets =
+        computeDesiredBudgets(streams, getWorkBudget);
+
+    for (Entry<WindmillStreamSender, GetWorkBudget> streamAndDesiredBudget :
+        desiredBudgets.entrySet()) {
+      WindmillStreamSender stream = streamAndDesiredBudget.getKey();
+      GetWorkBudget desired = streamAndDesiredBudget.getValue();
+      GetWorkBudget remaining = stream.remainingGetWorkBudget();
+      if (isBelowFiftyPercentOfTarget(remaining, desired)) {
+        GetWorkBudget adjustment = desired.subtract(remaining);
+        LOG.info("Adjusting budget for stream={} by {}", stream, adjustment);
+        stream.adjustBudget(adjustment);
+      }
+    }
+  }
+
+  private ImmutableMap<WindmillStreamSender, GetWorkBudget> 
computeDesiredBudgets(
+      ImmutableCollection<WindmillStreamSender> streams, GetWorkBudget 
totalGetWorkBudget) {
+    GetWorkBudget activeWorkBudget = activeWorkBudgetSupplier.get();
+    LOG.info("Current active work budget: {}", activeWorkBudget);
+    GetWorkBudget budgetPerStream =
+        GetWorkBudget.builder()
+            .setItems(
+                divide(
+                    totalGetWorkBudget.items() - activeWorkBudget.items(),
+                    streams.size(),
+                    RoundingMode.CEILING))
+            .setBytes(
+                divide(
+                    totalGetWorkBudget.bytes() - activeWorkBudget.bytes(),
+                    streams.size(),
+                    RoundingMode.CEILING))
+            .build();
+    LOG.info("Desired budgets per stream: {}; stream count: {}", 
budgetPerStream, streams.size());

Review Comment:
   ditto



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to