scwhittle commented on code in PR #32774:
URL: https://github.com/apache/beam/pull/32774#discussion_r1833955728


##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java:
##########
@@ -412,15 +420,17 @@ private void recordStreamStatus(Status status) {
     }
 
     /** Returns true if the stream was torn down and should not be restarted 
internally. */
-    private synchronized boolean maybeTeardownStream() {
-      if (hasReceivedShutdownSignal() || (clientClosed && 
!hasPendingRequests())) {
-        streamRegistry.remove(AbstractWindmillStream.this);
-        finishLatch.countDown();
-        executor.shutdownNow();
-        return true;
-      }
+    private boolean maybeTeardownStream() {
+      synchronized (AbstractWindmillStream.this) {

Review Comment:
   just make a synchronized method?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java:
##########
@@ -59,6 +59,9 @@ final class StreamDebugMetrics {
   @GuardedBy("this")
   private DateTime shutdownTime = null;

Review Comment:
   mark nullable and above
   
   I think you still need to fix the over exemption of windmill classes from 
checker. 
   https://github.com/apache/beam/issues/30183



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java:
##########
@@ -272,7 +280,7 @@ public final void appendSummaryHtml(PrintWriter writer) {
         summaryMetrics.timeSinceLastSend(),
         summaryMetrics.timeSinceLastResponse(),
         requestObserver.isClosed(),
-        hasReceivedShutdownSignal(),
+        summaryMetrics.shutdownTime().isPresent(),

Review Comment:
   nit: could remove this part since it is duplicated with null.  Could render 
null differently maybe somethign like
   
   summaryMetrics.shutdownTime().map(Instant::toString).orElse("not shutdown")
   
   



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java:
##########
@@ -412,15 +420,17 @@ private void recordStreamStatus(Status status) {
     }
 
     /** Returns true if the stream was torn down and should not be restarted 
internally. */
-    private synchronized boolean maybeTeardownStream() {
-      if (hasReceivedShutdownSignal() || (clientClosed && 
!hasPendingRequests())) {
-        streamRegistry.remove(AbstractWindmillStream.this);
-        finishLatch.countDown();
-        executor.shutdownNow();
-        return true;
-      }
+    private boolean maybeTeardownStream() {
+      synchronized (AbstractWindmillStream.this) {
+        if (isShutdown || (clientClosed && !hasPendingRequests())) {
+          streamRegistry.remove(AbstractWindmillStream.this);

Review Comment:
   why doesn't just "this" work?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java:
##########
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client;
+
+import java.util.function.Supplier;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+import javax.annotation.concurrent.ThreadSafe;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.TerminatingStreamObserver;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import org.slf4j.Logger;
+
+/**
+ * Request observer that allows resetting its internal delegate using the 
given {@link
+ * #streamObserverFactory}.
+ *
+ * @implNote {@link StreamObserver}s generated by {@link 
#streamObserverFactory} are expected to be
+ *     {@link ThreadSafe}. Has same methods declared in {@link 
StreamObserver}, but they throw
+ *     {@link StreamClosedException} and {@link 
WindmillStreamShutdownException}, which much be
+ *     handled by callers.
+ */
+@ThreadSafe
+@Internal
+final class ResettableThrowingStreamObserver<T> {
+  private final Supplier<TerminatingStreamObserver<T>> streamObserverFactory;
+  private final Logger logger;
+
+  @GuardedBy("this")
+  private @Nullable TerminatingStreamObserver<T> delegateStreamObserver;
+
+  @GuardedBy("this")
+  private boolean isPoisoned = false;
+
+  /**
+   * Indicates that the current delegate is closed via {@link #poison() or 
{@link #onCompleted()}}.
+   * If not poisoned, a call to {@link #reset()} is required to perform future 
operations on the
+   * StreamObserver.
+   */
+  @GuardedBy("this")
+  private boolean isCurrentStreamClosed = false;
+
+  ResettableThrowingStreamObserver(
+      Supplier<TerminatingStreamObserver<T>> streamObserverFactory, Logger 
logger) {
+    this.streamObserverFactory = streamObserverFactory;
+    this.logger = logger;
+    this.delegateStreamObserver = null;
+  }
+
+  private synchronized StreamObserver<T> delegate()
+      throws WindmillStreamShutdownException, StreamClosedException {
+    if (isPoisoned) {
+      throw new WindmillStreamShutdownException("Stream is already shutdown.");
+    }
+
+    if (isCurrentStreamClosed) {
+      throw new StreamClosedException(
+          "Current stream is closed, requires reset for future stream 
operations.");
+    }
+
+    return Preconditions.checkNotNull(
+        delegateStreamObserver,
+        "requestObserver cannot be null. Missing a call to startStream() to 
initialize.");
+  }
+
+  /** Creates a new delegate to use for future {@link StreamObserver} methods. 
*/
+  synchronized void reset() throws WindmillStreamShutdownException {
+    if (isPoisoned) {
+      throw new WindmillStreamShutdownException("Stream is already shutdown.");
+    }
+
+    delegateStreamObserver = streamObserverFactory.get();
+    isCurrentStreamClosed = false;
+  }
+
+  /**
+   * Indicates that the request observer should no longer be used. Attempts to 
perform operations on
+   * the request observer will throw an {@link 
WindmillStreamShutdownException}.
+   */
+  synchronized void poison() {
+    if (!isPoisoned) {
+      isPoisoned = true;
+      if (delegateStreamObserver != null) {
+        delegateStreamObserver.terminate(
+            new WindmillStreamShutdownException("Explicit call to shutdown 
stream."));
+        delegateStreamObserver = null;
+        isCurrentStreamClosed = true;
+      }
+    }
+  }
+
+  public void onNext(T t) throws StreamClosedException, 
WindmillStreamShutdownException {
+    // Make sure onNext and onError below to be called on the same 
StreamObserver instance.
+    StreamObserver<T> delegate = delegate();
+    try {
+      // Do NOT lock while sending message over the stream as this will block 
other StreamObserver
+      // operations.
+      delegate.onNext(t);
+    } catch (StreamObserverCancelledException e) {
+      synchronized (this) {
+        if (isPoisoned) {
+          logger.debug("Stream was shutdown during send.", e);
+          return;
+        }
+      }
+
+      try {
+        delegate.onError(e);
+      } catch (RuntimeException ignored) {
+        // If the delegate above was already terminated via onError or 
onComplete from another
+        // thread.
+        logger.warn("StreamObserver was previously cancelled.", e);
+      }
+    }
+  }
+
+  public void onError(Throwable throwable)
+      throws StreamClosedException, WindmillStreamShutdownException {
+    delegate().onError(throwable);
+  }
+
+  public synchronized void onCompleted()

Review Comment:
   the synchronization is inconsistent for onCompleted and onError.  Both also 
see like they should make the current stream closed since we don't want to send 
more after either of them is called.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java:
##########
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client;
+
+import java.util.function.Supplier;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+import javax.annotation.concurrent.ThreadSafe;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.TerminatingStreamObserver;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import org.slf4j.Logger;
+
+/**
+ * Request observer that allows resetting its internal delegate using the 
given {@link
+ * #streamObserverFactory}.
+ *
+ * @implNote {@link StreamObserver}s generated by {@link 
#streamObserverFactory} are expected to be
+ *     {@link ThreadSafe}. Has same methods declared in {@link 
StreamObserver}, but they throw
+ *     {@link StreamClosedException} and {@link 
WindmillStreamShutdownException}, which much be
+ *     handled by callers.
+ */
+@ThreadSafe
+@Internal
+final class ResettableThrowingStreamObserver<T> {
+  private final Supplier<TerminatingStreamObserver<T>> streamObserverFactory;
+  private final Logger logger;
+
+  @GuardedBy("this")
+  private @Nullable TerminatingStreamObserver<T> delegateStreamObserver;
+
+  @GuardedBy("this")
+  private boolean isPoisoned = false;
+
+  /**
+   * Indicates that the current delegate is closed via {@link #poison() or 
{@link #onCompleted()}}.
+   * If not poisoned, a call to {@link #reset()} is required to perform future 
operations on the
+   * StreamObserver.
+   */
+  @GuardedBy("this")
+  private boolean isCurrentStreamClosed = false;
+
+  ResettableThrowingStreamObserver(
+      Supplier<TerminatingStreamObserver<T>> streamObserverFactory, Logger 
logger) {
+    this.streamObserverFactory = streamObserverFactory;
+    this.logger = logger;
+    this.delegateStreamObserver = null;
+  }
+
+  private synchronized StreamObserver<T> delegate()
+      throws WindmillStreamShutdownException, StreamClosedException {
+    if (isPoisoned) {
+      throw new WindmillStreamShutdownException("Stream is already shutdown.");
+    }
+
+    if (isCurrentStreamClosed) {
+      throw new StreamClosedException(
+          "Current stream is closed, requires reset for future stream 
operations.");
+    }
+
+    return Preconditions.checkNotNull(
+        delegateStreamObserver,
+        "requestObserver cannot be null. Missing a call to startStream() to 
initialize.");

Review Comment:
   reset() is the method in this class that is needed.
   
   However mabye we should just begin with isCurrentStreamClosed = true? then 
the above check will handle it and the message seems correct about reset() 
being required.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java:
##########
@@ -67,76 +72,128 @@ public DirectStreamObserver(
   }
 
   @Override
-  public void onNext(T value) {
+  public void onNext(T value) throws StreamObserverCancelledException {
     int awaitPhase = -1;
     long totalSecondsWaited = 0;
     long waitSeconds = 1;
     while (true) {
       try {
         synchronized (lock) {
+          int currentPhase = isReadyNotifier.getPhase();
+          // Phaser is terminated so don't use the outboundObserver. Since 
onError and onCompleted
+          // are synchronized after terminating the phaser if we observe that 
the phaser is not
+          // terminated the onNext calls below are guaranteed to not be called 
on a closed observer.
+          if (currentPhase < 0) return;
+
+          // If we awaited previously and timed out, wait for the same phase. 
Otherwise we're
+          // careful to observe the phase before observing isReady.
+          if (awaitPhase < 0) {
+            awaitPhase = isReadyNotifier.getPhase();
+            // If getPhase() returns a value less than 0, the phaser has been 
terminated.
+            if (awaitPhase < 0) {
+              return;
+            }
+          }
+
           // We only check isReady periodically to effectively allow for 
increasing the outbound
           // buffer periodically. This reduces the overhead of blocking while 
still restricting
           // memory because there is a limited # of streams, and we have a max 
messages size of 2MB.
           if (++messagesSinceReady <= messagesBetweenIsReadyChecks) {
             outboundObserver.onNext(value);
             return;
           }
-          // If we awaited previously and timed out, wait for the same phase. 
Otherwise we're
-          // careful to observe the phase before observing isReady.
-          if (awaitPhase < 0) {
-            awaitPhase = phaser.getPhase();
-          }
+
           if (outboundObserver.isReady()) {
             messagesSinceReady = 0;
             outboundObserver.onNext(value);
             return;
           }
         }
+
         // A callback has been registered to advance the phaser whenever the 
observer
         // transitions to  is ready. Since we are waiting for a phase observed 
before the
         // outboundObserver.isReady() returned false, we expect it to advance 
after the
         // channel has become ready.  This doesn't always seem to be the case 
(despite
         // documentation stating otherwise) so we poll periodically and 
enforce an overall
         // timeout related to the stream deadline.
-        phaser.awaitAdvanceInterruptibly(awaitPhase, waitSeconds, 
TimeUnit.SECONDS);
+        int nextPhase =
+            isReadyNotifier.awaitAdvanceInterruptibly(awaitPhase, waitSeconds, 
TimeUnit.SECONDS);
+        // If nextPhase is a value less than 0, the phaser has been terminated.
+        if (nextPhase < 0) {
+          throw new StreamObserverCancelledException("StreamObserver was 
terminated.");
+        }
+
         synchronized (lock) {
+          int currentPhase = isReadyNotifier.getPhase();
+          // Phaser is terminated so don't use the outboundObserver. Since 
onError and onCompleted
+          // are synchronized after terminating the phaser if we observe that 
the phaser is not
+          // terminated the onNext calls below are guaranteed to not be called 
on a closed observer.
+          if (currentPhase < 0) return;

Review Comment:
   throw exception?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java:
##########
@@ -67,76 +72,128 @@ public DirectStreamObserver(
   }
 
   @Override
-  public void onNext(T value) {
+  public void onNext(T value) throws StreamObserverCancelledException {
     int awaitPhase = -1;
     long totalSecondsWaited = 0;
     long waitSeconds = 1;
     while (true) {
       try {
         synchronized (lock) {
+          int currentPhase = isReadyNotifier.getPhase();
+          // Phaser is terminated so don't use the outboundObserver. Since 
onError and onCompleted
+          // are synchronized after terminating the phaser if we observe that 
the phaser is not
+          // terminated the onNext calls below are guaranteed to not be called 
on a closed observer.
+          if (currentPhase < 0) return;
+
+          // If we awaited previously and timed out, wait for the same phase. 
Otherwise we're
+          // careful to observe the phase before observing isReady.
+          if (awaitPhase < 0) {
+            awaitPhase = isReadyNotifier.getPhase();

Review Comment:
   can use currentPhase here instead of calling getPhase again.  Don't need the 
negative check then either since handled above.



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java:
##########
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcGetDataStreamRequestsTest {
+
+  @Test
+  public void testQueuedRequest_globalRequestsFirstComparator() {
+    List<GrpcGetDataStreamRequests.QueuedRequest> requests = new ArrayList<>();
+    Windmill.KeyedGetDataRequest keyedGetDataRequest1 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(ByteString.EMPTY)
+            .setCacheToken(1L)
+            .setShardingKey(1L)
+            .setWorkToken(1L)
+            .setMaxBytes(Long.MAX_VALUE)
+            .build();
+    requests.add(
+        GrpcGetDataStreamRequests.QueuedRequest.forComputation(
+            1, "computation1", keyedGetDataRequest1));
+
+    Windmill.KeyedGetDataRequest keyedGetDataRequest2 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(ByteString.EMPTY)
+            .setCacheToken(2L)
+            .setShardingKey(2L)
+            .setWorkToken(2L)
+            .setMaxBytes(Long.MAX_VALUE)
+            .build();
+    requests.add(
+        GrpcGetDataStreamRequests.QueuedRequest.forComputation(
+            2, "computation2", keyedGetDataRequest2));
+
+    Windmill.GlobalDataRequest globalDataRequest =
+        Windmill.GlobalDataRequest.newBuilder()
+            .setDataId(
+                Windmill.GlobalDataId.newBuilder()
+                    .setTag("globalData")
+                    .setVersion(ByteString.EMPTY)
+                    .build())
+            .setComputationId("computation1")
+            .build();
+    requests.add(GrpcGetDataStreamRequests.QueuedRequest.global(3, 
globalDataRequest));
+
+    
requests.sort(GrpcGetDataStreamRequests.QueuedRequest.globalRequestsFirst());
+
+    // First one should be the global request.
+    assertTrue(requests.get(0).getDataRequest().isGlobal());
+  }
+
+  @Test
+  public void testQueuedBatch_asGetDataRequest() {
+    GrpcGetDataStreamRequests.QueuedBatch queuedBatch = new 
GrpcGetDataStreamRequests.QueuedBatch();
+
+    Windmill.KeyedGetDataRequest keyedGetDataRequest1 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(ByteString.EMPTY)
+            .setCacheToken(1L)
+            .setShardingKey(1L)
+            .setWorkToken(1L)
+            .setMaxBytes(Long.MAX_VALUE)
+            .build();
+    queuedBatch.addRequest(
+        GrpcGetDataStreamRequests.QueuedRequest.forComputation(
+            1, "computation1", keyedGetDataRequest1));
+
+    Windmill.KeyedGetDataRequest keyedGetDataRequest2 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(ByteString.EMPTY)
+            .setCacheToken(2L)
+            .setShardingKey(2L)
+            .setWorkToken(2L)
+            .setMaxBytes(Long.MAX_VALUE)
+            .build();
+    queuedBatch.addRequest(
+        GrpcGetDataStreamRequests.QueuedRequest.forComputation(
+            2, "computation2", keyedGetDataRequest2));
+
+    Windmill.GlobalDataRequest globalDataRequest =
+        Windmill.GlobalDataRequest.newBuilder()
+            .setDataId(
+                Windmill.GlobalDataId.newBuilder()
+                    .setTag("globalData")
+                    .setVersion(ByteString.EMPTY)
+                    .build())
+            .setComputationId("computation1")
+            .build();
+    queuedBatch.addRequest(GrpcGetDataStreamRequests.QueuedRequest.global(3, 
globalDataRequest));
+
+    Windmill.StreamingGetDataRequest getDataRequest = 
queuedBatch.asGetDataRequest();
+
+    assertThat(getDataRequest.getRequestIdCount()).isEqualTo(3);
+    
assertThat(getDataRequest.getGlobalDataRequestList()).containsExactly(globalDataRequest);
+    assertThat(getDataRequest.getStateRequestList())
+        .containsExactly(
+            Windmill.ComputationGetDataRequest.newBuilder()
+                .setComputationId("computation1")
+                .addRequests(keyedGetDataRequest1)
+                .build(),
+            Windmill.ComputationGetDataRequest.newBuilder()
+                .setComputationId("computation2")
+                .addRequests(keyedGetDataRequest2)
+                .build());
+  }
+
+  @Test
+  public void 
testQueuedBatch_notifyFailed_throwsWindmillStreamShutdownExceptionOnWaiters() {
+    GrpcGetDataStreamRequests.QueuedBatch queuedBatch = new 
GrpcGetDataStreamRequests.QueuedBatch();
+    CompletableFuture<WindmillStreamShutdownException> waitFuture =
+        CompletableFuture.supplyAsync(
+            () ->
+                assertThrows(
+                    WindmillStreamShutdownException.class,
+                    queuedBatch::waitForSendOrFailNotification));
+

Review Comment:
   sleep to allow above future to schedule?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java:
##########
@@ -305,77 +341,96 @@ public void onNext(ResponseT response) {
       } catch (IOException e) {
         // Ignore.
       }
-      lastResponseTimeMs.set(Instant.now().getMillis());
+      debugMetrics.recordResponse();
       onResponse(response);
     }
 
     @Override
     public void onError(Throwable t) {
-      onStreamFinished(t);
+      if (maybeTeardownStream()) {
+        return;
+      }
+
+      recordStreamStatus(Status.fromThrowable(t));
+
+      try {
+        long sleep = backoff.nextBackOffMillis();
+        debugMetrics.recordSleep(sleep);
+        sleeper.sleep(sleep);
+      } catch (InterruptedException e) {
+        Thread.currentThread().interrupt();
+        return;
+      } catch (IOException e) {
+        // Ignore.
+      }
+
+      executeSafely(AbstractWindmillStream.this::startStream);
     }
 
     @Override
     public void onCompleted() {
-      onStreamFinished(null);
+      if (maybeTeardownStream()) {
+        return;
+      }
+      recordStreamStatus(OK_STATUS);
+      executeSafely(AbstractWindmillStream.this::startStream);
     }
 
-    private void onStreamFinished(@Nullable Throwable t) {
-      synchronized (this) {
-        if (isShutdown.get() || (clientClosed.get() && !hasPendingRequests())) 
{
-          streamRegistry.remove(AbstractWindmillStream.this);
-          finishLatch.countDown();
-          return;
-        }
-      }
-      if (t != null) {
-        Status status = null;
-        if (t instanceof StatusRuntimeException) {
-          status = ((StatusRuntimeException) t).getStatus();
-        }
-        String statusError = status == null ? "" : status.toString();
-        setLastError(statusError);
-        if (errorCount.getAndIncrement() % logEveryNStreamFailures == 0) {
+    private void recordStreamStatus(Status status) {
+      int currentRestartCount = debugMetrics.incrementAndGetRestarts();
+      if (status.isOk()) {
+        String restartReason =
+            "Stream completed successfully but did not complete requested 
operations, "
+                + "recreating";
+        logger.warn(restartReason);
+        debugMetrics.recordRestartReason(restartReason);
+      } else {
+        int currentErrorCount = debugMetrics.incrementAndGetErrors();
+        debugMetrics.recordRestartReason(status.toString());
+        Throwable t = status.getCause();
+        if (t instanceof StreamObserverCancelledException) {
+          logger.error(
+              "StreamObserver was unexpectedly cancelled for stream={}, 
worker={}. stacktrace={}",
+              getClass(),
+              backendWorkerToken,
+              t.getStackTrace(),
+              t);
+        } else if (currentRestartCount % logEveryNStreamFailures == 0) {
+          // Don't log every restart since it will get noisy, and many errors 
transient.
           long nowMillis = Instant.now().getMillis();
-          String responseDebug;
-          if (lastResponseTimeMs.get() == 0) {
-            responseDebug = "never received response";
-          } else {
-            responseDebug =
-                "received response " + (nowMillis - lastResponseTimeMs.get()) 
+ "ms ago";
-          }
-          LOG.debug(
-              "{} streaming Windmill RPC errors for {}, last was: {} with 
status {}."
-                  + " created {}ms ago, {}. This is normal with autoscaling.",
+          logger.debug(
+              "{} has been restarted {} times. Streaming Windmill RPC Error 
Count: {}; last was: {}"
+                  + " with status: {}. created {}ms ago; {}. This is normal 
with autoscaling.",
               AbstractWindmillStream.this.getClass(),
-              errorCount.get(),
+              currentRestartCount,
+              currentErrorCount,
               t,
-              statusError,
-              nowMillis - startTimeMs.get(),
-              responseDebug);
+              status,
+              nowMillis - debugMetrics.getStartTimeMs(),
+              debugMetrics
+                  .responseDebugString(nowMillis)
+                  .orElse(NEVER_RECEIVED_RESPONSE_LOG_STRING));
         }
+
         // If the stream was stopped due to a resource exhausted error then we 
are throttled.
-        if (status != null && status.getCode() == 
Status.Code.RESOURCE_EXHAUSTED) {
+        if (status.getCode() == Status.Code.RESOURCE_EXHAUSTED) {

Review Comment:
   nit. I would maybe keep this in the onError case, as the other stuff is more 
just logs/debug page and this is more functional.



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java:
##########
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.inOrder;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import javax.annotation.Nullable;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.ServerCallStreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.InOrder;
+
+@RunWith(JUnit4.class)
+public class GrpcCommitWorkStreamTest {
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcCommitWorkStreamTest";
+  private static final Windmill.JobHeader TEST_JOB_HEADER =
+      Windmill.JobHeader.newBuilder()
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+  private static final String COMPUTATION_ID = "computationId";
+
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+  private ManagedChannel inProcessChannel;
+
+  private static Windmill.WorkItemCommitRequest workItemCommitRequest(long 
value) {
+    return Windmill.WorkItemCommitRequest.newBuilder()
+        .setKey(ByteString.EMPTY)
+        .setShardingKey(value)
+        .setWorkToken(value)
+        .setCacheToken(value)
+        .build();
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+  }
+
+  private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamTestStub 
testStub) {
+    serviceRegistry.addService(testStub);
+    GrpcCommitWorkStream commitWorkStream =
+        (GrpcCommitWorkStream)
+            GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+                .build()
+                .createCommitWorkStream(
+                    CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel),
+                    new ThrottleTimer());
+    commitWorkStream.start();
+    return commitWorkStream;
+  }
+
+  @Test
+  public void testShutdown_abortsQueuedCommits() throws InterruptedException {
+    int numCommits = 5;
+    CountDownLatch commitProcessed = new CountDownLatch(numCommits);
+    Set<Windmill.CommitStatus> onDone = new HashSet<>();
+
+    TestCommitWorkStreamRequestObserver requestObserver =
+        spy(new TestCommitWorkStreamRequestObserver());
+    CommitWorkStreamTestStub testStub = new 
CommitWorkStreamTestStub(requestObserver);
+    GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
+    try (WindmillStream.CommitWorkStream.RequestBatcher batcher = 
commitWorkStream.batcher()) {
+      for (int i = 0; i < numCommits; i++) {
+        batcher.commitWorkItem(
+            COMPUTATION_ID,
+            workItemCommitRequest(i),
+            commitStatus -> {
+              onDone.add(commitStatus);
+              commitProcessed.countDown();
+            });
+      }
+    }
+
+    // Verify that we sent the commits above in a request + the initial header.
+    verify(requestObserver, 
times(2)).onNext(any(Windmill.StreamingCommitWorkRequest.class));

Review Comment:
   can we verify this a little stricter?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java:
##########
@@ -278,23 +318,19 @@ public String backendWorkerToken() {
   }
 
   @Override
-  public void shutdown() {
-    if (isShutdown.compareAndSet(false, true)) {
-      requestObserver()
-          .onError(new WindmillStreamShutdownException("Explicit call to 
shutdown stream."));
+  public final void shutdown() {
+    // Don't lock on "this" before poisoning the request observer as allow IO 
to block shutdown.

Review Comment:
   comment a little confusing ending. how about "since otherwise the observer 
may be blocking in send"?



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java:
##########
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import javax.annotation.Nullable;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.ServerCallStreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcGetDataStreamTest {
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcGetDataStreamTest";
+  private static final Windmill.JobHeader TEST_JOB_HEADER =
+      Windmill.JobHeader.newBuilder()
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+  private ManagedChannel inProcessChannel;
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+  }
+
+  private GrpcGetDataStream createGetDataStream(GetDataStreamTestStub 
testStub) {
+    serviceRegistry.addService(testStub);
+    GrpcGetDataStream getDataStream =
+        (GrpcGetDataStream)
+            GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+                .setSendKeyedGetDataRequests(false)
+                .build()
+                .createGetDataStream(
+                    CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel),
+                    new ThrottleTimer());
+    getDataStream.start();
+    return getDataStream;
+  }
+
+  @Test
+  public void 
testRequestKeyedData_sendOnShutdownStreamThrowsWindmillStreamShutdownException()
 {

Review Comment:
   how about adding a simple test as well that just sends a single message 
without errors/shutdown?  Nice to have smaller tests since if this one fails it 
coudl possibly be lots of things where if simple test fails it is easier to 
debug.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java:
##########
@@ -301,39 +341,58 @@ public void appendSpecificHtml(PrintWriter writer) {
     writer.append("]");
   }
 
-  private <ResponseT> ResponseT issueRequest(QueuedRequest request, 
ParseFn<ResponseT> parseFn) {
-    while (true) {
+  private <ResponseT> ResponseT issueRequest(QueuedRequest request, 
ParseFn<ResponseT> parseFn)
+      throws WindmillStreamShutdownException {
+    while (!isShutdownLocked()) {

Review Comment:
   rm? below needs to handle it anyway



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java:
##########
@@ -271,12 +297,26 @@ public void 
onHeartbeatResponse(List<Windmill.ComputationHeartbeatResponse> resp
   }
 
   @Override
-  public void sendHealthCheck() {
+  public void sendHealthCheck() throws WindmillStreamShutdownException {
     if (hasPendingRequests()) {
-      send(StreamingGetDataRequest.newBuilder().build());
+      trySend(HEALTH_CHECK_REQUEST);
     }
   }
 
+  @Override
+  protected void shutdownInternal() {
+    // Stream has been explicitly closed. Drain pending input streams and 
request batches.
+    // Future calls to send RPCs will fail.
+    pending.values().forEach(AppendableInputStream::cancel);
+    pending.clear();
+    batches.forEach(
+        batch -> {
+          batch.markFinalized();

Review Comment:
   batch usage should be syncrhonized on this
   
   should shutdownInternal be marked synchronized? It is currently in 
sycnrhonzied shutdown block anyhway



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java:
##########
@@ -187,21 +208,26 @@ private long uniqueId() {
   }
 
   @Override
-  public KeyedGetDataResponse requestKeyedData(String computation, 
KeyedGetDataRequest request) {
+  public KeyedGetDataResponse requestKeyedData(String computation, 
KeyedGetDataRequest request)
+      throws WindmillStreamShutdownException {
     return issueRequest(
         QueuedRequest.forComputation(uniqueId(), computation, request),
         KeyedGetDataResponse::parseFrom);
   }
 
   @Override
-  public GlobalData requestGlobalData(GlobalDataRequest request) {
+  public GlobalData requestGlobalData(GlobalDataRequest request)
+      throws WindmillStreamShutdownException {
     return issueRequest(QueuedRequest.global(uniqueId(), request), 
GlobalData::parseFrom);
   }
 
   @Override
-  public void refreshActiveWork(Map<String, Collection<HeartbeatRequest>> 
heartbeats) {
-    if (isShutdown()) {
-      throw new WindmillStreamShutdownException("Unable to refresh work for 
shutdown stream.");
+  public void refreshActiveWork(Map<String, Collection<HeartbeatRequest>> 
heartbeats)
+      throws WindmillStreamShutdownException {
+    synchronized (this) {
+      if (isShutdown) {

Review Comment:
   rm? the check is racy and we need to handle exceptiosn thrown by below sends 
anyway



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java:
##########
@@ -301,39 +341,58 @@ public void appendSpecificHtml(PrintWriter writer) {
     writer.append("]");
   }
 
-  private <ResponseT> ResponseT issueRequest(QueuedRequest request, 
ParseFn<ResponseT> parseFn) {
-    while (true) {
+  private <ResponseT> ResponseT issueRequest(QueuedRequest request, 
ParseFn<ResponseT> parseFn)
+      throws WindmillStreamShutdownException {
+    while (!isShutdownLocked()) {
       request.resetResponseStream();
       try {
         queueRequestAndWait(request);
         return parseFn.parse(request.getResponseStream());
-      } catch (CancellationException e) {
-        // Retry issuing the request since the response stream was cancelled.
-        continue;
+      } catch (AppendableInputStream.InvalidInputStreamStateException | 
CancellationException e) {
+        handleShutdown(request, e);
+        if (!(e instanceof CancellationException)) {
+          throw e;
+        }
       } catch (IOException e) {
         LOG.error("Parsing GetData response failed: ", e);
-        continue;
       } catch (InterruptedException e) {
         Thread.currentThread().interrupt();
+        handleShutdown(request, e);
         throw new RuntimeException(e);
       } finally {
         pending.remove(request.id());
       }
     }
+
+    throw shutdownExceptionFor(request);
   }
 
-  private void queueRequestAndWait(QueuedRequest request) throws 
InterruptedException {
+  private synchronized void handleShutdown(QueuedRequest request, Throwable 
cause)

Review Comment:
   nit: name throwIfShutdown?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java:
##########
@@ -342,62 +401,86 @@ private void queueRequestAndWait(QueuedRequest request) 
throws InterruptedExcept
       batch.addRequest(request);
     }
     if (responsibleForSend) {
-      if (waitForSendLatch == null) {
+      if (prevBatch == null) {
         // If there was not a previous batch wait a little while to improve
         // batching.
-        Thread.sleep(1);
+        sleeper.sleep(1);
       } else {
-        waitForSendLatch.await();
+        prevBatch.waitForSendOrFailNotification();
       }
       // Finalize the batch so that no additional requests will be added.  
Leave the batch in the
       // queue so that a subsequent batch will wait for its completion.
-      synchronized (batches) {
-        verify(batch == batches.peekFirst());
+      synchronized (this) {
+        if (isShutdown) {
+          throw shutdownExceptionFor(batch);
+        }
+
+        verify(batch == batches.peekFirst(), "GetDataStream request batch 
removed before send().");
         batch.markFinalized();
       }
-      sendBatch(batch.requests());
-      synchronized (batches) {
-        verify(batch == batches.pollFirst());
+      trySendBatch(batch);
+    } else {
+      // Wait for this batch to be sent before parsing the response.
+      batch.waitForSendOrFailNotification();
+    }
+  }
+
+  void trySendBatch(QueuedBatch batch) throws WindmillStreamShutdownException {
+    try {
+      sendBatch(batch);
+      synchronized (this) {
+        if (isShutdown) {
+          throw shutdownExceptionFor(batch);
+        }
+
+        verify(
+            batch == batches.pollFirst(),
+            "Sent GetDataStream request batch removed before send() was 
complete.");
       }
       // Notify all waiters with requests in this batch as well as the sender
       // of the next batch (if one exists).
-      batch.countDown();
-    } else {
-      // Wait for this batch to be sent before parsing the response.
-      batch.await();
+      batch.notifySent();
+    } catch (Exception e) {
+      // Free waiters if the send() failed.
+      batch.notifyFailed();
+      // Propagate the exception to the calling thread.
+      throw e;
     }
   }
 
-  @SuppressWarnings("NullableProblems")
-  private void sendBatch(List<QueuedRequest> requests) {
-    StreamingGetDataRequest batchedRequest = flushToBatch(requests);
+  private void sendBatch(QueuedBatch batch) throws 
WindmillStreamShutdownException {
+    if (batch.isEmpty()) {
+      return;
+    }
+
+    // Synchronization of pending inserts is necessary with send to ensure 
duplicates are not
+    // sent on stream reconnect.
     synchronized (this) {
-      // Synchronization of pending inserts is necessary with send to ensure 
duplicates are not
-      // sent on stream reconnect.
-      for (QueuedRequest request : requests) {
+      if (isShutdown) {
+        throw shutdownExceptionFor(batch);
+      }
+
+      for (QueuedRequest request : batch.requestsReadOnly()) {
         // Map#put returns null if there was no previous mapping for the key, 
meaning we have not
         // seen it before.
-        verify(pending.put(request.id(), request.getResponseStream()) == null);
+        verify(
+            pending.put(request.id(), request.getResponseStream()) == null,
+            "Request already sent.");
       }
-      try {
-        send(batchedRequest);
-      } catch (IllegalStateException e) {
+
+      if (!trySend(batch.asGetDataRequest())) {
         // The stream broke before this call went through; onNewStream will 
retry the fetch.
-        LOG.warn("GetData stream broke before call started.", e);
+        LOG.warn("GetData stream broke before call started.");
       }
     }
   }
 
-  @SuppressWarnings("argument")
-  private StreamingGetDataRequest flushToBatch(List<QueuedRequest> requests) {
-    // Put all global data requests first because there is only a single 
repeated field for
-    // request ids and the initial ids correspond to global data requests if 
they are present.
-    requests.sort(QueuedRequest.globalRequestsFirst());
-    StreamingGetDataRequest.Builder builder = 
StreamingGetDataRequest.newBuilder();
-    for (QueuedRequest request : requests) {
-      request.addToStreamingGetDataRequest(builder);
-    }
-    return builder.build();
+  private synchronized void verify(boolean condition, String message) {

Review Comment:
   can this be removed? it seems you have shutdown check first in cases I see 
(and if not I think it would be clearer to have it as part of the check where 
it is than hidden in method that doesn't sound like it examines shutdown)



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java:
##########
@@ -152,114 +149,157 @@ private static long debugDuration(long nowMs, long 
startMs) {
    */
   protected abstract void startThrottleTimer();
 
-  /** Reflects that {@link #shutdown()} was explicitly called. */
-  protected boolean isShutdown() {
-    return isShutdown.get();
-  }
-
-  private StreamObserver<RequestT> requestObserver() {
-    if (requestObserver == null) {
-      throw new NullPointerException(
-          "requestObserver cannot be null. Missing a call to startStream() to 
initialize.");
+  /** Try to send a request to the server. Returns true if the request was 
successfully sent. */
+  @CanIgnoreReturnValue
+  protected final synchronized boolean trySend(RequestT request)
+      throws WindmillStreamShutdownException {
+    debugMetrics.recordSend();
+    try {
+      requestObserver.onNext(request);
+      return true;
+    } catch (StreamClosedException e) {
+      // Stream was broken, requests may be retried when stream is reopened.
     }
 
-    return requestObserver;
+    return false;
   }
 
-  /** Send a request to the server. */
-  protected final void send(RequestT request) {
-    lastSendTimeMs.set(Instant.now().getMillis());
+  @Override
+  public final void start() {
+    boolean shouldStartStream = false;
     synchronized (this) {
-      if (streamClosed.get()) {
-        throw new IllegalStateException("Send called on a client closed 
stream.");
+      if (!isShutdown && !started) {
+        started = true;
+        shouldStartStream = true;
       }
+    }
 
-      requestObserver().onNext(request);
+    if (shouldStartStream) {
+      startStream();
     }
   }
 
   /** Starts the underlying stream. */
-  protected final void startStream() {
+  private void startStream() {
     // Add the stream to the registry after it has been fully constructed.
     streamRegistry.add(this);
     while (true) {
       try {
         synchronized (this) {
-          startTimeMs.set(Instant.now().getMillis());
-          lastResponseTimeMs.set(0);
-          streamClosed.set(false);
-          // lazily initialize the requestObserver. Gets reset whenever the 
stream is reopened.
-          requestObserver = requestObserverSupplier.get();
+          debugMetrics.recordStart();
+          requestObserver.reset();
           onNewStream();
-          if (clientClosed.get()) {
+          if (clientClosed) {
             halfClose();
           }
           return;
         }
+      } catch (WindmillStreamShutdownException e) {
+        // shutdown() is responsible for cleaning up pending requests.
+        logger.debug("Stream was shutdown while creating new stream.", e);
       } catch (Exception e) {
-        LOG.error("Failed to create new stream, retrying: ", e);
+        logger.error("Failed to create new stream, retrying: ", e);
         try {
           long sleep = backoff.nextBackOffMillis();
-          sleepUntil.set(Instant.now().getMillis() + sleep);
-          Thread.sleep(sleep);
-        } catch (InterruptedException | IOException i) {
+          debugMetrics.recordSleep(sleep);
+          sleeper.sleep(sleep);
+        } catch (InterruptedException ie) {
+          Thread.currentThread().interrupt();
+          logger.info(
+              "Interrupted during {} creation backoff. The stream will not be 
created.",
+              getClass());
+          // Shutdown the stream to clean up any dangling resources and 
pending requests.
+          shutdown();
+          break;
+        } catch (IOException ioe) {
           // Keep trying to create the stream.
         }
       }
     }
+
+    // We were never able to start the stream, remove it from the stream 
registry. Otherwise, it is
+    // removed when closed.
+    streamRegistry.remove(this);
   }
 
-  protected final Executor executor() {
-    return executor;
+  /**
+   * Execute the runnable using the {@link #executor} handling the executor 
being in a shutdown
+   * state.
+   */
+  protected final void executeSafely(Runnable runnable) {
+    try {
+      executor.execute(runnable);
+    } catch (RejectedExecutionException e) {
+      logger.debug("{}-{} has been shutdown.", getClass(), backendWorkerToken);
+    }
   }
 
   public final synchronized void maybeSendHealthCheck(Instant 
lastSendThreshold) {
-    if (lastSendTimeMs.get() < lastSendThreshold.getMillis() && 
!clientClosed.get()) {
+    if (!clientClosed && debugMetrics.getLastSendTimeMs() < 
lastSendThreshold.getMillis()) {
       try {
         sendHealthCheck();
-      } catch (RuntimeException e) {
-        LOG.debug("Received exception sending health check.", e);
+      } catch (Exception e) {
+        logger.debug("Received exception sending health check.", e);
       }
     }
   }
 
-  protected abstract void sendHealthCheck();
+  protected abstract void sendHealthCheck() throws 
WindmillStreamShutdownException;
 
-  // Care is taken that synchronization on this is unnecessary for all status 
page information.
-  // Blocking sends are made beneath this stream object's lock which could 
block status page
-  // rendering.
+  /**
+   * @implNote Care is taken that synchronization on this is unnecessary for 
all status page
+   *     information. Blocking sends are made beneath this stream object's 
lock which could block
+   *     status page rendering.
+   */
   public final void appendSummaryHtml(PrintWriter writer) {
     appendSpecificHtml(writer);
-    if (errorCount.get() > 0) {
-      writer.format(
-          ", %d errors, last error [ %s ] at [%s]",
-          errorCount.get(), lastError.get(), lastErrorTime.get());
-    }
-    if (clientClosed.get()) {
+    StreamDebugMetrics.Snapshot summaryMetrics = 
debugMetrics.getSummaryMetrics();
+    summaryMetrics
+        .restartMetrics()
+        .ifPresent(
+            metrics ->
+                writer.format(
+                    ", %d restarts, last restart reason [ %s ] at [%s], %d 
errors",
+                    metrics.restartCount(),
+                    metrics.lastRestartReason(),
+                    metrics.lastRestartTime(),
+                    metrics.errorCount()));
+
+    if (summaryMetrics.isClientClosed()) {
       writer.write(", client closed");
     }
-    long nowMs = Instant.now().getMillis();
-    long sleepLeft = sleepUntil.get() - nowMs;
-    if (sleepLeft > 0) {
-      writer.format(", %dms backoff remaining", sleepLeft);
+
+    if (summaryMetrics.sleepLeft() > 0) {
+      writer.format(", %dms backoff remaining", summaryMetrics.sleepLeft());
     }
+
     writer.format(
-        ", current stream is %dms old, last send %dms, last response %dms, 
closed: %s",
-        debugDuration(nowMs, startTimeMs.get()),
-        debugDuration(nowMs, lastSendTimeMs.get()),
-        debugDuration(nowMs, lastResponseTimeMs.get()),
-        streamClosed.get());
+        ", current stream is %dms old, last send %dms, last response %dms, 
closed: %s, "
+            + "isShutdown: %s, shutdown time: %s",
+        summaryMetrics.streamAge(),
+        summaryMetrics.timeSinceLastSend(),
+        summaryMetrics.timeSinceLastResponse(),
+        requestObserver.isClosed(),
+        summaryMetrics.shutdownTime().isPresent(),
+        summaryMetrics.shutdownTime().orElse(null));
   }
 
-  // Don't require synchronization on stream, see the appendSummaryHtml 
comment.
+  /**
+   * @implNote Don't require synchronization on stream, see the {@link
+   *     #appendSummaryHtml(PrintWriter)} comment.
+   */
   protected abstract void appendSpecificHtml(PrintWriter writer);
 
   @Override
   public final synchronized void halfClose() {
     // Synchronization of close and onCompleted necessary for correct retry 
logic in onNewStream.
-    clientClosed.set(true);
-    requestObserver().onCompleted();
-    streamClosed.set(true);
+    debugMetrics.recordHalfClose();
+    clientClosed = true;
+    try {
+      requestObserver.onCompleted();
+    } catch (StreamClosedException | WindmillStreamShutdownException e) {
+      logger.warn("Stream was previously closed or shutdown.");

Review Comment:
   separate the catches  to log differently instead of having an "or"?
   



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java:
##########
@@ -156,29 +163,44 @@ public void sendHealthCheck() {
   protected void onResponse(StreamingCommitResponse response) {
     commitWorkThrottleTimer.stop();
 
-    RuntimeException finalException = null;
+    CommitCompletionException failures = new CommitCompletionException();

Review Comment:
   I was thinking the builder would not be an exception itself, and then the 
exception would just be a simple class without mutating methods.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java:
##########
@@ -67,76 +72,128 @@ public DirectStreamObserver(
   }
 
   @Override
-  public void onNext(T value) {
+  public void onNext(T value) throws StreamObserverCancelledException {
     int awaitPhase = -1;
     long totalSecondsWaited = 0;
     long waitSeconds = 1;
     while (true) {
       try {
         synchronized (lock) {
+          int currentPhase = isReadyNotifier.getPhase();
+          // Phaser is terminated so don't use the outboundObserver. Since 
onError and onCompleted
+          // are synchronized after terminating the phaser if we observe that 
the phaser is not
+          // terminated the onNext calls below are guaranteed to not be called 
on a closed observer.
+          if (currentPhase < 0) return;
+
+          // If we awaited previously and timed out, wait for the same phase. 
Otherwise we're
+          // careful to observe the phase before observing isReady.
+          if (awaitPhase < 0) {
+            awaitPhase = isReadyNotifier.getPhase();
+            // If getPhase() returns a value less than 0, the phaser has been 
terminated.
+            if (awaitPhase < 0) {
+              return;
+            }
+          }
+
           // We only check isReady periodically to effectively allow for 
increasing the outbound
           // buffer periodically. This reduces the overhead of blocking while 
still restricting
           // memory because there is a limited # of streams, and we have a max 
messages size of 2MB.
           if (++messagesSinceReady <= messagesBetweenIsReadyChecks) {
             outboundObserver.onNext(value);
             return;
           }
-          // If we awaited previously and timed out, wait for the same phase. 
Otherwise we're
-          // careful to observe the phase before observing isReady.
-          if (awaitPhase < 0) {
-            awaitPhase = phaser.getPhase();
-          }
+
           if (outboundObserver.isReady()) {
             messagesSinceReady = 0;
             outboundObserver.onNext(value);
             return;
           }
         }
+
         // A callback has been registered to advance the phaser whenever the 
observer
         // transitions to  is ready. Since we are waiting for a phase observed 
before the
         // outboundObserver.isReady() returned false, we expect it to advance 
after the
         // channel has become ready.  This doesn't always seem to be the case 
(despite
         // documentation stating otherwise) so we poll periodically and 
enforce an overall
         // timeout related to the stream deadline.
-        phaser.awaitAdvanceInterruptibly(awaitPhase, waitSeconds, 
TimeUnit.SECONDS);
+        int nextPhase =
+            isReadyNotifier.awaitAdvanceInterruptibly(awaitPhase, waitSeconds, 
TimeUnit.SECONDS);
+        // If nextPhase is a value less than 0, the phaser has been terminated.
+        if (nextPhase < 0) {
+          throw new StreamObserverCancelledException("StreamObserver was 
terminated.");
+        }
+
         synchronized (lock) {
+          int currentPhase = isReadyNotifier.getPhase();
+          // Phaser is terminated so don't use the outboundObserver. Since 
onError and onCompleted
+          // are synchronized after terminating the phaser if we observe that 
the phaser is not
+          // terminated the onNext calls below are guaranteed to not be called 
on a closed observer.
+          if (currentPhase < 0) return;
           messagesSinceReady = 0;
           outboundObserver.onNext(value);
           return;
         }
       } catch (TimeoutException e) {
         totalSecondsWaited += waitSeconds;
         if (totalSecondsWaited > deadlineSeconds) {
-          LOG.error(
-              "Exceeded timeout waiting for the outboundObserver to become 
ready meaning "
-                  + "that the stream deadline was not respected.");
-          throw new RuntimeException(e);
+          String errorMessage = 
constructStreamCancelledErrorMessage(totalSecondsWaited);
+          LOG.error(errorMessage);
+          throw new StreamObserverCancelledException(errorMessage, e);
         }
-        if (totalSecondsWaited > 30) {
+
+        if (totalSecondsWaited > OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS) {
           LOG.info(
               "Output channel stalled for {}s, outbound thread {}.",
               totalSecondsWaited,
               Thread.currentThread().getName());
         }
+
         waitSeconds = waitSeconds * 2;
       } catch (InterruptedException e) {
         Thread.currentThread().interrupt();
-        throw new RuntimeException(e);
+        throw new StreamObserverCancelledException(e);
       }
     }
   }
 
   @Override
   public void onError(Throwable t) {
+    isReadyNotifier.forceTermination();
     synchronized (lock) {
+      isClosed = true;

Review Comment:
   I think that in onError/onCompleted we shouldn't call the outboundObserver 
method if we've already closed.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java:
##########
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers;
+
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Internal;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+
+@Internal
+public interface TerminatingStreamObserver<T> extends StreamObserver<T> {
+
+  /** Terminates the StreamObserver. */

Review Comment:
   add some more comments on what this mean, how it is different from onError.
   
   Seems like compared to onError which we expect to be called once, terminate 
may be called multiple times and interleaved with other stream operations.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java:
##########
@@ -67,76 +72,128 @@ public DirectStreamObserver(
   }
 
   @Override
-  public void onNext(T value) {
+  public void onNext(T value) throws StreamObserverCancelledException {
     int awaitPhase = -1;
     long totalSecondsWaited = 0;
     long waitSeconds = 1;
     while (true) {
       try {
         synchronized (lock) {
+          int currentPhase = isReadyNotifier.getPhase();
+          // Phaser is terminated so don't use the outboundObserver. Since 
onError and onCompleted
+          // are synchronized after terminating the phaser if we observe that 
the phaser is not
+          // terminated the onNext calls below are guaranteed to not be called 
on a closed observer.
+          if (currentPhase < 0) return;

Review Comment:
   in other cases we're throwing StreamObserverCancelledException once phaser 
is terminated, should we do so for this return to be consistent?



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java:
##########
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import javax.annotation.Nullable;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.ServerCallStreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcCommitWorkStreamTest {
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcCommitWorkStreamTest";
+  private static final Windmill.JobHeader TEST_JOB_HEADER =
+      Windmill.JobHeader.newBuilder()
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+  private static final String COMPUTATION_ID = "computationId";
+
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+  private ManagedChannel inProcessChannel;
+
+  private static Windmill.WorkItemCommitRequest workItemCommitRequest(long 
value) {
+    return Windmill.WorkItemCommitRequest.newBuilder()
+        .setKey(ByteString.EMPTY)
+        .setShardingKey(value)
+        .setWorkToken(value)
+        .setCacheToken(value)
+        .build();
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+  }
+
+  private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamTestStub 
testStub) {
+    serviceRegistry.addService(testStub);
+    GrpcCommitWorkStream commitWorkStream =
+        (GrpcCommitWorkStream)
+            GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+                .build()
+                .createCommitWorkStream(
+                    CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel),
+                    new ThrottleTimer());
+    commitWorkStream.start();
+    return commitWorkStream;
+  }
+
+  @Test
+  public void testShutdown_abortsQueuedCommits() throws InterruptedException {
+    int numCommits = 5;
+    CountDownLatch commitProcessed = new CountDownLatch(numCommits);
+    Set<Windmill.CommitStatus> onDone = new HashSet<>();
+
+    TestCommitWorkStreamRequestObserver requestObserver =
+        spy(new TestCommitWorkStreamRequestObserver());
+    CommitWorkStreamTestStub testStub = new 
CommitWorkStreamTestStub(requestObserver);
+    GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
+    try (WindmillStream.CommitWorkStream.RequestBatcher batcher = 
commitWorkStream.batcher()) {
+      for (int i = 0; i < numCommits; i++) {
+        batcher.commitWorkItem(
+            COMPUTATION_ID,
+            workItemCommitRequest(i),
+            commitStatus -> {
+              onDone.add(commitStatus);
+              commitProcessed.countDown();
+            });
+      }
+    }
+
+    // Verify that we sent the commits above in a request + the initial header.
+    verify(requestObserver, 
times(2)).onNext(any(Windmill.StreamingCommitWorkRequest.class));
+    // We won't get responses so we will have some pending requests.
+    assertTrue(commitWorkStream.hasPendingRequests());
+
+    commitWorkStream.shutdown();
+    commitProcessed.await();
+
+    assertThat(onDone).containsExactly(Windmill.CommitStatus.ABORTED);
+  }
+
+  @Test
+  public void testCommitWorkItem_afterShutdownFalse() {
+    int numCommits = 5;
+
+    CommitWorkStreamTestStub testStub =
+        new CommitWorkStreamTestStub(new 
TestCommitWorkStreamRequestObserver());
+    GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
+
+    try (WindmillStream.CommitWorkStream.RequestBatcher batcher = 
commitWorkStream.batcher()) {
+      for (int i = 0; i < numCommits; i++) {
+        assertTrue(batcher.commitWorkItem(COMPUTATION_ID, 
workItemCommitRequest(i), ignored -> {}));
+      }
+    }
+    commitWorkStream.shutdown();
+
+    Set<Windmill.CommitStatus> commitStatuses = new HashSet<>();

Review Comment:
   Not done, on second reading though I think you could get rid of the set.
   Instead you could have a  AtomicReference<windmill.CommitStatus> within the 
inner loop, since you expect it to be set inline to commitworkitem.



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java:
##########
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertThrows;
+
+import java.io.PrintWriter;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
+import org.apache.beam.sdk.util.FluentBackoff;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.slf4j.LoggerFactory;
+
+@RunWith(JUnit4.class)
+public class AbstractWindmillStreamTest {
+  private static final long DEADLINE_SECONDS = 10;
+  private final Set<AbstractWindmillStream<?, ?>> streamRegistry = 
ConcurrentHashMap.newKeySet();
+  private final StreamObserverFactory streamObserverFactory =
+      StreamObserverFactory.direct(DEADLINE_SECONDS, 1);
+
+  @Before
+  public void setUp() {
+    streamRegistry.clear();
+  }
+
+  private TestStream newStream(
+      Function<StreamObserver<Integer>, StreamObserver<Integer>> 
clientFactory) {
+    return new TestStream(clientFactory, streamRegistry, 
streamObserverFactory);
+  }
+
+  @Test
+  public void testShutdown_notBlockedBySend() throws InterruptedException, 
ExecutionException {
+    CountDownLatch sendBlocker = new CountDownLatch(1);
+    Function<StreamObserver<Integer>, StreamObserver<Integer>> clientFactory =
+        ignored ->
+            new CallStreamObserver<Integer>() {
+              @Override
+              public void onNext(Integer integer) {
+                try {
+                  sendBlocker.await();
+                } catch (InterruptedException e) {
+                  throw new RuntimeException(e);
+                }
+              }
+
+              @Override
+              public void onError(Throwable throwable) {}
+
+              @Override
+              public void onCompleted() {}
+
+              @Override
+              public boolean isReady() {
+                return false;
+              }
+
+              @Override
+              public void setOnReadyHandler(Runnable runnable) {}
+
+              @Override
+              public void disableAutoInboundFlowControl() {}
+
+              @Override
+              public void request(int i) {}
+
+              @Override
+              public void setMessageCompression(boolean b) {}
+            };
+
+    TestStream testStream = newStream(clientFactory);
+    testStream.start();
+    ExecutorService sendExecutor = Executors.newSingleThreadExecutor();
+    Future<WindmillStreamShutdownException> sendFuture =
+        sendExecutor.submit(
+            () ->
+                assertThrows(WindmillStreamShutdownException.class, () -> 
testStream.testSend(1)));
+    testStream.shutdown();

Review Comment:
   add a sleep here with comment to give the above executor time to start 
running and block



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java:
##########
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.inOrder;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import javax.annotation.Nullable;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import 
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.ServerCallStreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.InOrder;
+
+@RunWith(JUnit4.class)
+public class GrpcCommitWorkStreamTest {
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcCommitWorkStreamTest";
+  private static final Windmill.JobHeader TEST_JOB_HEADER =
+      Windmill.JobHeader.newBuilder()
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+  private static final String COMPUTATION_ID = "computationId";
+
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+  private ManagedChannel inProcessChannel;
+
+  private static Windmill.WorkItemCommitRequest workItemCommitRequest(long 
value) {
+    return Windmill.WorkItemCommitRequest.newBuilder()
+        .setKey(ByteString.EMPTY)
+        .setShardingKey(value)
+        .setWorkToken(value)
+        .setCacheToken(value)
+        .build();
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+  }
+
+  private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamTestStub 
testStub) {
+    serviceRegistry.addService(testStub);
+    GrpcCommitWorkStream commitWorkStream =
+        (GrpcCommitWorkStream)
+            GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+                .build()
+                .createCommitWorkStream(
+                    CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel),
+                    new ThrottleTimer());
+    commitWorkStream.start();
+    return commitWorkStream;
+  }
+
+  @Test
+  public void testShutdown_abortsQueuedCommits() throws InterruptedException {
+    int numCommits = 5;
+    CountDownLatch commitProcessed = new CountDownLatch(numCommits);
+    Set<Windmill.CommitStatus> onDone = new HashSet<>();
+
+    TestCommitWorkStreamRequestObserver requestObserver =
+        spy(new TestCommitWorkStreamRequestObserver());
+    CommitWorkStreamTestStub testStub = new 
CommitWorkStreamTestStub(requestObserver);
+    GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
+    try (WindmillStream.CommitWorkStream.RequestBatcher batcher = 
commitWorkStream.batcher()) {
+      for (int i = 0; i < numCommits; i++) {
+        batcher.commitWorkItem(
+            COMPUTATION_ID,
+            workItemCommitRequest(i),
+            commitStatus -> {
+              onDone.add(commitStatus);
+              commitProcessed.countDown();
+            });
+      }
+    }
+
+    // Verify that we sent the commits above in a request + the initial header.
+    verify(requestObserver, 
times(2)).onNext(any(Windmill.StreamingCommitWorkRequest.class));
+    // We won't get responses so we will have some pending requests.
+    assertTrue(commitWorkStream.hasPendingRequests());
+
+    commitWorkStream.shutdown();
+    commitProcessed.await();
+
+    assertThat(onDone).containsExactly(Windmill.CommitStatus.ABORTED);
+  }
+
+  @Test
+  public void testCommitWorkItem_afterShutdown() {
+    int numCommits = 5;
+
+    CommitWorkStreamTestStub testStub =
+        new CommitWorkStreamTestStub(new 
TestCommitWorkStreamRequestObserver());
+    GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
+
+    try (WindmillStream.CommitWorkStream.RequestBatcher batcher = 
commitWorkStream.batcher()) {
+      for (int i = 0; i < numCommits; i++) {
+        assertTrue(batcher.commitWorkItem(COMPUTATION_ID, 
workItemCommitRequest(i), ignored -> {}));
+      }
+    }
+    commitWorkStream.shutdown();
+
+    Set<Windmill.CommitStatus> commitStatuses = new HashSet<>();
+    try (WindmillStream.CommitWorkStream.RequestBatcher batcher = 
commitWorkStream.batcher()) {
+      for (int i = 0; i < numCommits; i++) {
+        assertTrue(
+            batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), 
commitStatuses::add));
+      }
+    }
+
+    assertThat(commitStatuses).containsExactly(Windmill.CommitStatus.ABORTED);
+  }
+
+  @Test
+  public void testSend_notCalledAfterShutdown() {
+    int numCommits = 5;
+    CountDownLatch commitProcessed = new CountDownLatch(numCommits);
+
+    TestCommitWorkStreamRequestObserver requestObserver =
+        spy(new TestCommitWorkStreamRequestObserver());
+    InOrder requestObserverVerifier = inOrder(requestObserver);
+
+    CommitWorkStreamTestStub testStub = new 
CommitWorkStreamTestStub(requestObserver);
+    GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
+    try (WindmillStream.CommitWorkStream.RequestBatcher batcher = 
commitWorkStream.batcher()) {
+      for (int i = 0; i < numCommits; i++) {
+        assertTrue(
+            batcher.commitWorkItem(
+                COMPUTATION_ID,
+                workItemCommitRequest(i),
+                commitStatus -> commitProcessed.countDown()));
+      }
+      // Shutdown the stream before we exit the try-with-resources block which 
will try to send()
+      // the batched request.
+      commitWorkStream.shutdown();
+    }
+
+    // send() uses the requestObserver to send requests. We expect 1 send 
since startStream() sends
+    // the header, which happens before we shutdown.
+    requestObserverVerifier
+        .verify(requestObserver)
+        .onNext(any(Windmill.StreamingCommitWorkRequest.class));

Review Comment:
   verify it is header?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java:
##########
@@ -67,76 +72,128 @@ public DirectStreamObserver(
   }
 
   @Override
-  public void onNext(T value) {
+  public void onNext(T value) throws StreamObserverCancelledException {
     int awaitPhase = -1;
     long totalSecondsWaited = 0;
     long waitSeconds = 1;
     while (true) {
       try {
         synchronized (lock) {
+          int currentPhase = isReadyNotifier.getPhase();
+          // Phaser is terminated so don't use the outboundObserver. Since 
onError and onCompleted
+          // are synchronized after terminating the phaser if we observe that 
the phaser is not
+          // terminated the onNext calls below are guaranteed to not be called 
on a closed observer.
+          if (currentPhase < 0) return;
+
+          // If we awaited previously and timed out, wait for the same phase. 
Otherwise we're
+          // careful to observe the phase before observing isReady.
+          if (awaitPhase < 0) {
+            awaitPhase = isReadyNotifier.getPhase();
+            // If getPhase() returns a value less than 0, the phaser has been 
terminated.
+            if (awaitPhase < 0) {
+              return;
+            }
+          }
+
           // We only check isReady periodically to effectively allow for 
increasing the outbound
           // buffer periodically. This reduces the overhead of blocking while 
still restricting
           // memory because there is a limited # of streams, and we have a max 
messages size of 2MB.
           if (++messagesSinceReady <= messagesBetweenIsReadyChecks) {
             outboundObserver.onNext(value);
             return;
           }
-          // If we awaited previously and timed out, wait for the same phase. 
Otherwise we're
-          // careful to observe the phase before observing isReady.
-          if (awaitPhase < 0) {
-            awaitPhase = phaser.getPhase();
-          }
+
           if (outboundObserver.isReady()) {
             messagesSinceReady = 0;
             outboundObserver.onNext(value);
             return;
           }
         }
+
         // A callback has been registered to advance the phaser whenever the 
observer
         // transitions to  is ready. Since we are waiting for a phase observed 
before the
         // outboundObserver.isReady() returned false, we expect it to advance 
after the
         // channel has become ready.  This doesn't always seem to be the case 
(despite
         // documentation stating otherwise) so we poll periodically and 
enforce an overall
         // timeout related to the stream deadline.
-        phaser.awaitAdvanceInterruptibly(awaitPhase, waitSeconds, 
TimeUnit.SECONDS);
+        int nextPhase =
+            isReadyNotifier.awaitAdvanceInterruptibly(awaitPhase, waitSeconds, 
TimeUnit.SECONDS);
+        // If nextPhase is a value less than 0, the phaser has been terminated.
+        if (nextPhase < 0) {
+          throw new StreamObserverCancelledException("StreamObserver was 
terminated.");
+        }
+
         synchronized (lock) {
+          int currentPhase = isReadyNotifier.getPhase();
+          // Phaser is terminated so don't use the outboundObserver. Since 
onError and onCompleted
+          // are synchronized after terminating the phaser if we observe that 
the phaser is not
+          // terminated the onNext calls below are guaranteed to not be called 
on a closed observer.
+          if (currentPhase < 0) return;
           messagesSinceReady = 0;
           outboundObserver.onNext(value);
           return;
         }
       } catch (TimeoutException e) {
         totalSecondsWaited += waitSeconds;
         if (totalSecondsWaited > deadlineSeconds) {
-          LOG.error(
-              "Exceeded timeout waiting for the outboundObserver to become 
ready meaning "
-                  + "that the stream deadline was not respected.");
-          throw new RuntimeException(e);
+          String errorMessage = 
constructStreamCancelledErrorMessage(totalSecondsWaited);
+          LOG.error(errorMessage);
+          throw new StreamObserverCancelledException(errorMessage, e);

Review Comment:
   maybe this shoudl remain a different exception since we are catching the 
streamobservercancelledexceptions because we think we terminated?
   
   This is more indicating that somethign is going wrong with grpc and we might 
prefer to bubble up and crash.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java:
##########
@@ -234,8 +229,13 @@ public void appendSpecificHtml(PrintWriter writer) {
   }
 
   @Override
-  public void sendHealthCheck() {
-    send(HEALTH_CHECK_REQUEST);
+  public void sendHealthCheck() throws WindmillStreamShutdownException {
+    trySend(HEALTH_CHECK_REQUEST);
+  }
+
+  @Override
+  protected void shutdownInternal() {
+    workItemAssemblers.clear();

Review Comment:
   is this needed? I think it could cause issues since onResponse may run 
before the shutdown terminates grpc stream.  And then it could cause some error 
in assembling the work item.
   
   Seems easiest to just not clean up until stream itself is deleted.



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java:
##########
@@ -460,8 +470,9 @@ private void flushResponse() {
                       "Sending batched response of {} ids", 
responseBuilder.getRequestIdCount());
                   try {
                     responseObserver.onNext(responseBuilder.build());
-                  } catch (IllegalStateException e) {
+                  } catch (Exception e) {
                     // Stream is already closed.
+                    System.out.println("trieu: " + e);

Review Comment:
   rm



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java:
##########
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.same;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.beam.sdk.fn.stream.AdvancingPhaser;
+import 
org.apache.beam.vendor.grpc.v1p60p1.com.google.common.base.VerifyException;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class DirectStreamObserverTest {
+
+  @Test

Review Comment:
   add some simple test coverage as well
   - send a couple times, onComplete
   - send blocking for onRead, become onReady, send, onComplete
   - send a couple times, onError



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java:
##########
@@ -121,32 +131,44 @@ public static GrpcGetDataStream create(
       int streamingRpcBatchLimit,
       boolean sendKeyedGetDataRequests,
       Consumer<List<Windmill.ComputationHeartbeatResponse>> 
processHeartbeatResponses) {
-    GrpcGetDataStream getDataStream =
-        new GrpcGetDataStream(
-            backendWorkerToken,
-            startGetDataRpcFn,
-            backoff,
-            streamObserverFactory,
-            streamRegistry,
-            logEveryNStreamFailures,
-            getDataThrottleTimer,
-            jobHeader,
-            idGenerator,
-            streamingRpcBatchLimit,
-            sendKeyedGetDataRequests,
-            processHeartbeatResponses);
-    getDataStream.startStream();
-    return getDataStream;
+    return new GrpcGetDataStream(
+        backendWorkerToken,
+        startGetDataRpcFn,
+        backoff,
+        streamObserverFactory,
+        streamRegistry,
+        logEveryNStreamFailures,
+        getDataThrottleTimer,
+        jobHeader,
+        idGenerator,
+        streamingRpcBatchLimit,
+        sendKeyedGetDataRequests,
+        processHeartbeatResponses);
+  }
+
+  private static WindmillStreamShutdownException 
shutdownExceptionFor(QueuedBatch batch) {
+    return new WindmillStreamShutdownException(
+        "Stream was closed when attempting to send " + batch.requestsCount() + 
" requests.");
+  }
+
+  private static WindmillStreamShutdownException 
shutdownExceptionFor(QueuedRequest request) {
+    return new WindmillStreamShutdownException(
+        "Cannot send request=[" + request + "] on closed stream.");
+  }
+
+  private void sendIgnoringClosed(StreamingGetDataRequest getDataRequest)
+      throws WindmillStreamShutdownException {
+    trySend(getDataRequest);
   }
 
   @Override
-  protected synchronized void onNewStream() {
-    send(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build());
-    if (clientClosed.get()) {
+  protected synchronized void onNewStream() throws 
WindmillStreamShutdownException {
+    trySend(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build());
+    if (clientClosed) {
       // We rely on close only occurring after all methods on the stream have 
returned.
       // Since the requestKeyedData and requestGlobalData methods are blocking 
this
       // means there should be no pending requests.
-      verify(!hasPendingRequests());
+      verify(!hasPendingRequests(), "Pending requests not expected on stream 
restart.");

Review Comment:
   nit: it seems more we don't expect pendingrequests if we've half-closed



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java:
##########
@@ -342,62 +401,86 @@ private void queueRequestAndWait(QueuedRequest request) 
throws InterruptedExcept
       batch.addRequest(request);
     }
     if (responsibleForSend) {
-      if (waitForSendLatch == null) {
+      if (prevBatch == null) {
         // If there was not a previous batch wait a little while to improve
         // batching.
-        Thread.sleep(1);
+        sleeper.sleep(1);
       } else {
-        waitForSendLatch.await();
+        prevBatch.waitForSendOrFailNotification();
       }
       // Finalize the batch so that no additional requests will be added.  
Leave the batch in the
       // queue so that a subsequent batch will wait for its completion.
-      synchronized (batches) {
-        verify(batch == batches.peekFirst());
+      synchronized (this) {
+        if (isShutdown) {
+          throw shutdownExceptionFor(batch);
+        }
+
+        verify(batch == batches.peekFirst(), "GetDataStream request batch 
removed before send().");
         batch.markFinalized();
       }
-      sendBatch(batch.requests());
-      synchronized (batches) {
-        verify(batch == batches.pollFirst());
+      trySendBatch(batch);
+    } else {
+      // Wait for this batch to be sent before parsing the response.
+      batch.waitForSendOrFailNotification();
+    }
+  }
+
+  void trySendBatch(QueuedBatch batch) throws WindmillStreamShutdownException {
+    try {
+      sendBatch(batch);
+      synchronized (this) {
+        if (isShutdown) {
+          throw shutdownExceptionFor(batch);
+        }
+
+        verify(
+            batch == batches.pollFirst(),
+            "Sent GetDataStream request batch removed before send() was 
complete.");
       }
       // Notify all waiters with requests in this batch as well as the sender
       // of the next batch (if one exists).
-      batch.countDown();
-    } else {
-      // Wait for this batch to be sent before parsing the response.
-      batch.await();
+      batch.notifySent();
+    } catch (Exception e) {
+      // Free waiters if the send() failed.
+      batch.notifyFailed();
+      // Propagate the exception to the calling thread.
+      throw e;
     }
   }
 
-  @SuppressWarnings("NullableProblems")
-  private void sendBatch(List<QueuedRequest> requests) {
-    StreamingGetDataRequest batchedRequest = flushToBatch(requests);
+  private void sendBatch(QueuedBatch batch) throws 
WindmillStreamShutdownException {
+    if (batch.isEmpty()) {
+      return;
+    }
+
+    // Synchronization of pending inserts is necessary with send to ensure 
duplicates are not
+    // sent on stream reconnect.
     synchronized (this) {
-      // Synchronization of pending inserts is necessary with send to ensure 
duplicates are not
-      // sent on stream reconnect.
-      for (QueuedRequest request : requests) {
+      if (isShutdown) {
+        throw shutdownExceptionFor(batch);
+      }
+
+      for (QueuedRequest request : batch.requestsReadOnly()) {
         // Map#put returns null if there was no previous mapping for the key, 
meaning we have not
         // seen it before.
-        verify(pending.put(request.id(), request.getResponseStream()) == null);
+        verify(
+            pending.put(request.id(), request.getResponseStream()) == null,
+            "Request already sent.");
       }
-      try {
-        send(batchedRequest);
-      } catch (IllegalStateException e) {
+
+      if (!trySend(batch.asGetDataRequest())) {
         // The stream broke before this call went through; onNewStream will 
retry the fetch.
-        LOG.warn("GetData stream broke before call started.", e);
+        LOG.warn("GetData stream broke before call started.");
       }
     }
   }
 
-  @SuppressWarnings("argument")
-  private StreamingGetDataRequest flushToBatch(List<QueuedRequest> requests) {
-    // Put all global data requests first because there is only a single 
repeated field for
-    // request ids and the initial ids correspond to global data requests if 
they are present.
-    requests.sort(QueuedRequest.globalRequestsFirst());
-    StreamingGetDataRequest.Builder builder = 
StreamingGetDataRequest.newBuilder();
-    for (QueuedRequest request : requests) {
-      request.addToStreamingGetDataRequest(builder);
-    }
-    return builder.build();
+  private synchronized void verify(boolean condition, String message) {
+    Verify.verify(condition || isShutdown, message);
+  }
+
+  private synchronized boolean isShutdownLocked() {

Review Comment:
   rm if you remove above



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java:
##########
@@ -139,4 +166,29 @@ public void onCompleted() {
       outboundObserver.onCompleted();

Review Comment:
   I think that we expect the client to close with onComplete before the server 
would close triggering ForwardClientResponseObserver#onDone.
   So I don't think we would get an error here in the happy case where the 
client is initiating a clean shutdown.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java:
##########
@@ -139,4 +166,29 @@ public void onCompleted() {
       outboundObserver.onCompleted();
     }
   }
+
+  @Override
+  public void terminate(Throwable terminationException) {
+    // Free the blocked threads in onNext().
+    isReadyNotifier.forceTermination();

Review Comment:
   Hmm this now seems equivalent to onError, if onError itself takes care of 
not closing multiple times.
   
   This DirectStreamObserver behaves a little differently than what 
StreamObserver interface specifies since that expects serial calls (which 
DirectStreamObserver synchronizes internally to child) and doesn't expect 
onError/onComplete to be called multiple times.  Could we get rid of 
TerminatingStreamObserver and just use DirectStreamObserver in other classes? 
Maybe that makes testing more difficult than an interface though?  Otherwise 
should we make this terminate different in that terminate can be called 
multiple times but onError/onComplete cannot? And then internally on first 
terminate or onError we onError the underlying observer.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to