This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c548b34c7d pull out StreamPool/StreamData from WindmillServerStub 
file. Organize streaming appliance files into their own directory. (#27593)
9c548b34c7d is described below

commit 9c548b34c7d8190b26840146b19e49e45505a888
Author: martin trieu <[email protected]>
AuthorDate: Mon Aug 14 17:41:59 2023 -0700

    pull out StreamPool/StreamData from WindmillServerStub file. Organize 
streaming appliance files into their own directory. (#27593)
---
 .../worker/MetricTrackingWindmillServerStub.java   |  59 +++--
 .../dataflow/worker/StreamingDataflowWorker.java   |  10 +-
 .../options/StreamingDataflowWorkerOptions.java    |   4 +-
 .../worker/windmill/AbstractWindmillStream.java    |   1 -
 .../worker/windmill/WindmillServerBase.java        |  12 +-
 .../worker/windmill/WindmillServerStub.java        | 163 +------------
 .../dataflow/worker/windmill/WindmillStream.java   |  89 ++++++++
 .../worker/windmill/WindmillStreamPool.java        | 181 +++++++++++++++
 .../JniWindmillApplianceServer.java}               |  13 +-
 .../windmill/grpcclient/GrpcCommitWorkStream.java  |   2 +-
 .../windmill/grpcclient/GrpcGetDataStream.java     |   2 +-
 .../windmill/grpcclient/GrpcGetWorkStream.java     |   4 +-
 .../windmill/grpcclient/GrpcWindmillServer.java    |   4 +
 .../dataflow/worker/FakeWindmillServer.java        | 117 +++++-----
 .../worker/windmill/WindmillStreamPoolTest.java    | 251 +++++++++++++++++++++
 .../grpcclient/GrpcWindmillServerTest.java         |   6 +-
 16 files changed, 651 insertions(+), 267 deletions(-)

diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
index b0624006b65..33b55647213 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
@@ -29,7 +29,8 @@ import 
org.apache.beam.runners.dataflow.worker.util.MemoryMonitor;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
 import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
-import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillStreamPool;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.SettableFuture;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Duration;
@@ -47,6 +48,10 @@ import org.joda.time.Duration;
 })
 public class MetricTrackingWindmillServerStub {
 
+  private static final int MAX_READS_PER_BATCH = 60;
+  private static final int MAX_ACTIVE_READS = 10;
+  private static final int NUM_STREAMS = 1;
+  private static final Duration STREAM_TIMEOUT = Duration.standardSeconds(30);
   private final AtomicInteger activeSideInputs = new AtomicInteger();
   private final AtomicInteger activeStateReads = new AtomicInteger();
   private final AtomicInteger activeHeartbeats = new AtomicInteger();
@@ -54,39 +59,13 @@ public class MetricTrackingWindmillServerStub {
   private final MemoryMonitor gcThrashingMonitor;
   private final boolean useStreamingRequests;
 
-  private static final class ReadBatch {
-    ArrayList<QueueEntry> reads = new ArrayList<>();
-    SettableFuture<Boolean> startRead = SettableFuture.create();
-  }
-
   @GuardedBy("this")
   private final List<ReadBatch> pendingReadBatches;
 
   @GuardedBy("this")
   private int activeReadThreads = 0;
 
-  private WindmillServerStub.StreamPool<GetDataStream> streamPool;
-
-  private static final int MAX_READS_PER_BATCH = 60;
-  private static final int MAX_ACTIVE_READS = 10;
-  private static final int NUM_STREAMS = 1;
-  private static final Duration STREAM_TIMEOUT = Duration.standardSeconds(30);
-
-  private static final class QueueEntry {
-
-    final String computation;
-    final Windmill.KeyedGetDataRequest request;
-    final SettableFuture<Windmill.KeyedGetDataResponse> response;
-
-    QueueEntry(
-        String computation,
-        Windmill.KeyedGetDataRequest request,
-        SettableFuture<Windmill.KeyedGetDataResponse> response) {
-      this.computation = computation;
-      this.request = request;
-      this.response = response;
-    }
-  }
+  private WindmillStreamPool<GetDataStream> streamPool;
 
   public MetricTrackingWindmillServerStub(
       WindmillServerStub server, MemoryMonitor gcThrashingMonitor, boolean 
useStreamingRequests) {
@@ -100,8 +79,7 @@ public class MetricTrackingWindmillServerStub {
   public void start() {
     if (useStreamingRequests) {
       streamPool =
-          new WindmillServerStub.StreamPool<>(
-              NUM_STREAMS, STREAM_TIMEOUT, this.server::getDataStream);
+          WindmillStreamPool.create(NUM_STREAMS, STREAM_TIMEOUT, 
this.server::getDataStream);
     }
   }
 
@@ -300,4 +278,25 @@ public class MetricTrackingWindmillServerStub {
     }
     writer.println("Heartbeat Keys Active: " + activeHeartbeats.get());
   }
+
+  private static final class ReadBatch {
+    ArrayList<QueueEntry> reads = new ArrayList<>();
+    SettableFuture<Boolean> startRead = SettableFuture.create();
+  }
+
+  private static final class QueueEntry {
+
+    final String computation;
+    final Windmill.KeyedGetDataRequest request;
+    final SettableFuture<Windmill.KeyedGetDataResponse> response;
+
+    QueueEntry(
+        String computation,
+        Windmill.KeyedGetDataRequest request,
+        SettableFuture<Windmill.KeyedGetDataResponse> response) {
+      this.computation = computation;
+      this.request = request;
+      this.response = response;
+    }
+  }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index 1ffb4b94cf0..8629b711697 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -106,9 +106,9 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
 import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
-import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.CommitWorkStream;
-import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetWorkStream;
-import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.StreamPool;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillStreamPool;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.extensions.gcp.util.Transport;
@@ -1716,8 +1716,8 @@ public class StreamingDataflowWorker {
   }
 
   private void streamingCommitLoop() {
-    StreamPool<CommitWorkStream> streamPool =
-        new StreamPool<>(
+    WindmillStreamPool<CommitWorkStream> streamPool =
+        WindmillStreamPool.create(
             NUM_COMMIT_STREAMS, COMMIT_STREAM_TIMEOUT, 
windmillServer::commitWorkStream);
     Commit initialCommit = null;
     while (running.get()) {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java
index 446934b6958..cc5b3302b01 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java
@@ -19,8 +19,8 @@ package org.apache.beam.runners.dataflow.worker.options;
 
 import java.io.IOException;
 import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions;
-import org.apache.beam.runners.dataflow.worker.windmill.WindmillServer;
 import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer;
 import 
org.apache.beam.runners.dataflow.worker.windmill.grpcclient.GrpcWindmillServer;
 import org.apache.beam.sdk.options.Default;
 import org.apache.beam.sdk.options.DefaultValueFactory;
@@ -212,7 +212,7 @@ public interface StreamingDataflowWorkerOptions extends 
DataflowWorkerHarnessOpt
           throw new RuntimeException("Failed to create GrpcWindmillServer: ", 
e);
         }
       } else {
-        return new WindmillServer(streamingOptions.getLocalWindmillHostport());
+        return new 
JniWindmillApplianceServer(streamingOptions.getLocalWindmillHostport());
       }
     }
   }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java
index 79d446e2d4e..d3e7de58931 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java
@@ -30,7 +30,6 @@ import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Function;
 import java.util.function.Supplier;
-import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.WindmillStream;
 import org.apache.beam.sdk.util.BackOff;
 import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Status;
 import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.StatusRuntimeException;
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java
index b7d1507d378..fe81eece138 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java
@@ -19,11 +19,17 @@ package org.apache.beam.runners.dataflow.worker.windmill;
 
 import java.io.IOException;
 import java.util.Set;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream.WorkItemReceiver;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
 
 /**
- * Implementation of a WindmillServerStub which communcates with an actual 
windmill server at the
- * specified location.
+ * Implementation of a WindmillServerStub which communicates with a Windmill 
appliance server.
+ *
+ * @implNote This is only for use in Streaming Appliance. Please do not change 
the name or path of
+ *     this class as this will break JNI.
  */
 @SuppressWarnings({
   "nullness" // TODO(https://github.com/apache/beam/issues/20497)
@@ -33,7 +39,7 @@ public class WindmillServerBase extends WindmillServerStub {
   /** Pointer to the underlying native windmill client object. */
   private final long nativePointer;
 
-  WindmillServerBase(String host) {
+  protected WindmillServerBase(String host) {
     this.nativePointer = create(host);
   }
 
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
index b77b5a59cdf..1bb5359e06f 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
@@ -19,26 +19,13 @@ package org.apache.beam.runners.dataflow.worker.windmill;
 
 import java.io.IOException;
 import java.io.PrintWriter;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
 import java.util.Set;
-import java.util.concurrent.ThreadLocalRandom;
-import java.util.concurrent.TimeUnit;
-import java.util.function.Consumer;
-import java.util.function.Supplier;
-import javax.annotation.concurrent.ThreadSafe;
 import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
-import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
-import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
-import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
-import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream.WorkItemReceiver;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
-import org.checkerframework.checker.nullness.qual.Nullable;
-import org.joda.time.Duration;
-import org.joda.time.Instant;
 
 /** Stub for communicating with a Windmill server. */
 @SuppressWarnings({
@@ -92,150 +79,10 @@ public abstract class WindmillServerStub implements 
StatusDataProvider {
   @Override
   public void appendSummaryHtml(PrintWriter writer) {}
 
-  /** Functional interface for receiving WorkItems. */
-  @FunctionalInterface
-  public interface WorkItemReceiver {
-
-    void receiveWork(
-        String computation,
-        @Nullable Instant inputDataWatermark,
-        @Nullable Instant synchronizedProcessingTime,
-        Windmill.WorkItem workItem,
-        Collection<LatencyAttribution> getWorkStreamLatencies);
-  }
-
-  /** Superclass for streams returned by streaming Windmill methods. */
-  @ThreadSafe
-  public interface WindmillStream {
-    /** Indicates that no more requests will be sent. */
-    void close();
-
-    /** Waits for the server to close its end of the connection, with timeout. 
*/
-    boolean awaitTermination(int time, TimeUnit unit) throws 
InterruptedException;
-
-    /** Returns when the stream was opened. */
-    Instant startTime();
-  }
-
-  /** Handle representing a stream of GetWork responses. */
-  @ThreadSafe
-  public interface GetWorkStream extends WindmillStream {}
-
-  /** Interface for streaming GetDataRequests to Windmill. */
-  @ThreadSafe
-  public interface GetDataStream extends WindmillStream {
-    /** Issues a keyed GetData fetch, blocking until the result is ready. */
-    KeyedGetDataResponse requestKeyedData(String computation, 
Windmill.KeyedGetDataRequest request);
-
-    /** Issues a global GetData fetch, blocking until the result is ready. */
-    Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request);
-
-    /** Tells windmill processing is ongoing for the given keys. */
-    void refreshActiveWork(Map<String, List<KeyedGetDataRequest>> active);
-  }
-
-  /** Interface for streaming CommitWorkRequests to Windmill. */
-  @ThreadSafe
-  public interface CommitWorkStream extends WindmillStream {
-
-    /**
-     * Commits a work item and running onDone when the commit has been 
processed by the server.
-     * Returns true if the request was accepted. If false is returned the 
stream should be flushed
-     * and the request recommitted.
-     *
-     * <p>onDone will be called with the status of the commit.
-     */
-    boolean commitWorkItem(
-        String computation, Windmill.WorkItemCommitRequest request, 
Consumer<CommitStatus> onDone);
-
-    /** Flushes any pending work items to the wire. */
-    void flush();
-  }
-
-  /**
-   * Pool of homogeneous streams to Windmill.
-   *
-   * <p>The pool holds a fixed total number of streams, and keeps each stream 
open for a specified
-   * time to allow for better load-balancing.
-   */
-  @ThreadSafe
-  public static class StreamPool<S extends WindmillStream> {
-
-    private final Duration streamTimeout;
-    private final List<StreamData> streams;
-
-    private final Supplier<S> supplier;
-    private final HashMap<S, StreamData> holds;
-
-    public StreamPool(int numStreams, Duration streamTimeout, Supplier<S> 
supplier) {
-      this.streams = new ArrayList<>(numStreams);
-      for (int i = 0; i < numStreams; i++) {
-        streams.add(null);
-      }
-      this.streamTimeout = streamTimeout;
-      this.supplier = supplier;
-      this.holds = new HashMap<>();
-    }
-
-    // Returns a stream for use that may be cached from a previous call.  Each 
call of getStream
-    // must be matched with a call of releaseStream.
-    public S getStream() {
-      int index = ThreadLocalRandom.current().nextInt(streams.size());
-      S result;
-      S closeStream = null;
-      synchronized (this) {
-        StreamData streamData = streams.get(index);
-        if (streamData == null
-            || 
streamData.stream.startTime().isBefore(Instant.now().minus(streamTimeout))) {
-          if (streamData != null && --streamData.holds == 0) {
-            holds.remove(streamData.stream);
-            closeStream = streamData.stream;
-          }
-          streamData = new StreamData();
-          streams.set(index, streamData);
-          holds.put(streamData.stream, streamData);
-        }
-        streamData.holds++;
-        result = streamData.stream;
-      }
-      if (closeStream != null) {
-        closeStream.close();
-      }
-      return result;
-    }
-
-    // Releases a stream that was obtained with getStream.
-    public void releaseStream(S stream) {
-      boolean closeStream = false;
-      synchronized (this) {
-        if (--holds.get(stream).holds == 0) {
-          closeStream = true;
-          holds.remove(stream);
-        }
-      }
-      if (closeStream) {
-        stream.close();
-      }
-    }
-
-    private final class StreamData {
-      final S stream = supplier.get();
-      int holds = 1;
-    }
-  }
-
   /** Generic Exception type for implementors to use to represent errors while 
making RPCs. */
-  public static class RpcException extends RuntimeException {
-    public RpcException() {
-      super();
-    }
-
+  public static final class RpcException extends RuntimeException {
     public RpcException(Throwable cause) {
       super(cause);
     }
-
-    public RpcException(String message, Throwable cause) {
-      super(message, cause);
-    }
   }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStream.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStream.java
new file mode 100644
index 00000000000..70c7cc36ba3
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStream.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+import javax.annotation.concurrent.ThreadSafe;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Instant;
+
+/** Superclass for streams returned by streaming Windmill methods. */
+@ThreadSafe
+public interface WindmillStream {
+  /** Indicates that no more requests will be sent. */
+  void close();
+
+  /** Waits for the server to close its end of the connection, with timeout. */
+  boolean awaitTermination(int time, TimeUnit unit) throws 
InterruptedException;
+
+  /** Returns when the stream was opened. */
+  Instant startTime();
+
+  /** Handle representing a stream of GetWork responses. */
+  @ThreadSafe
+  interface GetWorkStream extends WindmillStream {
+    /** Functional interface for receiving WorkItems. */
+    @FunctionalInterface
+    interface WorkItemReceiver {
+      void receiveWork(
+          String computation,
+          @Nullable Instant inputDataWatermark,
+          @Nullable Instant synchronizedProcessingTime,
+          Windmill.WorkItem workItem,
+          Collection<Windmill.LatencyAttribution> getWorkStreamLatencies);
+    }
+  }
+
+  /** Interface for streaming GetDataRequests to Windmill. */
+  @ThreadSafe
+  interface GetDataStream extends WindmillStream {
+    /** Issues a keyed GetData fetch, blocking until the result is ready. */
+    Windmill.KeyedGetDataResponse requestKeyedData(
+        String computation, Windmill.KeyedGetDataRequest request);
+
+    /** Issues a global GetData fetch, blocking until the result is ready. */
+    Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request);
+
+    /** Tells windmill processing is ongoing for the given keys. */
+    void refreshActiveWork(Map<String, List<Windmill.KeyedGetDataRequest>> 
active);
+  }
+
+  /** Interface for streaming CommitWorkRequests to Windmill. */
+  @ThreadSafe
+  interface CommitWorkStream extends WindmillStream {
+
+    /**
+     * Commits a work item and running onDone when the commit has been 
processed by the server.
+     * Returns true if the request was accepted. If false is returned the 
stream should be flushed
+     * and the request recommitted.
+     *
+     * <p>onDone will be called with the status of the commit.
+     */
+    boolean commitWorkItem(
+        String computation,
+        Windmill.WorkItemCommitRequest request,
+        Consumer<Windmill.CommitStatus> onDone);
+
+    /** Flushes any pending work items to the wire. */
+    void flush();
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStreamPool.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStreamPool.java
new file mode 100644
index 00000000000..9cd4ab0ea4a
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStreamPool.java
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.function.Supplier;
+import javax.annotation.concurrent.GuardedBy;
+import javax.annotation.concurrent.ThreadSafe;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+
+/**
+ * Pool of homogeneous streams to Windmill.
+ *
+ * <p>The pool holds a fixed total number of streams, and keeps each stream 
open for a specified
+ * time to allow for better load-balancing.
+ */
+@ThreadSafe
+public class WindmillStreamPool<StreamT extends WindmillStream> {
+
+  private final Duration streamTimeout;
+  private final Supplier<StreamT> streamSupplier;
+
+  /** @implNote Size of streams never changes once initialized. */
+  private final List<@Nullable StreamData<StreamT>> streams;
+
+  @GuardedBy("this")
+  private final Map<StreamT, StreamData<StreamT>> holds;
+
+  private WindmillStreamPool(
+      Duration streamTimeout,
+      Supplier<StreamT> streamSupplier,
+      List<@Nullable StreamData<StreamT>> streams,
+      Map<StreamT, StreamData<StreamT>> holds) {
+    this.streams = streams;
+    this.streamTimeout = streamTimeout;
+    this.streamSupplier = streamSupplier;
+    this.holds = holds;
+  }
+
+  public static <StreamT extends WindmillStream> WindmillStreamPool<StreamT> 
create(
+      int numStreams, Duration streamTimeout, Supplier<StreamT> 
streamSupplier) {
+    return new WindmillStreamPool<>(
+        streamTimeout, streamSupplier, newStreamList(numStreams), new 
HashMap<>());
+  }
+
+  @VisibleForTesting
+  static <StreamT extends WindmillStream> WindmillStreamPool<StreamT> 
forTesting(
+      Duration streamTimeout,
+      Supplier<StreamT> streamSupplier,
+      List<@Nullable StreamData<StreamT>> streamPool,
+      Map<StreamT, StreamData<StreamT>> holds) {
+    return new WindmillStreamPool<>(streamTimeout, streamSupplier, streamPool, 
holds);
+  }
+
+  /**
+   * Creates a new list of streams of the given capacity with all values 
initialized to null. This
+   * is because we randomly load balance across all the streams in the pool.
+   */
+  @VisibleForTesting
+  static <StreamT extends WindmillStream> List<@Nullable StreamData<StreamT>> 
newStreamList(
+      int numStreams) {
+    List<@Nullable StreamData<StreamT>> streamPool = new 
ArrayList<>(numStreams);
+    for (int i = 0; i < numStreams; i++) {
+      streamPool.add(null);
+    }
+    return streamPool;
+  }
+
+  /**
+   * Returns a stream for use that may be cached from a previous call. Each 
call of getStream must
+   * be matched with a call of {@link 
WindmillStreamPool#releaseStream(WindmillStream)}. If the
+   * stream has been cached but has timed out and drained (no longer has any 
holds), the stream will
+   * be closed.
+   */
+  public StreamT getStream() {
+    int index = streams.size() == 1 ? 0 : 
ThreadLocalRandom.current().nextInt(streams.size());
+    // We will return this stream
+    StreamT resultStream;
+    StreamT closeThisStream = null;
+    try {
+      synchronized (this) {
+        WindmillStreamPool.StreamData<StreamT> existingStreamData = 
streams.get(index);
+        // There are 3 possible states that can result from fetching the 
stream from the cache.
+        if (existingStreamData == null) {
+          // 1. Stream doesn't exist create and cache a new one.
+          resultStream = createAndCacheStream(index).stream;
+        } else if (existingStreamData.hasTimedOut(streamTimeout)) {
+          // 2. The stream exists, but has timed out. The timed out stream is 
not returned (a new
+          // one is created and returned here) and evicted from the cache if 
the stream has
+          // completely drained. Every call to getStream(), is matched with a 
call to
+          // releaseStream(), so the stream will eventually drain and be 
closed.
+          if (--existingStreamData.holds == 0) {
+            holds.remove(existingStreamData.stream);
+            closeThisStream = existingStreamData.stream;
+          }
+          // Create and cache a new stream at the timed out stream's index.
+          resultStream = createAndCacheStream(index).stream;
+        } else {
+          // 3. The stream exists and is in a valid state.
+          existingStreamData.holds++;
+          resultStream = existingStreamData.stream;
+        }
+      }
+      return resultStream;
+    } finally {
+      if (closeThisStream != null) {
+        closeThisStream.close();
+      }
+    }
+  }
+
+  private synchronized WindmillStreamPool.StreamData<StreamT> 
createAndCacheStream(int cacheKey) {
+    WindmillStreamPool.StreamData<StreamT> newStreamData =
+        new WindmillStreamPool.StreamData<>(streamSupplier.get());
+    newStreamData.holds++;
+    streams.set(cacheKey, newStreamData);
+    holds.put(newStreamData.stream, newStreamData);
+    return newStreamData;
+  }
+
+  /** Releases a stream that was obtained with {@link 
WindmillStreamPool#getStream()}. */
+  public void releaseStream(StreamT stream) {
+    boolean closeStream = false;
+    synchronized (this) {
+      StreamData<StreamT> streamData = holds.get(stream);
+      // All streams that are created by an instance of a pool will be present.
+      if (streamData == null) {
+        throw new IllegalStateException(
+            "Attempted to release stream that does not exist in this pool. 
This stream "
+                + "may not have been created by this pool or may have been 
released more "
+                + "times than acquired.");
+      }
+      if (--streamData.holds == 0) {
+        closeStream = true;
+        holds.remove(stream);
+      }
+    }
+
+    if (closeStream) {
+      stream.close();
+    }
+  }
+
+  @VisibleForTesting
+  static final class StreamData<StreamT extends WindmillStream> {
+    final StreamT stream;
+    int holds;
+
+    @VisibleForTesting
+    StreamData(StreamT stream) {
+      this.stream = stream;
+      holds = 1;
+    }
+
+    private boolean hasTimedOut(Duration timeout) {
+      return stream.startTime().isBefore(Instant.now().minus(timeout));
+    }
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServer.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/appliance/JniWindmillApplianceServer.java
similarity index 83%
rename from 
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServer.java
rename to 
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/appliance/JniWindmillApplianceServer.java
index fbc67987811..8385c7cb597 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServer.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/appliance/JniWindmillApplianceServer.java
@@ -15,19 +15,24 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.beam.runners.dataflow.worker.windmill;
+package org.apache.beam.runners.dataflow.worker.windmill.appliance;
 
 import java.io.File;
 import java.io.IOException;
 import java.io.InputStream;
 import java.nio.file.Files;
 import java.nio.file.StandardCopyOption;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerBase;
 
-/** Implementation of a WindmillServerBase. */
+/**
+ * JNI Implementation of a {@link WindmillServerBase}.
+ *
+ * @implNote This is only for use in Streaming Appliance.
+ */
 @SuppressWarnings({
   "nullness" // TODO(https://github.com/apache/beam/issues/20497)
 })
-public class WindmillServer extends WindmillServerBase {
+public class JniWindmillApplianceServer extends WindmillServerBase {
   private static final String WINDMILL_SERVER_JNI_LIBRARY_PROPERTY = 
"windmill.jni_library";
   private static final String DEFAULT_SHUFFLE_CLIENT_LIBRARY = 
"libwindmill_service_jni.so";
 
@@ -50,7 +55,7 @@ public class WindmillServer extends WindmillServerBase {
    * The host should be specified as protocol://address:port to connect to a 
windmill server through
    * rpcz.
    */
-  public WindmillServer(String host) {
+  public JniWindmillApplianceServer(String host) {
     super(host);
   }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java
index 0e444844f0f..74bd93a5474 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java
@@ -34,7 +34,7 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommit
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitResponse;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitWorkRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
-import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.CommitWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream;
 import org.apache.beam.sdk.util.BackOff;
 import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java
index 3f4f6b0a922..b51daabb1a2 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java
@@ -41,7 +41,7 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataReq
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataResponse;
-import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
 import 
org.apache.beam.runners.dataflow.worker.windmill.grpcclient.GrpcGetDataStreamRequests.QueuedBatch;
 import 
org.apache.beam.runners.dataflow.worker.windmill.grpcclient.GrpcGetDataStreamRequests.QueuedRequest;
 import org.apache.beam.sdk.util.BackOff;
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java
index d0edcc45828..6e35beccdb6 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java
@@ -36,8 +36,8 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribut
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequest;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequestExtension;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkResponseChunk;
-import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetWorkStream;
-import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.WorkItemReceiver;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream.WorkItemReceiver;
 import org.apache.beam.sdk.util.BackOff;
 import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
 import org.joda.time.Instant;
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java
index 986a7e22eeb..5003f9948c5 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java
@@ -52,6 +52,10 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ReportStatsRequ
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ReportStatsResponse;
 import org.apache.beam.runners.dataflow.worker.windmill.WindmillApplianceGrpc;
 import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream.WorkItemReceiver;
 import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.util.BackOff;
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
index 83046db3f89..4700217dc8a 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
@@ -53,6 +53,9 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribut
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution.State;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
 import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles;
 import org.joda.time.Duration;
@@ -64,58 +67,6 @@ import org.slf4j.LoggerFactory;
 /** An in-memory Windmill server that offers provided work and data. */
 class FakeWindmillServer extends WindmillServerStub {
   private static final Logger LOG = 
LoggerFactory.getLogger(FakeWindmillServer.class);
-
-  static class ResponseQueue<T, U> {
-    private final Queue<Function<T, U>> responses = new 
ConcurrentLinkedQueue<>();
-    private Function<T, U> defaultResponse;
-    Duration sleep = Duration.ZERO;
-
-    // (Fluent) interface for response producers, accessible from tests.
-
-    ResponseQueue<T, U> thenAnswer(Function<T, U> mapFun) {
-      responses.add(mapFun);
-      return this;
-    }
-
-    ResponseQueue<T, U> thenReturn(U response) {
-      return thenAnswer((request) -> response);
-    }
-
-    ResponseQueue<T, U> answerByDefault(Function<T, U> mapFun) {
-      defaultResponse = mapFun;
-      return this;
-    }
-
-    ResponseQueue<T, U> returnByDefault(U response) {
-      return answerByDefault((request) -> response);
-    }
-
-    ResponseQueue<T, U> delayEachResponseBy(Duration sleep) {
-      this.sleep = sleep;
-      return this;
-    }
-
-    // Interface for response consumers, accessible from the enclosing class.
-
-    private U getOrDefault(T request) {
-      Function<T, U> mapFun = responses.poll();
-      U response = mapFun == null ? defaultResponse.apply(request) : 
mapFun.apply(request);
-      Uninterruptibles.sleepUninterruptibly(sleep.getMillis(), 
TimeUnit.MILLISECONDS);
-      return response;
-    }
-
-    private U get(T request) {
-      Function<T, U> mapFun = responses.poll();
-      U response = mapFun == null ? null : mapFun.apply(request);
-      Uninterruptibles.sleepUninterruptibly(sleep.getMillis(), 
TimeUnit.MILLISECONDS);
-      return response;
-    }
-
-    private boolean isEmpty() {
-      return responses.isEmpty();
-    }
-  }
-
   private final ResponseQueue<Windmill.GetWorkRequest, 
Windmill.GetWorkResponse> workToOffer;
   private final ResponseQueue<GetDataRequest, GetDataResponse> dataToOffer;
   private final ResponseQueue<Windmill.CommitWorkRequest, CommitWorkResponse> 
commitsToOffer;
@@ -123,13 +74,13 @@ class FakeWindmillServer extends WindmillServerStub {
   private final Map<Long, WorkItemCommitRequest> commitsReceived;
   private final ArrayList<Windmill.ReportStatsRequest> statsReceived;
   private final LinkedBlockingQueue<Windmill.Exception> exceptions;
-  private int commitsRequested = 0;
-  private int numGetDataRequests = 0;
   private final AtomicInteger expectedExceptionCount;
   private final ErrorCollector errorCollector;
+  private final ConcurrentHashMap<Long, Consumer<Windmill.CommitStatus>> 
droppedStreamingCommits;
+  private int commitsRequested = 0;
+  private int numGetDataRequests = 0;
   private boolean isReady = true;
   private boolean dropStreamingCommits = false;
-  private final ConcurrentHashMap<Long, Consumer<Windmill.CommitStatus>> 
droppedStreamingCommits;
 
   public FakeWindmillServer(ErrorCollector errorCollector) {
     workToOffer =
@@ -243,11 +194,12 @@ class FakeWindmillServer extends WindmillServerStub {
 
   @Override
   public long getAndResetThrottleTime() {
-    return (long) 0;
+    return 0;
   }
 
   @Override
-  public GetWorkStream getWorkStream(Windmill.GetWorkRequest request, 
WorkItemReceiver receiver) {
+  public GetWorkStream getWorkStream(
+      Windmill.GetWorkRequest request, GetWorkStream.WorkItemReceiver 
receiver) {
     LOG.debug("getWorkStream: {}", request.toString());
     Instant startTime = Instant.now();
     final CountDownLatch done = new CountDownLatch(1);
@@ -485,4 +437,55 @@ class FakeWindmillServer extends WindmillServerStub {
   public void setIsReady(boolean ready) {
     this.isReady = ready;
   }
+
+  static class ResponseQueue<T, U> {
+    private final Queue<Function<T, U>> responses = new 
ConcurrentLinkedQueue<>();
+    Duration sleep = Duration.ZERO;
+    private Function<T, U> defaultResponse;
+
+    // (Fluent) interface for response producers, accessible from tests.
+
+    ResponseQueue<T, U> thenAnswer(Function<T, U> mapFun) {
+      responses.add(mapFun);
+      return this;
+    }
+
+    ResponseQueue<T, U> thenReturn(U response) {
+      return thenAnswer((request) -> response);
+    }
+
+    ResponseQueue<T, U> answerByDefault(Function<T, U> mapFun) {
+      defaultResponse = mapFun;
+      return this;
+    }
+
+    ResponseQueue<T, U> returnByDefault(U response) {
+      return answerByDefault((request) -> response);
+    }
+
+    ResponseQueue<T, U> delayEachResponseBy(Duration sleep) {
+      this.sleep = sleep;
+      return this;
+    }
+
+    // Interface for response consumers, accessible from the enclosing class.
+
+    private U getOrDefault(T request) {
+      Function<T, U> mapFun = responses.poll();
+      U response = mapFun == null ? defaultResponse.apply(request) : 
mapFun.apply(request);
+      Uninterruptibles.sleepUninterruptibly(sleep.getMillis(), 
TimeUnit.MILLISECONDS);
+      return response;
+    }
+
+    private U get(T request) {
+      Function<T, U> mapFun = responses.poll();
+      U response = mapFun == null ? null : mapFun.apply(request);
+      Uninterruptibles.sleepUninterruptibly(sleep.getMillis(), 
TimeUnit.MILLISECONDS);
+      return response;
+    }
+
+    private boolean isEmpty() {
+      return responses.isEmpty();
+    }
+  }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStreamPoolTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStreamPoolTest.java
new file mode 100644
index 00000000000..9924bb7d2b2
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStreamPoolTest.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotSame;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+import java.util.List;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class WindmillStreamPoolTest {
+  private static final int DEFAULT_NUM_STREAMS = 10;
+  private static final int NEW_STREAM_HOLDS = 2;
+  private final ConcurrentHashMap<
+          TestWindmillStream, 
WindmillStreamPool.StreamData<TestWindmillStream>>
+      holds = new ConcurrentHashMap<>();
+  private List<WindmillStreamPool.@Nullable StreamData<TestWindmillStream>> 
streams;
+
+  @Before
+  public void setUp() {
+    streams = WindmillStreamPool.newStreamList(DEFAULT_NUM_STREAMS);
+    holds.clear();
+  }
+
+  @Test
+  public void testGetStream_returnsAndCachesNewStream() {
+    Duration streamTimeout = Duration.standardSeconds(1);
+    WindmillStreamPool<TestWindmillStream> streamPool =
+        WindmillStreamPool.forTesting(
+            streamTimeout, () -> new TestWindmillStream(Instant.now()), 
streams, holds);
+    TestWindmillStream stream = streamPool.getStream();
+    assertTrue(holds.containsKey(stream));
+    assertEquals(2, holds.get(stream).holds);
+    assertTrue(streams.contains(holds.get(stream)));
+  }
+
+  @Test
+  public void testGetStream_returnsCachedStreamAndIncrementsHolds() {
+    Duration streamTimeout = Duration.standardDays(1);
+    int cachedStreamHolds = 2;
+    // Populate the stream data.
+    for (int i = 0; i < DEFAULT_NUM_STREAMS; i++) {
+      WindmillStreamPool.StreamData<TestWindmillStream> streamData =
+          new WindmillStreamPool.StreamData<>(new 
TestWindmillStream(Instant.now()));
+      streamData.holds = cachedStreamHolds;
+      streams.set(i, streamData);
+      holds.put(streamData.stream, streamData);
+    }
+
+    WindmillStreamPool<TestWindmillStream> streamPool =
+        WindmillStreamPool.forTesting(
+            streamTimeout, () -> new TestWindmillStream(Instant.now()), 
streams, holds);
+    TestWindmillStream stream = streamPool.getStream();
+    assertEquals(cachedStreamHolds + 1, holds.get(stream).holds);
+  }
+
+  @Test
+  public void 
testGetStream_returnsAndCachesNewStream_whenOldStreamTimedOutAndDrained() {
+    Duration streamTimeout = Duration.ZERO;
+    Instant expired = Instant.EPOCH;
+    // Populate the stream data.
+    for (int i = 0; i < DEFAULT_NUM_STREAMS; i++) {
+      WindmillStreamPool.StreamData<TestWindmillStream> streamData =
+          new WindmillStreamPool.StreamData<>(new TestWindmillStream(expired));
+      streams.set(i, streamData);
+      holds.put(streamData.stream, streamData);
+    }
+
+    WindmillStreamPool<TestWindmillStream> streamPool =
+        WindmillStreamPool.forTesting(
+            streamTimeout, () -> new TestWindmillStream(Instant.now()), 
streams, holds);
+    TestWindmillStream stream = streamPool.getStream();
+    assertEquals(NEW_STREAM_HOLDS, holds.get(stream).holds);
+  }
+
+  @Test
+  public void testGetStream_closesInvalidStream() {
+    Duration streamTimeout = Duration.ZERO;
+    Instant expired = Instant.EPOCH;
+    WindmillStreamPool.StreamData<TestWindmillStream> streamData =
+        new WindmillStreamPool.StreamData<>(new TestWindmillStream(expired));
+    List<WindmillStreamPool.@Nullable StreamData<TestWindmillStream>> streams =
+        WindmillStreamPool.newStreamList(1);
+    streams.set(0, streamData);
+    holds.put(streamData.stream, streamData);
+
+    WindmillStreamPool<TestWindmillStream> streamPool =
+        WindmillStreamPool.forTesting(
+            streamTimeout, () -> new TestWindmillStream(Instant.now()), 
streams, holds);
+
+    TestWindmillStream stream = streamPool.getStream();
+    assertEquals(NEW_STREAM_HOLDS, holds.get(stream).holds);
+    assertTrue(streamData.stream.closed);
+    assertEquals(1, streams.size());
+    assertEquals(1, holds.size());
+  }
+
+  @Test
+  public void 
testGetStream_returnsNewAndCachesNewStream_whenOldStreamTimedOutAndNotDrained() 
{
+    int notDrained = 4;
+    Duration streamTimeout = Duration.ZERO;
+    Instant expired = Instant.EPOCH;
+
+    // Populate the stream data.
+    List<WindmillStreamPool.@Nullable StreamData<TestWindmillStream>> streams =
+        WindmillStreamPool.newStreamList(1);
+    WindmillStreamPool.StreamData<TestWindmillStream> expiredStreamData =
+        new WindmillStreamPool.StreamData<>(new TestWindmillStream(expired));
+    expiredStreamData.holds = notDrained;
+    streams.set(0, expiredStreamData);
+    holds.put(expiredStreamData.stream, expiredStreamData);
+
+    WindmillStreamPool<TestWindmillStream> streamPool =
+        WindmillStreamPool.forTesting(
+            streamTimeout, () -> new TestWindmillStream(Instant.now()), 
streams, holds);
+    TestWindmillStream newStream = streamPool.getStream();
+
+    assertEquals(NEW_STREAM_HOLDS, holds.get(newStream).holds);
+    assertEquals(2, holds.size());
+    assertEquals(1, streams.size());
+  }
+
+  @Test
+  public void testGetStream_doesNotCloseExpiredStream_whenNotDrained() {
+    int notDrained = 4;
+    Duration streamTimeout = Duration.ZERO;
+    Instant expired = Instant.EPOCH;
+
+    // Populate the stream data.
+    List<WindmillStreamPool.@Nullable StreamData<TestWindmillStream>> streams =
+        WindmillStreamPool.newStreamList(1);
+    WindmillStreamPool.StreamData<TestWindmillStream> expiredStreamData =
+        new WindmillStreamPool.StreamData<>(new TestWindmillStream(expired));
+    expiredStreamData.holds = notDrained;
+    streams.set(0, expiredStreamData);
+    holds.put(expiredStreamData.stream, expiredStreamData);
+
+    WindmillStreamPool<TestWindmillStream> streamPool =
+        WindmillStreamPool.forTesting(
+            streamTimeout, () -> new TestWindmillStream(Instant.now()), 
streams, holds);
+    TestWindmillStream newStream = streamPool.getStream();
+
+    assertNotSame(expiredStreamData.stream, newStream);
+    assertFalse(expiredStreamData.stream.closed);
+    assertEquals(notDrained - 1, expiredStreamData.holds);
+    assertEquals(2, holds.size());
+    assertEquals(1, streams.size());
+    assertTrue(holds.containsKey(expiredStreamData.stream));
+    assertFalse(streams.contains(expiredStreamData));
+  }
+
+  @Test
+  public void testReleaseStream_closesStream() {
+    Duration streamTimeout = Duration.standardDays(1);
+    WindmillStreamPool.StreamData<TestWindmillStream> streamData =
+        new WindmillStreamPool.StreamData<>(new 
TestWindmillStream(Instant.now()));
+    List<WindmillStreamPool.@Nullable StreamData<TestWindmillStream>> streams =
+        WindmillStreamPool.newStreamList(1);
+    streams.set(0, streamData);
+    holds.put(streamData.stream, streamData);
+
+    WindmillStreamPool<TestWindmillStream> streamPool =
+        WindmillStreamPool.forTesting(
+            streamTimeout, () -> new TestWindmillStream(Instant.now()), 
streams, holds);
+    TestWindmillStream stream = streamPool.getStream();
+    holds.get(stream).holds = 1;
+    streamPool.releaseStream(stream);
+    assertFalse(holds.containsKey(stream));
+    assertTrue(stream.closed);
+  }
+
+  @Test
+  public void testReleaseStream_doesNotCloseStream_ifStreamHasHolds() {
+    Duration streamTimeout = Duration.standardDays(1);
+    WindmillStreamPool.StreamData<TestWindmillStream> streamData =
+        new WindmillStreamPool.StreamData<>(new 
TestWindmillStream(Instant.now()));
+    List<WindmillStreamPool.@Nullable StreamData<TestWindmillStream>> streams =
+        WindmillStreamPool.newStreamList(1);
+    streams.set(0, streamData);
+    holds.put(streamData.stream, streamData);
+
+    WindmillStreamPool<TestWindmillStream> streamPool =
+        WindmillStreamPool.forTesting(
+            streamTimeout, () -> new TestWindmillStream(Instant.now()), 
streams, holds);
+    TestWindmillStream stream = streamPool.getStream();
+    streamPool.releaseStream(stream);
+    assertTrue(holds.containsKey(stream));
+    assertFalse(stream.closed);
+  }
+
+  @Test
+  public void 
testReleaseStream_throwsExceptionWhenAttemptingToReleaseUnheldStream() {
+    WindmillStreamPool<TestWindmillStream> streamPool =
+        WindmillStreamPool.forTesting(
+            Duration.ZERO, () -> new TestWindmillStream(Instant.now()), 
streams, holds);
+    TestWindmillStream unheldStream = new TestWindmillStream(Instant.now());
+    assertThrows(IllegalStateException.class, () -> 
streamPool.releaseStream(unheldStream));
+  }
+
+  private static class TestWindmillStream implements WindmillStream {
+    private final Instant startTime;
+    private boolean closed;
+
+    private TestWindmillStream(Instant startTime) {
+      this.startTime = startTime;
+      this.closed = false;
+    }
+
+    @Override
+    public void close() {
+      closed = true;
+    }
+
+    @Override
+    public boolean awaitTermination(int time, TimeUnit unit) {
+      return false;
+    }
+
+    @Override
+    public Instant startTime() {
+      return startTime;
+    }
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServerTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServerTest.java
index 77f046e7384..511b4e4895a 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServerTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServerTest.java
@@ -65,9 +65,9 @@ import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagValue;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.Value;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
-import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.CommitWorkStream;
-import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetDataStream;
-import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream;
 import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Server;
 import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Status;

Reply via email to