This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 89795c0e7c7 Make GrpcCommitWorkStream thread-safe as documented by 
moving batcher out of it. (#31304)
89795c0e7c7 is described below

commit 89795c0e7c731f37f737c908ede9cb9d3b578a24
Author: Sam Whittle <[email protected]>
AuthorDate: Tue May 21 10:53:04 2024 +0200

    Make GrpcCommitWorkStream thread-safe as documented by moving batcher out 
of it. (#31304)
    
    Also increase the number of streams in commit cache to number of threads
---
 .../dataflow/worker/StreamingDataflowWorker.java   |   3 +-
 .../worker/windmill/client/WindmillStream.java     |  39 +++++--
 .../commits/StreamingApplianceWorkCommitter.java   |   3 +-
 .../commits/StreamingEngineWorkCommitter.java      |  43 ++++----
 .../windmill/client/grpc/GrpcCommitWorkStream.java |  72 ++++++------
 .../dataflow/worker/FakeWindmillServer.java        | 104 +++++++++++-------
 .../commits/StreamingEngineWorkCommitterTest.java  |  57 ++++++++--
 .../client/grpc/GrpcWindmillServerTest.java        | 121 +++++++++++++++------
 8 files changed, 290 insertions(+), 152 deletions(-)

diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index f18d5fac721..78b7d3c9dfe 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -155,7 +155,6 @@ public class StreamingDataflowWorker {
   // Maximum number of threads for processing.  Currently each thread 
processes one key at a time.
   static final int MAX_PROCESSING_THREADS = 300;
   static final long THREAD_EXPIRATION_TIME_SEC = 60;
-  static final int NUM_COMMIT_STREAMS = 1;
   static final int GET_WORK_STREAM_TIMEOUT_MINUTES = 3;
   static final Duration COMMIT_STREAM_TIMEOUT = Duration.standardMinutes(1);
 
@@ -280,7 +279,7 @@ public class StreamingDataflowWorker {
         windmillServiceEnabled
             ? StreamingEngineWorkCommitter.create(
                 WindmillStreamPool.create(
-                        NUM_COMMIT_STREAMS, COMMIT_STREAM_TIMEOUT, 
windmillServer::commitWorkStream)
+                        numCommitThreads, COMMIT_STREAM_TIMEOUT, 
windmillServer::commitWorkStream)
                     ::getCloseableStream,
                 numCommitThreads,
                 this::onCompleteCommit)
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
index 7c22f4fb576..d044e930079 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java
@@ -17,10 +17,12 @@
  */
 package org.apache.beam.runners.dataflow.worker.windmill.client;
 
+import java.io.Closeable;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Consumer;
+import javax.annotation.concurrent.NotThreadSafe;
 import javax.annotation.concurrent.ThreadSafe;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
 import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest;
@@ -68,21 +70,34 @@ public interface WindmillStream {
   /** Interface for streaming CommitWorkRequests to Windmill. */
   @ThreadSafe
   interface CommitWorkStream extends WindmillStream {
+    @NotThreadSafe
+    interface RequestBatcher extends Closeable {
+      /**
+       * Commits a work item and running onDone when the commit has been 
processed by the server.
+       * Returns true if the request was accepted. If false is returned the 
stream should be flushed
+       * and the request recommitted.
+       *
+       * <p>onDone will be called with the status of the commit.
+       */
+      boolean commitWorkItem(
+          String computation,
+          Windmill.WorkItemCommitRequest request,
+          Consumer<Windmill.CommitStatus> onDone);
+
+      /** Flushes any pending work items to the wire. */
+      void flush();
+
+      @Override
+      default void close() {
+        flush();
+      }
+    }
 
     /**
-     * Commits a work item and running onDone when the commit has been 
processed by the server.
-     * Returns true if the request was accepted. If false is returned the 
stream should be flushed
-     * and the request recommitted.
-     *
-     * <p>onDone will be called with the status of the commit.
+     * Returns a builder that can be used for sending requests. Each builder 
is not thread-safe but
+     * different builders for the same stream may be used simultaneously.
      */
-    boolean commitWorkItem(
-        String computation,
-        Windmill.WorkItemCommitRequest request,
-        Consumer<Windmill.CommitStatus> onDone);
-
-    /** Flushes any pending work items to the wire. */
-    void flush();
+    RequestBatcher batcher();
   }
 
   /** Interface for streaming GetWorkerMetadata requests to Windmill. */
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
index 344f04cfd00..d092ebf53fc 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java
@@ -112,8 +112,7 @@ public final class StreamingApplianceWorkCommitter 
implements WorkCommitter {
       try {
         commit = commitQueue.take();
       } catch (InterruptedException e) {
-        Thread.currentThread().interrupt();
-        continue;
+        return;
       }
       while (commit != null) {
         ComputationState computationState = commit.computationState();
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
index f6088acf011..ed4dcfa212f 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java
@@ -104,14 +104,15 @@ public final class StreamingEngineWorkCommitter 
implements WorkCommitter {
 
   @Override
   public void stop() {
-    if (!commitSenders.isTerminated() || !commitSenders.isShutdown()) {
-      commitSenders.shutdown();
+    if (!commitSenders.isTerminated()) {
+      commitSenders.shutdownNow();
       try {
         commitSenders.awaitTermination(10, TimeUnit.SECONDS);
       } catch (InterruptedException e) {
-        LOG.warn("Could not shut down commitSenders gracefully, forcing 
shutdown.", e);
+        LOG.warn(
+            "Commit senders didn't complete shutdown within 10 seconds, 
continuing to drain queue",
+            e);
       }
-      commitSenders.shutdownNow();
     }
     drainCommitQueue();
   }
@@ -143,9 +144,10 @@ public final class StreamingEngineWorkCommitter implements 
WorkCommitter {
             // Block until we have a commit or are shutting down.
             initialCommit = commitQueue.take();
           } catch (InterruptedException e) {
-            continue;
+            return;
           }
         }
+        Preconditions.checkNotNull(initialCommit);
 
         if (initialCommit.work().isFailed()) {
           onCommitComplete.accept(CompleteCommit.forFailedWork(initialCommit));
@@ -156,15 +158,17 @@ public final class StreamingEngineWorkCommitter 
implements WorkCommitter {
         try (CloseableStream<CommitWorkStream> closeableCommitStream =
             commitWorkStreamFactory.get()) {
           CommitWorkStream commitStream = closeableCommitStream.stream();
-          if (!tryAddToCommitStream(initialCommit, commitStream)) {
-            throw new AssertionError("Initial commit on flushed stream should 
always be accepted.");
+          try (CommitWorkStream.RequestBatcher batcher = 
commitStream.batcher()) {
+            if (!tryAddToCommitBatch(initialCommit, batcher)) {
+              throw new AssertionError(
+                  "Initial commit on flushed stream should always be 
accepted.");
+            }
+            // Batch additional commits to the stream and possibly make an 
un-batched commit the
+            // next initial commit.
+            initialCommit = expandBatch(batcher);
           }
-          // Batch additional commits to the stream and possibly make an 
un-batched commit the next
-          // initial commit.
-          initialCommit = batchCommitsToStream(commitStream);
-          commitStream.flush();
         } catch (Exception e) {
-          LOG.error("Error occurred fetching a CommitWorkStream.", e);
+          LOG.error("Error occurred sending commits.", e);
         }
       }
     } finally {
@@ -174,13 +178,13 @@ public final class StreamingEngineWorkCommitter 
implements WorkCommitter {
     }
   }
 
-  /** Adds the commit to the commitStream if it fits, returning true if it is 
consumed. */
-  private boolean tryAddToCommitStream(Commit commit, CommitWorkStream 
commitStream) {
+  /** Adds the commit to the batch if it fits, returning true if it is 
consumed. */
+  private boolean tryAddToCommitBatch(Commit commit, 
CommitWorkStream.RequestBatcher batcher) {
     Preconditions.checkNotNull(commit);
     commit.work().setState(Work.State.COMMITTING);
     activeCommitBytes.addAndGet(commit.getSize());
     boolean isCommitAccepted =
-        commitStream.commitWorkItem(
+        batcher.commitWorkItem(
             commit.computationId(),
             commit.request(),
             (commitStatus) -> {
@@ -197,9 +201,9 @@ public final class StreamingEngineWorkCommitter implements 
WorkCommitter {
     return isCommitAccepted;
   }
 
-  // Helper to batch additional commits into the commit stream as long as they 
fit.
+  // Helper to batch additional commits into the commit batch as long as they 
fit.
   // Returns a commit that was removed from the queue but not consumed or null.
-  private Commit batchCommitsToStream(CommitWorkStream commitStream) {
+  private Commit expandBatch(CommitWorkStream.RequestBatcher batcher) {
     int commits = 1;
     while (true) {
       Commit commit;
@@ -210,8 +214,7 @@ public final class StreamingEngineWorkCommitter implements 
WorkCommitter {
           commit = commitQueue.poll();
         }
       } catch (InterruptedException e) {
-        // Continue processing until !running.get()
-        continue;
+        return null;
       }
 
       if (commit == null) {
@@ -224,7 +227,7 @@ public final class StreamingEngineWorkCommitter implements 
WorkCommitter {
         continue;
       }
 
-      if (!tryAddToCommitStream(commit, commitStream)) {
+      if (!tryAddToCommitBatch(commit, batcher)) {
         return commit;
       }
       commits++;
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java
index b921160e1a9..f9f579119d6 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java
@@ -51,7 +51,6 @@ public final class GrpcCommitWorkStream
   private static final long HEARTBEAT_REQUEST_ID = Long.MAX_VALUE;
 
   private final Map<Long, PendingRequest> pending;
-  private final Batcher batcher;
   private final AtomicLong idGenerator;
   private final JobHeader jobHeader;
   private final ThrottleTimer commitWorkThrottleTimer;
@@ -75,7 +74,6 @@ public final class GrpcCommitWorkStream
         streamRegistry,
         logEveryNStreamFailures);
     pending = new ConcurrentHashMap<>();
-    batcher = new Batcher();
     this.idGenerator = idGenerator;
     this.jobHeader = jobHeader;
     this.commitWorkThrottleTimer = commitWorkThrottleTimer;
@@ -116,14 +114,23 @@ public final class GrpcCommitWorkStream
   @Override
   protected synchronized void onNewStream() {
     send(StreamingCommitWorkRequest.newBuilder().setHeader(jobHeader).build());
-    Batcher resendBatcher = new Batcher();
-    for (Map.Entry<Long, PendingRequest> entry : pending.entrySet()) {
-      if (!resendBatcher.canAccept(entry.getValue())) {
-        resendBatcher.flush();
+    try (Batcher resendBatcher = new Batcher()) {
+      for (Map.Entry<Long, PendingRequest> entry : pending.entrySet()) {
+        if (!resendBatcher.canAccept(entry.getValue().getBytes())) {
+          resendBatcher.flush();
+        }
+        resendBatcher.add(entry.getKey(), entry.getValue());
       }
-      resendBatcher.add(entry.getKey(), entry.getValue());
     }
-    resendBatcher.flush();
+  }
+
+  /**
+   * Returns a builder that can be used for sending requests. Each builder is 
not thread-safe but
+   * different builders for the same stream may be used simultaneously.
+   */
+  @Override
+  public CommitWorkStream.RequestBatcher batcher() {
+    return new Batcher();
   }
 
   @Override
@@ -175,22 +182,6 @@ public final class GrpcCommitWorkStream
     commitWorkThrottleTimer.start();
   }
 
-  @Override
-  public boolean commitWorkItem(
-      String computation, WorkItemCommitRequest commitRequest, 
Consumer<CommitStatus> onDone) {
-    PendingRequest request = new PendingRequest(computation, commitRequest, 
onDone);
-    if (!batcher.canAccept(request)) {
-      return false;
-    }
-    batcher.add(idGenerator.incrementAndGet(), request);
-    return true;
-  }
-
-  @Override
-  public void flush() {
-    batcher.flush();
-  }
-
   private void flushInternal(Map<Long, PendingRequest> requests) {
     if (requests.isEmpty()) {
       return;
@@ -305,7 +296,7 @@ public final class GrpcCommitWorkStream
     }
   }
 
-  private class Batcher {
+  private class Batcher implements CommitWorkStream.RequestBatcher {
 
     private final Map<Long, PendingRequest> queue;
     private long queuedBytes;
@@ -315,22 +306,35 @@ public final class GrpcCommitWorkStream
       this.queue = new HashMap<>();
     }
 
-    boolean canAccept(PendingRequest request) {
-      return queue.isEmpty()
-          || (queue.size() < streamingRpcBatchLimit
-              && (request.getBytes() + queuedBytes) < 
AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE);
+    @Override
+    public boolean commitWorkItem(
+        String computation, WorkItemCommitRequest commitRequest, 
Consumer<CommitStatus> onDone) {
+      if (!canAccept(commitRequest.getSerializedSize() + 
computation.length())) {
+        return false;
+      }
+      PendingRequest request = new PendingRequest(computation, commitRequest, 
onDone);
+      add(idGenerator.incrementAndGet(), request);
+      return true;
+    }
+
+    /** Flushes any pending work items to the wire. */
+    @Override
+    public void flush() {
+      flushInternal(queue);
+      queuedBytes = 0;
+      queue.clear();
     }
 
     void add(long id, PendingRequest request) {
-      assert (canAccept(request));
+      assert (canAccept(request.getBytes()));
       queuedBytes += request.getBytes();
       queue.put(id, request);
     }
 
-    void flush() {
-      flushInternal(queue);
-      queuedBytes = 0;
-      queue.clear();
+    private boolean canAccept(long requestBytes) {
+      return queue.isEmpty()
+          || (queue.size() < streamingRpcBatchLimit
+              && (requestBytes + queuedBytes) < 
AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE);
     }
   }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
index 89939d5d341..127d46b7caf 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java
@@ -366,41 +366,69 @@ public final class FakeWindmillServer extends 
WindmillServerStub {
   public CommitWorkStream commitWorkStream() {
     Instant startTime = Instant.now();
     return new CommitWorkStream() {
-      @Override
-      public boolean commitWorkItem(
-          String computation,
-          WorkItemCommitRequest request,
-          Consumer<Windmill.CommitStatus> onDone) {
-        LOG.debug("commitWorkStream::commitWorkItem: {}", request);
-        errorCollector.checkThat(request.hasWorkToken(), equalTo(true));
-        errorCollector.checkThat(
-            request.getShardingKey(), allOf(greaterThan(0L), 
lessThan(Long.MAX_VALUE)));
-        errorCollector.checkThat(request.getCacheToken(), not(equalTo(0L)));
-        // Throws away the result, but allows to inject latency.
-        Windmill.CommitWorkRequest.Builder builder = 
Windmill.CommitWorkRequest.newBuilder();
-        
builder.addRequestsBuilder().setComputationId(computation).addRequests(request);
-        commitsToOffer.getOrDefault(builder.build());
-        if (dropStreamingCommits) {
-          droppedStreamingCommits.put(request.getWorkToken(), onDone);
-        } else {
-          commitsReceived.put(request.getWorkToken(), request);
-          onDone.accept(
-              Optional.ofNullable(
-                      streamingCommitsToOffer.remove(
-                          WorkId.builder()
-                              .setWorkToken(request.getWorkToken())
-                              .setCacheToken(request.getCacheToken())
-                              .build()))
-                  // Default to CommitStatus.OK
-                  .orElse(Windmill.CommitStatus.OK));
-        }
-        // Return true to indicate the request was accepted even if we are 
dropping the commit
-        // to simulate a dropped commit.
-        return true;
-      }
 
       @Override
-      public void flush() {}
+      public RequestBatcher batcher() {
+        return new RequestBatcher() {
+          class RequestAndDone {
+            final Consumer<Windmill.CommitStatus> onDone;
+            final WorkItemCommitRequest request;
+
+            RequestAndDone(WorkItemCommitRequest request, 
Consumer<Windmill.CommitStatus> onDone) {
+              this.request = request;
+              this.onDone = onDone;
+            }
+          }
+
+          final List<RequestAndDone> requests = new ArrayList<>();
+
+          @Override
+          public boolean commitWorkItem(
+              String computation,
+              WorkItemCommitRequest request,
+              Consumer<Windmill.CommitStatus> onDone) {
+            LOG.debug("commitWorkStream::commitWorkItem: {}", request);
+            errorCollector.checkThat(request.hasWorkToken(), equalTo(true));
+            errorCollector.checkThat(
+                request.getShardingKey(), allOf(greaterThan(0L), 
lessThan(Long.MAX_VALUE)));
+            errorCollector.checkThat(request.getCacheToken(), 
not(equalTo(0L)));
+            if (requests.size() > 5) return false;
+
+            // Throws away the result, but allows to inject latency.
+            Windmill.CommitWorkRequest.Builder builder = 
Windmill.CommitWorkRequest.newBuilder();
+            
builder.addRequestsBuilder().setComputationId(computation).addRequests(request);
+            commitsToOffer.getOrDefault(builder.build());
+
+            requests.add(new RequestAndDone(request, onDone));
+            flush();
+            return true;
+          }
+
+          @Override
+          public void flush() {
+            for (RequestAndDone elem : requests) {
+              if (dropStreamingCommits) {
+                droppedStreamingCommits.put(elem.request.getWorkToken(), 
elem.onDone);
+                // Return true to indicate the request was accepted even if we 
are dropping the
+                // commit to simulate a dropped commit.
+                continue;
+              }
+
+              commitsReceived.put(elem.request.getWorkToken(), elem.request);
+              elem.onDone.accept(
+                  Optional.ofNullable(
+                          streamingCommitsToOffer.remove(
+                              WorkId.builder()
+                                  .setWorkToken(elem.request.getWorkToken())
+                                  .setCacheToken(elem.request.getCacheToken())
+                                  .build()))
+                      // Default to CommitStatus.OK
+                      .orElse(Windmill.CommitStatus.OK));
+            }
+            requests.clear();
+          }
+        };
+      }
 
       @Override
       public void close() {}
@@ -419,7 +447,7 @@ public final class FakeWindmillServer extends 
WindmillServerStub {
 
   public void waitForEmptyWorkQueue() {
     while (!workToOffer.isEmpty()) {
-      Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS);
+      Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
     }
   }
 
@@ -429,7 +457,7 @@ public final class FakeWindmillServer extends 
WindmillServerStub {
     Instant waitStart = Instant.now();
     while (commitsReceived.size() < commitsRequested + numCommits
         && Instant.now().isBefore(waitStart.plus(timeout))) {
-      Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS);
+      Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
     }
     commitsRequested += numCommits;
     return commitsReceived;
@@ -437,9 +465,9 @@ public final class FakeWindmillServer extends 
WindmillServerStub {
 
   public Map<Long, WorkItemCommitRequest> waitForAndGetCommits(int numCommits) 
{
     LOG.debug("waitForAndGetCommitsRequest: {}", numCommits);
-    int maxTries = 10;
+    int maxTries = 100;
     while (maxTries-- > 0 && commitsReceived.size() < commitsRequested + 
numCommits) {
-      Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS);
+      Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
     }
 
     assertFalse(
@@ -448,7 +476,7 @@ public final class FakeWindmillServer extends 
WindmillServerStub {
             + " more commits beyond "
             + commitsRequested
             + " commits already seen, but after 10s have only seen "
-            + commitsReceived
+            + commitsReceived.size()
             + ". Exceptions seen: "
             + exceptions,
         commitsReceived.size() < commitsRequested + numCommits);
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
index 1bf2e44f9f0..49c61b9b8ab 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java
@@ -33,6 +33,7 @@ import java.util.Map;
 import java.util.Optional;
 import java.util.Random;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Consumer;
 import java.util.function.Function;
@@ -249,16 +250,21 @@ public class StreamingEngineWorkCommitterTest {
         () ->
             new CommitWorkStream() {
               @Override
-              public boolean commitWorkItem(
-                  String computation,
-                  WorkItemCommitRequest request,
-                  Consumer<Windmill.CommitStatus> onDone) {
-                return false;
+              public RequestBatcher batcher() {
+                return new RequestBatcher() {
+                  @Override
+                  public boolean commitWorkItem(
+                      String computation,
+                      WorkItemCommitRequest request,
+                      Consumer<Windmill.CommitStatus> onDone) {
+                    return false;
+                  }
+
+                  @Override
+                  public void flush() {}
+                };
               }
 
-              @Override
-              public void flush() {}
-
               @Override
               public void close() {}
 
@@ -305,4 +311,39 @@ public class StreamingEngineWorkCommitterTest {
       assertTrue(commit.work().isFailed());
     }
   }
+
+  @Test
+  public void testMultipleCommitSendersSingleStream() {
+    commitWorkStreamFactory =
+        WindmillStreamPool.create(
+                1, Duration.standardMinutes(1), 
fakeWindmillServer::commitWorkStream)
+            ::getCloseableStream;
+    Set<CompleteCommit> completeCommits = Collections.newSetFromMap(new 
ConcurrentHashMap<>());
+    workCommitter =
+        StreamingEngineWorkCommitter.create(commitWorkStreamFactory, 5, 
completeCommits::add);
+    List<Commit> commits = new ArrayList<>();
+    for (int i = 1; i <= 500; i++) {
+      Work work = createMockWork(i, ignored -> {});
+      WorkItemCommitRequest commitRequest =
+          WorkItemCommitRequest.newBuilder()
+              .setKey(work.getWorkItem().getKey())
+              .setShardingKey(work.getWorkItem().getShardingKey())
+              .setWorkToken(work.getWorkItem().getWorkToken())
+              .setCacheToken(work.getWorkItem().getCacheToken())
+              .build();
+      commits.add(Commit.create(commitRequest, 
createComputationState("computationId-" + i), work));
+    }
+
+    workCommitter.start();
+    commits.parallelStream().forEach(workCommitter::commit);
+    Map<Long, WorkItemCommitRequest> committed =
+        fakeWindmillServer.waitForAndGetCommits(commits.size());
+
+    for (Commit commit : commits) {
+      WorkItemCommitRequest request = 
committed.get(commit.work().getWorkItem().getWorkToken());
+      assertNotNull(request);
+      assertThat(request).isEqualTo(commit.request());
+      assertThat(completeCommits).contains(asCompleteCommit(commit, 
Windmill.CommitStatus.OK));
+    }
+  }
 }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
index fe0822a6067..b1d5309e12d 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java
@@ -32,6 +32,8 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -630,20 +632,47 @@ public class GrpcWindmillServerTest {
     };
   }
 
-  @Test
-  public void testStreamingCommit() throws Exception {
+  private void commitWorkTestHelper(
+      CommitWorkStream stream,
+      ConcurrentHashMap<Long, WorkItemCommitRequest> requestsForService,
+      int requestIdStart,
+      int numRequests)
+      throws InterruptedException {
     List<WorkItemCommitRequest> commitRequestList = new ArrayList<>();
     List<CountDownLatch> latches = new ArrayList<>();
-    Map<Long, WorkItemCommitRequest> commitRequests = new 
ConcurrentHashMap<>();
-    for (int i = 0; i < 500; ++i) {
+    for (int i = 0; i < numRequests; ++i) {
       // Build some requests of varying size with a few big ones.
-      WorkItemCommitRequest request = makeCommitRequest(i, i * (i < 480 ? 8 : 
128));
+      WorkItemCommitRequest request =
+          makeCommitRequest(i + requestIdStart, i * (i < numRequests * .9 ? 8 
: 128));
       commitRequestList.add(request);
-      commitRequests.put((long) i, request);
+      requestsForService.put((long) i + requestIdStart, request);
       latches.add(new CountDownLatch(1));
     }
     Collections.shuffle(commitRequestList);
+    try (CommitWorkStream.RequestBatcher batcher = stream.batcher()) {
+      for (int i = 0; i < commitRequestList.size(); ) {
+        final CountDownLatch latch = latches.get(i);
+        if (batcher.commitWorkItem(
+            "computation",
+            commitRequestList.get(i),
+            (CommitStatus status) -> {
+              assertEquals(status, CommitStatus.OK);
+              latch.countDown();
+            })) {
+          i++;
+        } else {
+          batcher.flush();
+        }
+      }
+    }
+    for (CountDownLatch latch : latches) {
+      assertTrue(latch.await(1, TimeUnit.MINUTES));
+    }
+  }
 
+  @Test
+  public void testStreamingCommit() throws Exception {
+    ConcurrentHashMap<Long, WorkItemCommitRequest> commitRequests = new 
ConcurrentHashMap<>();
     serviceRegistry.addService(
         new CloudWindmillServiceV1Alpha1ImplBase() {
           @Override
@@ -655,26 +684,45 @@ public class GrpcWindmillServerTest {
 
     // Make the commit requests, waiting for each of them to be verified and 
acknowledged.
     CommitWorkStream stream = client.commitWorkStream();
-    for (int i = 0; i < commitRequestList.size(); ) {
-      final CountDownLatch latch = latches.get(i);
-      if (stream.commitWorkItem(
-          "computation",
-          commitRequestList.get(i),
-          (CommitStatus status) -> {
-            assertEquals(status, CommitStatus.OK);
-            latch.countDown();
-          })) {
-        i++;
-      } else {
-        stream.flush();
-      }
-    }
-    stream.flush();
+    commitWorkTestHelper(stream, commitRequests, 0, 500);
     stream.close();
-    for (CountDownLatch latch : latches) {
-      assertTrue(latch.await(1, TimeUnit.MINUTES));
+    assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS));
+  }
+
+  @Test
+  public void testStreamingCommitManyThreads() throws Exception {
+    ConcurrentHashMap<Long, WorkItemCommitRequest> commitRequests = new 
ConcurrentHashMap<>();
+    serviceRegistry.addService(
+        new CloudWindmillServiceV1Alpha1ImplBase() {
+          @Override
+          public StreamObserver<StreamingCommitWorkRequest> commitWorkStream(
+              StreamObserver<StreamingCommitResponse> responseObserver) {
+            return getTestCommitStreamObserver(responseObserver, 
commitRequests);
+          }
+        });
+    ScheduledExecutorService executor = Executors.newScheduledThreadPool(10);
+    // Make the commit requests, waiting for each of them to be verified and 
acknowledged.
+    CommitWorkStream stream = client.commitWorkStream();
+    List<Future<?>> futures = new ArrayList<>();
+    for (int i = 0; i < 10; ++i) {
+      final int startRequestId = i * 50;
+      futures.add(
+          executor.submit(
+              () -> {
+                try {
+                  commitWorkTestHelper(stream, commitRequests, startRequestId, 
50);
+                } catch (InterruptedException e) {
+                  throw new RuntimeException(e);
+                }
+              }));
+    }
+    // Surface any exceptions that might be thrown by submitting by blocking 
on the future.
+    for (Future<?> f : futures) {
+      f.get();
     }
+    stream.close();
     assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS));
+    executor.shutdown();
   }
 
   @Test
@@ -739,21 +787,22 @@ public class GrpcWindmillServerTest {
 
     // Make the commit requests, waiting for each of them to be verified and 
acknowledged.
     CommitWorkStream stream = client.commitWorkStream();
-    for (int i = 0; i < commitRequestList.size(); ) {
-      final CountDownLatch latch = latches.get(i);
-      if (stream.commitWorkItem(
-          "computation",
-          commitRequestList.get(i),
-          (CommitStatus status) -> {
-            assertEquals(status, CommitStatus.OK);
-            latch.countDown();
-          })) {
-        i++;
-      } else {
-        stream.flush();
+    try (CommitWorkStream.RequestBatcher batcher = stream.batcher()) {
+      for (int i = 0; i < commitRequestList.size(); ) {
+        final CountDownLatch latch = latches.get(i);
+        if (batcher.commitWorkItem(
+            "computation",
+            commitRequestList.get(i),
+            (CommitStatus status) -> {
+              assertEquals(status, CommitStatus.OK);
+              latch.countDown();
+            })) {
+          i++;
+        } else {
+          batcher.flush();
+        }
       }
     }
-    stream.flush();
 
     long deadline = System.currentTimeMillis() + 60_000; // 1 min
     while (true) {

Reply via email to