This is an automated email from the ASF dual-hosted git repository.
tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 9c548b34c7d pull out StreamPool/StreamData from WindmillServerStub
file. Organize streaming appliance files into their own directory. (#27593)
9c548b34c7d is described below
commit 9c548b34c7d8190b26840146b19e49e45505a888
Author: martin trieu <[email protected]>
AuthorDate: Mon Aug 14 17:41:59 2023 -0700
pull out StreamPool/StreamData from WindmillServerStub file. Organize
streaming appliance files into their own directory. (#27593)
---
.../worker/MetricTrackingWindmillServerStub.java | 59 +++--
.../dataflow/worker/StreamingDataflowWorker.java | 10 +-
.../options/StreamingDataflowWorkerOptions.java | 4 +-
.../worker/windmill/AbstractWindmillStream.java | 1 -
.../worker/windmill/WindmillServerBase.java | 12 +-
.../worker/windmill/WindmillServerStub.java | 163 +------------
.../dataflow/worker/windmill/WindmillStream.java | 89 ++++++++
.../worker/windmill/WindmillStreamPool.java | 181 +++++++++++++++
.../JniWindmillApplianceServer.java} | 13 +-
.../windmill/grpcclient/GrpcCommitWorkStream.java | 2 +-
.../windmill/grpcclient/GrpcGetDataStream.java | 2 +-
.../windmill/grpcclient/GrpcGetWorkStream.java | 4 +-
.../windmill/grpcclient/GrpcWindmillServer.java | 4 +
.../dataflow/worker/FakeWindmillServer.java | 117 +++++-----
.../worker/windmill/WindmillStreamPoolTest.java | 251 +++++++++++++++++++++
.../grpcclient/GrpcWindmillServerTest.java | 6 +-
16 files changed, 651 insertions(+), 267 deletions(-)
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
index b0624006b65..33b55647213 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricTrackingWindmillServerStub.java
@@ -29,7 +29,8 @@ import
org.apache.beam.runners.dataflow.worker.util.MemoryMonitor;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
-import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetDataStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillStreamPool;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.SettableFuture;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
@@ -47,6 +48,10 @@ import org.joda.time.Duration;
})
public class MetricTrackingWindmillServerStub {
+ private static final int MAX_READS_PER_BATCH = 60;
+ private static final int MAX_ACTIVE_READS = 10;
+ private static final int NUM_STREAMS = 1;
+ private static final Duration STREAM_TIMEOUT = Duration.standardSeconds(30);
private final AtomicInteger activeSideInputs = new AtomicInteger();
private final AtomicInteger activeStateReads = new AtomicInteger();
private final AtomicInteger activeHeartbeats = new AtomicInteger();
@@ -54,39 +59,13 @@ public class MetricTrackingWindmillServerStub {
private final MemoryMonitor gcThrashingMonitor;
private final boolean useStreamingRequests;
- private static final class ReadBatch {
- ArrayList<QueueEntry> reads = new ArrayList<>();
- SettableFuture<Boolean> startRead = SettableFuture.create();
- }
-
@GuardedBy("this")
private final List<ReadBatch> pendingReadBatches;
@GuardedBy("this")
private int activeReadThreads = 0;
- private WindmillServerStub.StreamPool<GetDataStream> streamPool;
-
- private static final int MAX_READS_PER_BATCH = 60;
- private static final int MAX_ACTIVE_READS = 10;
- private static final int NUM_STREAMS = 1;
- private static final Duration STREAM_TIMEOUT = Duration.standardSeconds(30);
-
- private static final class QueueEntry {
-
- final String computation;
- final Windmill.KeyedGetDataRequest request;
- final SettableFuture<Windmill.KeyedGetDataResponse> response;
-
- QueueEntry(
- String computation,
- Windmill.KeyedGetDataRequest request,
- SettableFuture<Windmill.KeyedGetDataResponse> response) {
- this.computation = computation;
- this.request = request;
- this.response = response;
- }
- }
+ private WindmillStreamPool<GetDataStream> streamPool;
public MetricTrackingWindmillServerStub(
WindmillServerStub server, MemoryMonitor gcThrashingMonitor, boolean
useStreamingRequests) {
@@ -100,8 +79,7 @@ public class MetricTrackingWindmillServerStub {
public void start() {
if (useStreamingRequests) {
streamPool =
- new WindmillServerStub.StreamPool<>(
- NUM_STREAMS, STREAM_TIMEOUT, this.server::getDataStream);
+ WindmillStreamPool.create(NUM_STREAMS, STREAM_TIMEOUT,
this.server::getDataStream);
}
}
@@ -300,4 +278,25 @@ public class MetricTrackingWindmillServerStub {
}
writer.println("Heartbeat Keys Active: " + activeHeartbeats.get());
}
+
+ private static final class ReadBatch {
+ ArrayList<QueueEntry> reads = new ArrayList<>();
+ SettableFuture<Boolean> startRead = SettableFuture.create();
+ }
+
+ private static final class QueueEntry {
+
+ final String computation;
+ final Windmill.KeyedGetDataRequest request;
+ final SettableFuture<Windmill.KeyedGetDataResponse> response;
+
+ QueueEntry(
+ String computation,
+ Windmill.KeyedGetDataRequest request,
+ SettableFuture<Windmill.KeyedGetDataResponse> response) {
+ this.computation = computation;
+ this.request = request;
+ this.response = response;
+ }
+ }
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index 1ffb4b94cf0..8629b711697 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -106,9 +106,9 @@ import
org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
-import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.CommitWorkStream;
-import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetWorkStream;
-import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.StreamPool;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillStreamPool;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.extensions.gcp.util.Transport;
@@ -1716,8 +1716,8 @@ public class StreamingDataflowWorker {
}
private void streamingCommitLoop() {
- StreamPool<CommitWorkStream> streamPool =
- new StreamPool<>(
+ WindmillStreamPool<CommitWorkStream> streamPool =
+ WindmillStreamPool.create(
NUM_COMMIT_STREAMS, COMMIT_STREAM_TIMEOUT,
windmillServer::commitWorkStream);
Commit initialCommit = null;
while (running.get()) {
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java
index 446934b6958..cc5b3302b01 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/options/StreamingDataflowWorkerOptions.java
@@ -19,8 +19,8 @@ package org.apache.beam.runners.dataflow.worker.options;
import java.io.IOException;
import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions;
-import org.apache.beam.runners.dataflow.worker.windmill.WindmillServer;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
+import
org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer;
import
org.apache.beam.runners.dataflow.worker.windmill.grpcclient.GrpcWindmillServer;
import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.DefaultValueFactory;
@@ -212,7 +212,7 @@ public interface StreamingDataflowWorkerOptions extends
DataflowWorkerHarnessOpt
throw new RuntimeException("Failed to create GrpcWindmillServer: ",
e);
}
} else {
- return new WindmillServer(streamingOptions.getLocalWindmillHostport());
+ return new
JniWindmillApplianceServer(streamingOptions.getLocalWindmillHostport());
}
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java
index 79d446e2d4e..d3e7de58931 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/AbstractWindmillStream.java
@@ -30,7 +30,6 @@ import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
-import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.WindmillStream;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Status;
import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.StatusRuntimeException;
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java
index b7d1507d378..fe81eece138 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerBase.java
@@ -19,11 +19,17 @@ package org.apache.beam.runners.dataflow.worker.windmill;
import java.io.IOException;
import java.util.Set;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream.WorkItemReceiver;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
/**
- * Implementation of a WindmillServerStub which communcates with an actual
windmill server at the
- * specified location.
+ * Implementation of a WindmillServerStub which communicates with a Windmill
appliance server.
+ *
+ * @implNote This is only for use in Streaming Appliance. Please do not change
the name or path of
+ * this class as this will break JNI.
*/
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
@@ -33,7 +39,7 @@ public class WindmillServerBase extends WindmillServerStub {
/** Pointer to the underlying native windmill client object. */
private final long nativePointer;
- WindmillServerBase(String host) {
+ protected WindmillServerBase(String host) {
this.nativePointer = create(host);
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
index b77b5a59cdf..1bb5359e06f 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java
@@ -19,26 +19,13 @@ package org.apache.beam.runners.dataflow.worker.windmill;
import java.io.IOException;
import java.io.PrintWriter;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
import java.util.Set;
-import java.util.concurrent.ThreadLocalRandom;
-import java.util.concurrent.TimeUnit;
-import java.util.function.Consumer;
-import java.util.function.Supplier;
-import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
-import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
-import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
-import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
-import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream.WorkItemReceiver;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
-import org.checkerframework.checker.nullness.qual.Nullable;
-import org.joda.time.Duration;
-import org.joda.time.Instant;
/** Stub for communicating with a Windmill server. */
@SuppressWarnings({
@@ -92,150 +79,10 @@ public abstract class WindmillServerStub implements
StatusDataProvider {
@Override
public void appendSummaryHtml(PrintWriter writer) {}
- /** Functional interface for receiving WorkItems. */
- @FunctionalInterface
- public interface WorkItemReceiver {
-
- void receiveWork(
- String computation,
- @Nullable Instant inputDataWatermark,
- @Nullable Instant synchronizedProcessingTime,
- Windmill.WorkItem workItem,
- Collection<LatencyAttribution> getWorkStreamLatencies);
- }
-
- /** Superclass for streams returned by streaming Windmill methods. */
- @ThreadSafe
- public interface WindmillStream {
- /** Indicates that no more requests will be sent. */
- void close();
-
- /** Waits for the server to close its end of the connection, with timeout.
*/
- boolean awaitTermination(int time, TimeUnit unit) throws
InterruptedException;
-
- /** Returns when the stream was opened. */
- Instant startTime();
- }
-
- /** Handle representing a stream of GetWork responses. */
- @ThreadSafe
- public interface GetWorkStream extends WindmillStream {}
-
- /** Interface for streaming GetDataRequests to Windmill. */
- @ThreadSafe
- public interface GetDataStream extends WindmillStream {
- /** Issues a keyed GetData fetch, blocking until the result is ready. */
- KeyedGetDataResponse requestKeyedData(String computation,
Windmill.KeyedGetDataRequest request);
-
- /** Issues a global GetData fetch, blocking until the result is ready. */
- Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request);
-
- /** Tells windmill processing is ongoing for the given keys. */
- void refreshActiveWork(Map<String, List<KeyedGetDataRequest>> active);
- }
-
- /** Interface for streaming CommitWorkRequests to Windmill. */
- @ThreadSafe
- public interface CommitWorkStream extends WindmillStream {
-
- /**
- * Commits a work item and running onDone when the commit has been
processed by the server.
- * Returns true if the request was accepted. If false is returned the
stream should be flushed
- * and the request recommitted.
- *
- * <p>onDone will be called with the status of the commit.
- */
- boolean commitWorkItem(
- String computation, Windmill.WorkItemCommitRequest request,
Consumer<CommitStatus> onDone);
-
- /** Flushes any pending work items to the wire. */
- void flush();
- }
-
- /**
- * Pool of homogeneous streams to Windmill.
- *
- * <p>The pool holds a fixed total number of streams, and keeps each stream
open for a specified
- * time to allow for better load-balancing.
- */
- @ThreadSafe
- public static class StreamPool<S extends WindmillStream> {
-
- private final Duration streamTimeout;
- private final List<StreamData> streams;
-
- private final Supplier<S> supplier;
- private final HashMap<S, StreamData> holds;
-
- public StreamPool(int numStreams, Duration streamTimeout, Supplier<S>
supplier) {
- this.streams = new ArrayList<>(numStreams);
- for (int i = 0; i < numStreams; i++) {
- streams.add(null);
- }
- this.streamTimeout = streamTimeout;
- this.supplier = supplier;
- this.holds = new HashMap<>();
- }
-
- // Returns a stream for use that may be cached from a previous call. Each
call of getStream
- // must be matched with a call of releaseStream.
- public S getStream() {
- int index = ThreadLocalRandom.current().nextInt(streams.size());
- S result;
- S closeStream = null;
- synchronized (this) {
- StreamData streamData = streams.get(index);
- if (streamData == null
- ||
streamData.stream.startTime().isBefore(Instant.now().minus(streamTimeout))) {
- if (streamData != null && --streamData.holds == 0) {
- holds.remove(streamData.stream);
- closeStream = streamData.stream;
- }
- streamData = new StreamData();
- streams.set(index, streamData);
- holds.put(streamData.stream, streamData);
- }
- streamData.holds++;
- result = streamData.stream;
- }
- if (closeStream != null) {
- closeStream.close();
- }
- return result;
- }
-
- // Releases a stream that was obtained with getStream.
- public void releaseStream(S stream) {
- boolean closeStream = false;
- synchronized (this) {
- if (--holds.get(stream).holds == 0) {
- closeStream = true;
- holds.remove(stream);
- }
- }
- if (closeStream) {
- stream.close();
- }
- }
-
- private final class StreamData {
- final S stream = supplier.get();
- int holds = 1;
- }
- }
-
/** Generic Exception type for implementors to use to represent errors while
making RPCs. */
- public static class RpcException extends RuntimeException {
- public RpcException() {
- super();
- }
-
+ public static final class RpcException extends RuntimeException {
public RpcException(Throwable cause) {
super(cause);
}
-
- public RpcException(String message, Throwable cause) {
- super(message, cause);
- }
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStream.java
new file mode 100644
index 00000000000..70c7cc36ba3
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStream.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+import javax.annotation.concurrent.ThreadSafe;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Instant;
+
+/** Superclass for streams returned by streaming Windmill methods. */
+@ThreadSafe
+public interface WindmillStream {
+ /** Indicates that no more requests will be sent. */
+ void close();
+
+ /** Waits for the server to close its end of the connection, with timeout. */
+ boolean awaitTermination(int time, TimeUnit unit) throws
InterruptedException;
+
+ /** Returns when the stream was opened. */
+ Instant startTime();
+
+ /** Handle representing a stream of GetWork responses. */
+ @ThreadSafe
+ interface GetWorkStream extends WindmillStream {
+ /** Functional interface for receiving WorkItems. */
+ @FunctionalInterface
+ interface WorkItemReceiver {
+ void receiveWork(
+ String computation,
+ @Nullable Instant inputDataWatermark,
+ @Nullable Instant synchronizedProcessingTime,
+ Windmill.WorkItem workItem,
+ Collection<Windmill.LatencyAttribution> getWorkStreamLatencies);
+ }
+ }
+
+ /** Interface for streaming GetDataRequests to Windmill. */
+ @ThreadSafe
+ interface GetDataStream extends WindmillStream {
+ /** Issues a keyed GetData fetch, blocking until the result is ready. */
+ Windmill.KeyedGetDataResponse requestKeyedData(
+ String computation, Windmill.KeyedGetDataRequest request);
+
+ /** Issues a global GetData fetch, blocking until the result is ready. */
+ Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request);
+
+ /** Tells windmill processing is ongoing for the given keys. */
+ void refreshActiveWork(Map<String, List<Windmill.KeyedGetDataRequest>>
active);
+ }
+
+ /** Interface for streaming CommitWorkRequests to Windmill. */
+ @ThreadSafe
+ interface CommitWorkStream extends WindmillStream {
+
+ /**
+ * Commits a work item and running onDone when the commit has been
processed by the server.
+ * Returns true if the request was accepted. If false is returned the
stream should be flushed
+ * and the request recommitted.
+ *
+ * <p>onDone will be called with the status of the commit.
+ */
+ boolean commitWorkItem(
+ String computation,
+ Windmill.WorkItemCommitRequest request,
+ Consumer<Windmill.CommitStatus> onDone);
+
+ /** Flushes any pending work items to the wire. */
+ void flush();
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStreamPool.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStreamPool.java
new file mode 100644
index 00000000000..9cd4ab0ea4a
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStreamPool.java
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.function.Supplier;
+import javax.annotation.concurrent.GuardedBy;
+import javax.annotation.concurrent.ThreadSafe;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+
+/**
+ * Pool of homogeneous streams to Windmill.
+ *
+ * <p>The pool holds a fixed total number of streams, and keeps each stream
open for a specified
+ * time to allow for better load-balancing.
+ */
+@ThreadSafe
+public class WindmillStreamPool<StreamT extends WindmillStream> {
+
+ private final Duration streamTimeout;
+ private final Supplier<StreamT> streamSupplier;
+
+ /** @implNote Size of streams never changes once initialized. */
+ private final List<@Nullable StreamData<StreamT>> streams;
+
+ @GuardedBy("this")
+ private final Map<StreamT, StreamData<StreamT>> holds;
+
+ private WindmillStreamPool(
+ Duration streamTimeout,
+ Supplier<StreamT> streamSupplier,
+ List<@Nullable StreamData<StreamT>> streams,
+ Map<StreamT, StreamData<StreamT>> holds) {
+ this.streams = streams;
+ this.streamTimeout = streamTimeout;
+ this.streamSupplier = streamSupplier;
+ this.holds = holds;
+ }
+
+ public static <StreamT extends WindmillStream> WindmillStreamPool<StreamT>
create(
+ int numStreams, Duration streamTimeout, Supplier<StreamT>
streamSupplier) {
+ return new WindmillStreamPool<>(
+ streamTimeout, streamSupplier, newStreamList(numStreams), new
HashMap<>());
+ }
+
+ @VisibleForTesting
+ static <StreamT extends WindmillStream> WindmillStreamPool<StreamT>
forTesting(
+ Duration streamTimeout,
+ Supplier<StreamT> streamSupplier,
+ List<@Nullable StreamData<StreamT>> streamPool,
+ Map<StreamT, StreamData<StreamT>> holds) {
+ return new WindmillStreamPool<>(streamTimeout, streamSupplier, streamPool,
holds);
+ }
+
+ /**
+ * Creates a new list of streams of the given capacity with all values
initialized to null. This
+ * is because we randomly load balance across all the streams in the pool.
+ */
+ @VisibleForTesting
+ static <StreamT extends WindmillStream> List<@Nullable StreamData<StreamT>>
newStreamList(
+ int numStreams) {
+ List<@Nullable StreamData<StreamT>> streamPool = new
ArrayList<>(numStreams);
+ for (int i = 0; i < numStreams; i++) {
+ streamPool.add(null);
+ }
+ return streamPool;
+ }
+
+ /**
+ * Returns a stream for use that may be cached from a previous call. Each
call of getStream must
+ * be matched with a call of {@link
WindmillStreamPool#releaseStream(WindmillStream)}. If the
+ * stream has been cached but has timed out and drained (no longer has any
holds), the stream will
+ * be closed.
+ */
+ public StreamT getStream() {
+ int index = streams.size() == 1 ? 0 :
ThreadLocalRandom.current().nextInt(streams.size());
+ // We will return this stream
+ StreamT resultStream;
+ StreamT closeThisStream = null;
+ try {
+ synchronized (this) {
+ WindmillStreamPool.StreamData<StreamT> existingStreamData =
streams.get(index);
+ // There are 3 possible states that can result from fetching the
stream from the cache.
+ if (existingStreamData == null) {
+ // 1. Stream doesn't exist create and cache a new one.
+ resultStream = createAndCacheStream(index).stream;
+ } else if (existingStreamData.hasTimedOut(streamTimeout)) {
+ // 2. The stream exists, but has timed out. The timed out stream is
not returned (a new
+ // one is created and returned here) and evicted from the cache if
the stream has
+ // completely drained. Every call to getStream(), is matched with a
call to
+ // releaseStream(), so the stream will eventually drain and be
closed.
+ if (--existingStreamData.holds == 0) {
+ holds.remove(existingStreamData.stream);
+ closeThisStream = existingStreamData.stream;
+ }
+ // Create and cache a new stream at the timed out stream's index.
+ resultStream = createAndCacheStream(index).stream;
+ } else {
+ // 3. The stream exists and is in a valid state.
+ existingStreamData.holds++;
+ resultStream = existingStreamData.stream;
+ }
+ }
+ return resultStream;
+ } finally {
+ if (closeThisStream != null) {
+ closeThisStream.close();
+ }
+ }
+ }
+
+ private synchronized WindmillStreamPool.StreamData<StreamT>
createAndCacheStream(int cacheKey) {
+ WindmillStreamPool.StreamData<StreamT> newStreamData =
+ new WindmillStreamPool.StreamData<>(streamSupplier.get());
+ newStreamData.holds++;
+ streams.set(cacheKey, newStreamData);
+ holds.put(newStreamData.stream, newStreamData);
+ return newStreamData;
+ }
+
+ /** Releases a stream that was obtained with {@link
WindmillStreamPool#getStream()}. */
+ public void releaseStream(StreamT stream) {
+ boolean closeStream = false;
+ synchronized (this) {
+ StreamData<StreamT> streamData = holds.get(stream);
+ // All streams that are created by an instance of a pool will be present.
+ if (streamData == null) {
+ throw new IllegalStateException(
+ "Attempted to release stream that does not exist in this pool.
This stream "
+ + "may not have been created by this pool or may have been
released more "
+ + "times than acquired.");
+ }
+ if (--streamData.holds == 0) {
+ closeStream = true;
+ holds.remove(stream);
+ }
+ }
+
+ if (closeStream) {
+ stream.close();
+ }
+ }
+
+ @VisibleForTesting
+ static final class StreamData<StreamT extends WindmillStream> {
+ final StreamT stream;
+ int holds;
+
+ @VisibleForTesting
+ StreamData(StreamT stream) {
+ this.stream = stream;
+ holds = 1;
+ }
+
+ private boolean hasTimedOut(Duration timeout) {
+ return stream.startTime().isBefore(Instant.now().minus(timeout));
+ }
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServer.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/appliance/JniWindmillApplianceServer.java
similarity index 83%
rename from
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServer.java
rename to
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/appliance/JniWindmillApplianceServer.java
index fbc67987811..8385c7cb597 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServer.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/appliance/JniWindmillApplianceServer.java
@@ -15,19 +15,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.beam.runners.dataflow.worker.windmill;
+package org.apache.beam.runners.dataflow.worker.windmill.appliance;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerBase;
-/** Implementation of a WindmillServerBase. */
+/**
+ * JNI Implementation of a {@link WindmillServerBase}.
+ *
+ * @implNote This is only for use in Streaming Appliance.
+ */
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
-public class WindmillServer extends WindmillServerBase {
+public class JniWindmillApplianceServer extends WindmillServerBase {
private static final String WINDMILL_SERVER_JNI_LIBRARY_PROPERTY =
"windmill.jni_library";
private static final String DEFAULT_SHUFFLE_CLIENT_LIBRARY =
"libwindmill_service_jni.so";
@@ -50,7 +55,7 @@ public class WindmillServer extends WindmillServerBase {
* The host should be specified as protocol://address:port to connect to a
windmill server through
* rpcz.
*/
- public WindmillServer(String host) {
+ public JniWindmillApplianceServer(String host) {
super(host);
}
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java
index 0e444844f0f..74bd93a5474 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcCommitWorkStream.java
@@ -34,7 +34,7 @@ import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommit
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitResponse;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitWorkRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
-import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.CommitWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java
index 3f4f6b0a922..b51daabb1a2 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetDataStream.java
@@ -41,7 +41,7 @@ import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataReq
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataResponse;
-import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetDataStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
import
org.apache.beam.runners.dataflow.worker.windmill.grpcclient.GrpcGetDataStreamRequests.QueuedBatch;
import
org.apache.beam.runners.dataflow.worker.windmill.grpcclient.GrpcGetDataStreamRequests.QueuedRequest;
import org.apache.beam.sdk.util.BackOff;
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java
index d0edcc45828..6e35beccdb6 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkStream.java
@@ -36,8 +36,8 @@ import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribut
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequest;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequestExtension;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkResponseChunk;
-import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetWorkStream;
-import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.WorkItemReceiver;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream.WorkItemReceiver;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
import org.joda.time.Instant;
diff --git
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java
index 986a7e22eeb..5003f9948c5 100644
---
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java
+++
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServer.java
@@ -52,6 +52,10 @@ import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ReportStatsRequ
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.ReportStatsResponse;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillApplianceGrpc;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream.WorkItemReceiver;
import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.util.BackOff;
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
index 83046db3f89..4700217dc8a 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
@@ -53,6 +53,9 @@ import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribut
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.LatencyAttribution.State;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles;
import org.joda.time.Duration;
@@ -64,58 +67,6 @@ import org.slf4j.LoggerFactory;
/** An in-memory Windmill server that offers provided work and data. */
class FakeWindmillServer extends WindmillServerStub {
private static final Logger LOG =
LoggerFactory.getLogger(FakeWindmillServer.class);
-
- static class ResponseQueue<T, U> {
- private final Queue<Function<T, U>> responses = new
ConcurrentLinkedQueue<>();
- private Function<T, U> defaultResponse;
- Duration sleep = Duration.ZERO;
-
- // (Fluent) interface for response producers, accessible from tests.
-
- ResponseQueue<T, U> thenAnswer(Function<T, U> mapFun) {
- responses.add(mapFun);
- return this;
- }
-
- ResponseQueue<T, U> thenReturn(U response) {
- return thenAnswer((request) -> response);
- }
-
- ResponseQueue<T, U> answerByDefault(Function<T, U> mapFun) {
- defaultResponse = mapFun;
- return this;
- }
-
- ResponseQueue<T, U> returnByDefault(U response) {
- return answerByDefault((request) -> response);
- }
-
- ResponseQueue<T, U> delayEachResponseBy(Duration sleep) {
- this.sleep = sleep;
- return this;
- }
-
- // Interface for response consumers, accessible from the enclosing class.
-
- private U getOrDefault(T request) {
- Function<T, U> mapFun = responses.poll();
- U response = mapFun == null ? defaultResponse.apply(request) :
mapFun.apply(request);
- Uninterruptibles.sleepUninterruptibly(sleep.getMillis(),
TimeUnit.MILLISECONDS);
- return response;
- }
-
- private U get(T request) {
- Function<T, U> mapFun = responses.poll();
- U response = mapFun == null ? null : mapFun.apply(request);
- Uninterruptibles.sleepUninterruptibly(sleep.getMillis(),
TimeUnit.MILLISECONDS);
- return response;
- }
-
- private boolean isEmpty() {
- return responses.isEmpty();
- }
- }
-
private final ResponseQueue<Windmill.GetWorkRequest,
Windmill.GetWorkResponse> workToOffer;
private final ResponseQueue<GetDataRequest, GetDataResponse> dataToOffer;
private final ResponseQueue<Windmill.CommitWorkRequest, CommitWorkResponse>
commitsToOffer;
@@ -123,13 +74,13 @@ class FakeWindmillServer extends WindmillServerStub {
private final Map<Long, WorkItemCommitRequest> commitsReceived;
private final ArrayList<Windmill.ReportStatsRequest> statsReceived;
private final LinkedBlockingQueue<Windmill.Exception> exceptions;
- private int commitsRequested = 0;
- private int numGetDataRequests = 0;
private final AtomicInteger expectedExceptionCount;
private final ErrorCollector errorCollector;
+ private final ConcurrentHashMap<Long, Consumer<Windmill.CommitStatus>>
droppedStreamingCommits;
+ private int commitsRequested = 0;
+ private int numGetDataRequests = 0;
private boolean isReady = true;
private boolean dropStreamingCommits = false;
- private final ConcurrentHashMap<Long, Consumer<Windmill.CommitStatus>>
droppedStreamingCommits;
public FakeWindmillServer(ErrorCollector errorCollector) {
workToOffer =
@@ -243,11 +194,12 @@ class FakeWindmillServer extends WindmillServerStub {
@Override
public long getAndResetThrottleTime() {
- return (long) 0;
+ return 0;
}
@Override
- public GetWorkStream getWorkStream(Windmill.GetWorkRequest request,
WorkItemReceiver receiver) {
+ public GetWorkStream getWorkStream(
+ Windmill.GetWorkRequest request, GetWorkStream.WorkItemReceiver
receiver) {
LOG.debug("getWorkStream: {}", request.toString());
Instant startTime = Instant.now();
final CountDownLatch done = new CountDownLatch(1);
@@ -485,4 +437,55 @@ class FakeWindmillServer extends WindmillServerStub {
public void setIsReady(boolean ready) {
this.isReady = ready;
}
+
+ static class ResponseQueue<T, U> {
+ private final Queue<Function<T, U>> responses = new
ConcurrentLinkedQueue<>();
+ Duration sleep = Duration.ZERO;
+ private Function<T, U> defaultResponse;
+
+ // (Fluent) interface for response producers, accessible from tests.
+
+ ResponseQueue<T, U> thenAnswer(Function<T, U> mapFun) {
+ responses.add(mapFun);
+ return this;
+ }
+
+ ResponseQueue<T, U> thenReturn(U response) {
+ return thenAnswer((request) -> response);
+ }
+
+ ResponseQueue<T, U> answerByDefault(Function<T, U> mapFun) {
+ defaultResponse = mapFun;
+ return this;
+ }
+
+ ResponseQueue<T, U> returnByDefault(U response) {
+ return answerByDefault((request) -> response);
+ }
+
+ ResponseQueue<T, U> delayEachResponseBy(Duration sleep) {
+ this.sleep = sleep;
+ return this;
+ }
+
+ // Interface for response consumers, accessible from the enclosing class.
+
+ private U getOrDefault(T request) {
+ Function<T, U> mapFun = responses.poll();
+ U response = mapFun == null ? defaultResponse.apply(request) :
mapFun.apply(request);
+ Uninterruptibles.sleepUninterruptibly(sleep.getMillis(),
TimeUnit.MILLISECONDS);
+ return response;
+ }
+
+ private U get(T request) {
+ Function<T, U> mapFun = responses.poll();
+ U response = mapFun == null ? null : mapFun.apply(request);
+ Uninterruptibles.sleepUninterruptibly(sleep.getMillis(),
TimeUnit.MILLISECONDS);
+ return response;
+ }
+
+ private boolean isEmpty() {
+ return responses.isEmpty();
+ }
+ }
}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStreamPoolTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStreamPoolTest.java
new file mode 100644
index 00000000000..9924bb7d2b2
--- /dev/null
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillStreamPoolTest.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotSame;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+import java.util.List;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class WindmillStreamPoolTest {
+ private static final int DEFAULT_NUM_STREAMS = 10;
+ private static final int NEW_STREAM_HOLDS = 2;
+ private final ConcurrentHashMap<
+ TestWindmillStream,
WindmillStreamPool.StreamData<TestWindmillStream>>
+ holds = new ConcurrentHashMap<>();
+ private List<WindmillStreamPool.@Nullable StreamData<TestWindmillStream>>
streams;
+
+ @Before
+ public void setUp() {
+ streams = WindmillStreamPool.newStreamList(DEFAULT_NUM_STREAMS);
+ holds.clear();
+ }
+
+ @Test
+ public void testGetStream_returnsAndCachesNewStream() {
+ Duration streamTimeout = Duration.standardSeconds(1);
+ WindmillStreamPool<TestWindmillStream> streamPool =
+ WindmillStreamPool.forTesting(
+ streamTimeout, () -> new TestWindmillStream(Instant.now()),
streams, holds);
+ TestWindmillStream stream = streamPool.getStream();
+ assertTrue(holds.containsKey(stream));
+ assertEquals(2, holds.get(stream).holds);
+ assertTrue(streams.contains(holds.get(stream)));
+ }
+
+ @Test
+ public void testGetStream_returnsCachedStreamAndIncrementsHolds() {
+ Duration streamTimeout = Duration.standardDays(1);
+ int cachedStreamHolds = 2;
+ // Populate the stream data.
+ for (int i = 0; i < DEFAULT_NUM_STREAMS; i++) {
+ WindmillStreamPool.StreamData<TestWindmillStream> streamData =
+ new WindmillStreamPool.StreamData<>(new
TestWindmillStream(Instant.now()));
+ streamData.holds = cachedStreamHolds;
+ streams.set(i, streamData);
+ holds.put(streamData.stream, streamData);
+ }
+
+ WindmillStreamPool<TestWindmillStream> streamPool =
+ WindmillStreamPool.forTesting(
+ streamTimeout, () -> new TestWindmillStream(Instant.now()),
streams, holds);
+ TestWindmillStream stream = streamPool.getStream();
+ assertEquals(cachedStreamHolds + 1, holds.get(stream).holds);
+ }
+
+ @Test
+ public void
testGetStream_returnsAndCachesNewStream_whenOldStreamTimedOutAndDrained() {
+ Duration streamTimeout = Duration.ZERO;
+ Instant expired = Instant.EPOCH;
+ // Populate the stream data.
+ for (int i = 0; i < DEFAULT_NUM_STREAMS; i++) {
+ WindmillStreamPool.StreamData<TestWindmillStream> streamData =
+ new WindmillStreamPool.StreamData<>(new TestWindmillStream(expired));
+ streams.set(i, streamData);
+ holds.put(streamData.stream, streamData);
+ }
+
+ WindmillStreamPool<TestWindmillStream> streamPool =
+ WindmillStreamPool.forTesting(
+ streamTimeout, () -> new TestWindmillStream(Instant.now()),
streams, holds);
+ TestWindmillStream stream = streamPool.getStream();
+ assertEquals(NEW_STREAM_HOLDS, holds.get(stream).holds);
+ }
+
+ @Test
+ public void testGetStream_closesInvalidStream() {
+ Duration streamTimeout = Duration.ZERO;
+ Instant expired = Instant.EPOCH;
+ WindmillStreamPool.StreamData<TestWindmillStream> streamData =
+ new WindmillStreamPool.StreamData<>(new TestWindmillStream(expired));
+ List<WindmillStreamPool.@Nullable StreamData<TestWindmillStream>> streams =
+ WindmillStreamPool.newStreamList(1);
+ streams.set(0, streamData);
+ holds.put(streamData.stream, streamData);
+
+ WindmillStreamPool<TestWindmillStream> streamPool =
+ WindmillStreamPool.forTesting(
+ streamTimeout, () -> new TestWindmillStream(Instant.now()),
streams, holds);
+
+ TestWindmillStream stream = streamPool.getStream();
+ assertEquals(NEW_STREAM_HOLDS, holds.get(stream).holds);
+ assertTrue(streamData.stream.closed);
+ assertEquals(1, streams.size());
+ assertEquals(1, holds.size());
+ }
+
+ @Test
+ public void
testGetStream_returnsNewAndCachesNewStream_whenOldStreamTimedOutAndNotDrained()
{
+ int notDrained = 4;
+ Duration streamTimeout = Duration.ZERO;
+ Instant expired = Instant.EPOCH;
+
+ // Populate the stream data.
+ List<WindmillStreamPool.@Nullable StreamData<TestWindmillStream>> streams =
+ WindmillStreamPool.newStreamList(1);
+ WindmillStreamPool.StreamData<TestWindmillStream> expiredStreamData =
+ new WindmillStreamPool.StreamData<>(new TestWindmillStream(expired));
+ expiredStreamData.holds = notDrained;
+ streams.set(0, expiredStreamData);
+ holds.put(expiredStreamData.stream, expiredStreamData);
+
+ WindmillStreamPool<TestWindmillStream> streamPool =
+ WindmillStreamPool.forTesting(
+ streamTimeout, () -> new TestWindmillStream(Instant.now()),
streams, holds);
+ TestWindmillStream newStream = streamPool.getStream();
+
+ assertEquals(NEW_STREAM_HOLDS, holds.get(newStream).holds);
+ assertEquals(2, holds.size());
+ assertEquals(1, streams.size());
+ }
+
+ @Test
+ public void testGetStream_doesNotCloseExpiredStream_whenNotDrained() {
+ int notDrained = 4;
+ Duration streamTimeout = Duration.ZERO;
+ Instant expired = Instant.EPOCH;
+
+ // Populate the stream data.
+ List<WindmillStreamPool.@Nullable StreamData<TestWindmillStream>> streams =
+ WindmillStreamPool.newStreamList(1);
+ WindmillStreamPool.StreamData<TestWindmillStream> expiredStreamData =
+ new WindmillStreamPool.StreamData<>(new TestWindmillStream(expired));
+ expiredStreamData.holds = notDrained;
+ streams.set(0, expiredStreamData);
+ holds.put(expiredStreamData.stream, expiredStreamData);
+
+ WindmillStreamPool<TestWindmillStream> streamPool =
+ WindmillStreamPool.forTesting(
+ streamTimeout, () -> new TestWindmillStream(Instant.now()),
streams, holds);
+ TestWindmillStream newStream = streamPool.getStream();
+
+ assertNotSame(expiredStreamData.stream, newStream);
+ assertFalse(expiredStreamData.stream.closed);
+ assertEquals(notDrained - 1, expiredStreamData.holds);
+ assertEquals(2, holds.size());
+ assertEquals(1, streams.size());
+ assertTrue(holds.containsKey(expiredStreamData.stream));
+ assertFalse(streams.contains(expiredStreamData));
+ }
+
+ @Test
+ public void testReleaseStream_closesStream() {
+ Duration streamTimeout = Duration.standardDays(1);
+ WindmillStreamPool.StreamData<TestWindmillStream> streamData =
+ new WindmillStreamPool.StreamData<>(new
TestWindmillStream(Instant.now()));
+ List<WindmillStreamPool.@Nullable StreamData<TestWindmillStream>> streams =
+ WindmillStreamPool.newStreamList(1);
+ streams.set(0, streamData);
+ holds.put(streamData.stream, streamData);
+
+ WindmillStreamPool<TestWindmillStream> streamPool =
+ WindmillStreamPool.forTesting(
+ streamTimeout, () -> new TestWindmillStream(Instant.now()),
streams, holds);
+ TestWindmillStream stream = streamPool.getStream();
+ holds.get(stream).holds = 1;
+ streamPool.releaseStream(stream);
+ assertFalse(holds.containsKey(stream));
+ assertTrue(stream.closed);
+ }
+
+ @Test
+ public void testReleaseStream_doesNotCloseStream_ifStreamHasHolds() {
+ Duration streamTimeout = Duration.standardDays(1);
+ WindmillStreamPool.StreamData<TestWindmillStream> streamData =
+ new WindmillStreamPool.StreamData<>(new
TestWindmillStream(Instant.now()));
+ List<WindmillStreamPool.@Nullable StreamData<TestWindmillStream>> streams =
+ WindmillStreamPool.newStreamList(1);
+ streams.set(0, streamData);
+ holds.put(streamData.stream, streamData);
+
+ WindmillStreamPool<TestWindmillStream> streamPool =
+ WindmillStreamPool.forTesting(
+ streamTimeout, () -> new TestWindmillStream(Instant.now()),
streams, holds);
+ TestWindmillStream stream = streamPool.getStream();
+ streamPool.releaseStream(stream);
+ assertTrue(holds.containsKey(stream));
+ assertFalse(stream.closed);
+ }
+
+ @Test
+ public void
testReleaseStream_throwsExceptionWhenAttemptingToReleaseUnheldStream() {
+ WindmillStreamPool<TestWindmillStream> streamPool =
+ WindmillStreamPool.forTesting(
+ Duration.ZERO, () -> new TestWindmillStream(Instant.now()),
streams, holds);
+ TestWindmillStream unheldStream = new TestWindmillStream(Instant.now());
+ assertThrows(IllegalStateException.class, () ->
streamPool.releaseStream(unheldStream));
+ }
+
+ private static class TestWindmillStream implements WindmillStream {
+ private final Instant startTime;
+ private boolean closed;
+
+ private TestWindmillStream(Instant startTime) {
+ this.startTime = startTime;
+ this.closed = false;
+ }
+
+ @Override
+ public void close() {
+ closed = true;
+ }
+
+ @Override
+ public boolean awaitTermination(int time, TimeUnit unit) {
+ return false;
+ }
+
+ @Override
+ public Instant startTime() {
+ return startTime;
+ }
+ }
+}
diff --git
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServerTest.java
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServerTest.java
index 77f046e7384..511b4e4895a 100644
---
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServerTest.java
+++
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcWindmillServerTest.java
@@ -65,9 +65,9 @@ import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagValue;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.Value;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
import
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
-import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.CommitWorkStream;
-import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetDataStream;
-import
org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.CommitWorkStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetDataStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkStream;
import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Status;