scwhittle commented on code in PR #32774:
URL: https://github.com/apache/beam/pull/32774#discussion_r1812469715
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java:
##########
@@ -122,12 +132,62 @@ void addRequest(QueuedRequest request) {
byteSize += request.byteSize();
}
- void countDown() {
+ /** Let waiting for threads know that the request has been successfully
sent. */
+ synchronized void notifySent() {
Review Comment:
don't think this needs synchronized
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java:
##########
@@ -17,44 +17,45 @@
*/
package org.apache.beam.runners.dataflow.worker.streaming.harness;
-import java.io.Closeable;
-import java.util.function.Supplier;
import javax.annotation.concurrent.ThreadSafe;
import
org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
import
org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.sdk.annotations.Internal;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
@Internal
@ThreadSafe
-// TODO (m-trieu): replace Supplier<Stream> with Stream after
github.com/apache/beam/pull/32774/ is
-// merged
-final class GlobalDataStreamSender implements Closeable,
Supplier<GetDataStream> {
+final class GlobalDataStreamSender implements StreamSender {
private final Endpoint endpoint;
- private final Supplier<GetDataStream> delegate;
+ private final GetDataStream delegate;
private volatile boolean started;
- GlobalDataStreamSender(Supplier<GetDataStream> delegate, Endpoint endpoint) {
- // Ensures that the Supplier is thread-safe
- this.delegate = Suppliers.memoize(delegate::get);
+ GlobalDataStreamSender(GetDataStream delegate, Endpoint endpoint) {
+ this.delegate = delegate;
this.started = false;
this.endpoint = endpoint;
}
- @Override
- public GetDataStream get() {
+ GetDataStream stream() {
if (!started) {
- started = true;
+ // Starting the stream possibly perform IO. Start the stream lazily
since not all pipeline
+ // implementations need to fetch global/side input data.
+ startStream();
}
- return delegate.get();
+ return delegate;
+ }
+
+ private synchronized void startStream() {
+ // Check started again after we acquire the lock.
+ if (!started) {
+ started = true;
+ delegate.start();
Review Comment:
seems like you should start() and then set started=true
otherwise if there are concurrent calls to stream() a non-started stream
could be exposed.
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java:
##########
@@ -122,12 +132,62 @@ void addRequest(QueuedRequest request) {
byteSize += request.byteSize();
}
- void countDown() {
+ /** Let waiting for threads know that the request has been successfully
sent. */
+ synchronized void notifySent() {
sent.countDown();
}
- void await() throws InterruptedException {
+ /** Let waiting for threads know that a failure occurred. */
+ synchronized void notifyFailed() {
Review Comment:
doesn't need synchronized since failed is volatile
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java:
##########
@@ -183,19 +182,9 @@ public void
testStartStream_onlyStartsStreamsOnceConcurrent() throws Interrupted
eq(workItemScheduler));
verify(streamFactory, times(1))
- .createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class));
+ .createDirectGetDataStream(eq(connection), any(ThrottleTimer.class));
verify(streamFactory, times(1))
- .createCommitWorkStream(eq(connection.stub()),
any(ThrottleTimer.class));
- }
-
- @Test
- public void testCloseAllStreams_doesNotCloseUnstartedStreams() {
- WindmillStreamSender windmillStreamSender =
-
newWindmillStreamSender(GetWorkBudget.builder().setBytes(1L).setItems(1L).build());
-
- windmillStreamSender.close();
-
- verifyNoInteractions(streamFactory);
Review Comment:
leave this test and verify that the stream is created and then not started
but closed?
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java:
##########
@@ -332,27 +332,36 @@ public CommitWorkResponse commitWork(CommitWorkRequest
request) {
@Override
public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver
receiver) {
Review Comment:
can you make sure StreamingEngineWindmillClient documents that the streams
returned by these methods are started now that it's possible for streams to not
be started?
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java:
##########
@@ -93,7 +98,8 @@ private GrpcWindmillStreamFactory(
int windmillMessagesBetweenIsReadyChecks,
boolean sendKeyedGetDataRequests,
Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses,
- Supplier<Duration> maxBackOffSupplier) {
+ Supplier<Duration> maxBackOffSupplier,
+ Set<AbstractWindmillStream<?, ?>> streamRegistry) {
Review Comment:
Can we remove this ability to inject the stream registry? It appears only
used for test to peek at registered streams but I think that would be clearer
to just use some @ VisibleForTesting method to expose registered streams or to
test if stream is registered.
Injecting things that are really managed/owned by this class is confusing
since callers could mutate state unexpectedly or inject non-concurrent set etc
that would break things if used in non-test situations.
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java:
##########
@@ -74,69 +76,104 @@ public void onNext(T value) {
while (true) {
try {
synchronized (lock) {
+ // If we awaited previously and timed out, wait for the same phase.
Otherwise we're
+ // careful to observe the phase before observing isReady.
+ if (awaitPhase < 0) {
+ awaitPhase = isReadyNotifier.getPhase();
+ // If getPhase() returns a value less than 0, the phaser has been
terminated.
+ if (awaitPhase < 0) {
+ return;
+ }
+ }
+
// We only check isReady periodically to effectively allow for
increasing the outbound
// buffer periodically. This reduces the overhead of blocking while
still restricting
// memory because there is a limited # of streams, and we have a max
messages size of 2MB.
if (++messagesSinceReady <= messagesBetweenIsReadyChecks) {
outboundObserver.onNext(value);
return;
}
- // If we awaited previously and timed out, wait for the same phase.
Otherwise we're
- // careful to observe the phase before observing isReady.
- if (awaitPhase < 0) {
- awaitPhase = phaser.getPhase();
- }
+
if (outboundObserver.isReady()) {
messagesSinceReady = 0;
outboundObserver.onNext(value);
return;
}
}
+
// A callback has been registered to advance the phaser whenever the
observer
// transitions to is ready. Since we are waiting for a phase observed
before the
// outboundObserver.isReady() returned false, we expect it to advance
after the
// channel has become ready. This doesn't always seem to be the case
(despite
// documentation stating otherwise) so we poll periodically and
enforce an overall
// timeout related to the stream deadline.
- phaser.awaitAdvanceInterruptibly(awaitPhase, waitSeconds,
TimeUnit.SECONDS);
+ int nextPhase =
+ isReadyNotifier.awaitAdvanceInterruptibly(awaitPhase, waitSeconds,
TimeUnit.SECONDS);
+ // If nextPhase is a value less than 0, the phaser has been terminated.
+ if (nextPhase < 0) {
+ return;
+ }
+
synchronized (lock) {
messagesSinceReady = 0;
outboundObserver.onNext(value);
return;
}
} catch (TimeoutException e) {
+ if (isReadyNotifier.isTerminated()) {
Review Comment:
I'd remove this:
- it's very unlikely to be true since it would have to terminate after we
timed out otherwise the await would return negative value
- there is no reason it couldn't terminate right after we check here so we
need to be correct regardless
- an additional return is additional complication to the flow. The below is
fine to run if we're terminated.
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java:
##########
@@ -122,12 +132,62 @@ void addRequest(QueuedRequest request) {
byteSize += request.byteSize();
}
- void countDown() {
+ /** Let waiting for threads know that the request has been successfully
sent. */
+ synchronized void notifySent() {
sent.countDown();
}
- void await() throws InterruptedException {
+ /** Let waiting for threads know that a failure occurred. */
+ synchronized void notifyFailed() {
+ failed = true;
+ sent.countDown();
+ }
+
+ /**
+ * Block until notified of a successful send via {@link #notifySent()} or
a non-retryable
+ * failure via {@link #notifyFailed()}. On failure, throw an exception to
on calling threads.
+ */
+ void waitForSendOrFailNotification() throws InterruptedException {
sent.await();
+ if (failed) {
+ ImmutableList<String> cancelledRequests =
createStreamCancelledErrorMessage();
+ LOG.error("Requests failed for the following batches: {}",
cancelledRequests);
+ throw new WindmillStreamShutdownException(
+ "Requests failed for batch containing "
+ + String.join(", ", cancelledRequests)
+ + " ... requests. This is most likely due to the stream being
explicitly closed"
+ + " which happens when the work is marked as invalid on the
streaming"
+ + " backend when key ranges shuffle around. This is transient
and corresponding"
+ + " work will eventually be retried.");
+ }
+ }
+
+ ImmutableList<String> createStreamCancelledErrorMessage() {
+ return requests.stream()
+ .flatMap(
+ request -> {
+ switch (request.getDataRequest().getKind()) {
+ case GLOBAL:
+ return Stream.of("GetSideInput=" +
request.getDataRequest().global());
+ case COMPUTATION:
+ return
request.getDataRequest().computation().getRequestsList().stream()
+ .map(
+ keyedRequest ->
+ "KeyedGetState=["
+ + "shardingKey="
+ + keyedRequest.getShardingKey()
Review Comment:
nit: print all of these values as hex? more compact and matches generally
how printed on status pages
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java:
##########
@@ -74,69 +76,104 @@ public void onNext(T value) {
while (true) {
try {
synchronized (lock) {
+ // If we awaited previously and timed out, wait for the same phase.
Otherwise we're
Review Comment:
I think we need to check for termination in this synchronized block we might
send in on retry. Otherwise we could have closed the observer and then try to
call onNext on it which might cause an issue.
How about structuring like:
```
int currentPhase = isReadyNotifier.getPhase();
// Phaser is terminated so don't use the outboundObserver. Since onError and
onCompleted are synchronized after terminating the phaser if we observe that
the phaser is not terminated the onNext calls below are guaranteed to not be
called on a closed observer.
if (currentPhase < 0) return;
// If we awaited previously and timed out, wait for the same phase.
Otherwise
// since we record the phase before the isReady check because the transition
from
// !isReady to isReady advances the phase.
if (awaitPhase < 0) awaitPhase = currentPhase();
```
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java:
##########
@@ -122,12 +132,62 @@ void addRequest(QueuedRequest request) {
byteSize += request.byteSize();
}
- void countDown() {
+ /** Let waiting for threads know that the request has been successfully
sent. */
+ synchronized void notifySent() {
sent.countDown();
}
- void await() throws InterruptedException {
+ /** Let waiting for threads know that a failure occurred. */
+ synchronized void notifyFailed() {
+ failed = true;
+ sent.countDown();
+ }
+
+ /**
+ * Block until notified of a successful send via {@link #notifySent()} or
a non-retryable
+ * failure via {@link #notifyFailed()}. On failure, throw an exception to
on calling threads.
+ */
+ void waitForSendOrFailNotification() throws InterruptedException {
sent.await();
+ if (failed) {
+ ImmutableList<String> cancelledRequests =
createStreamCancelledErrorMessage();
+ LOG.error("Requests failed for the following batches: {}",
cancelledRequests);
+ throw new WindmillStreamShutdownException(
+ "Requests failed for batch containing "
+ + String.join(", ", cancelledRequests)
+ + " ... requests. This is most likely due to the stream being
explicitly closed"
+ + " which happens when the work is marked as invalid on the
streaming"
+ + " backend when key ranges shuffle around. This is transient
and corresponding"
+ + " work will eventually be retried.");
+ }
+ }
+
+ ImmutableList<String> createStreamCancelledErrorMessage() {
Review Comment:
nit: createStreamCancelledErrorMessages
Message sounds like it would return a single string.
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java:
##########
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import javax.annotation.Nullable;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.ServerCallStreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcCommitWorkStreamTest {
+ private static final String FAKE_SERVER_NAME = "Fake server for
GrpcCommitWorkStreamTest";
+ private static final Windmill.JobHeader TEST_JOB_HEADER =
+ Windmill.JobHeader.newBuilder()
+ .setJobId("test_job")
+ .setWorkerId("test_worker")
+ .setProjectId("test_project")
+ .build();
+ private static final String COMPUTATION_ID = "computationId";
+
+ @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+ private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+ private ManagedChannel inProcessChannel;
+
+ private static Windmill.WorkItemCommitRequest workItemCommitRequest(long
value) {
+ return Windmill.WorkItemCommitRequest.newBuilder()
+ .setKey(ByteString.EMPTY)
+ .setShardingKey(value)
+ .setWorkToken(value)
+ .setCacheToken(value)
+ .build();
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ Server server =
+ InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+ .fallbackHandlerRegistry(serviceRegistry)
+ .directExecutor()
+ .build()
+ .start();
+
+ inProcessChannel =
+ grpcCleanup.register(
+
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+ grpcCleanup.register(server);
+ grpcCleanup.register(inProcessChannel);
+ }
+
+ @After
+ public void cleanUp() {
+ inProcessChannel.shutdownNow();
+ }
+
+ private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamTestStub
testStub) {
+ serviceRegistry.addService(testStub);
+ GrpcCommitWorkStream commitWorkStream =
+ (GrpcCommitWorkStream)
+ GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+ .build()
+ .createCommitWorkStream(
+ CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel),
+ new ThrottleTimer());
+ commitWorkStream.start();
+ return commitWorkStream;
+ }
+
+ @Test
+ public void testShutdown_abortsQueuedCommits() throws InterruptedException {
+ int numCommits = 5;
+ CountDownLatch commitProcessed = new CountDownLatch(numCommits);
+ Set<Windmill.CommitStatus> onDone = new HashSet<>();
+
+ TestCommitWorkStreamRequestObserver requestObserver =
+ spy(new TestCommitWorkStreamRequestObserver());
+ CommitWorkStreamTestStub testStub = new
CommitWorkStreamTestStub(requestObserver);
+ GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
+ try (WindmillStream.CommitWorkStream.RequestBatcher batcher =
commitWorkStream.batcher()) {
+ for (int i = 0; i < numCommits; i++) {
+ batcher.commitWorkItem(
+ COMPUTATION_ID,
+ workItemCommitRequest(i),
+ commitStatus -> {
+ onDone.add(commitStatus);
+ commitProcessed.countDown();
+ });
+ }
+ }
+
+ // Verify that we sent the commits above in a request + the initial header.
+ verify(requestObserver,
times(2)).onNext(any(Windmill.StreamingCommitWorkRequest.class));
+ // We won't get responses so we will have some pending requests.
+ assertTrue(commitWorkStream.hasPendingRequests());
+
+ commitWorkStream.shutdown();
+ commitProcessed.await();
+
+ assertThat(onDone).containsExactly(Windmill.CommitStatus.ABORTED);
+ }
+
+ @Test
+ public void testCommitWorkItem_afterShutdownFalse() {
+ int numCommits = 5;
+
+ CommitWorkStreamTestStub testStub =
+ new CommitWorkStreamTestStub(new
TestCommitWorkStreamRequestObserver());
+ GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
+
+ try (WindmillStream.CommitWorkStream.RequestBatcher batcher =
commitWorkStream.batcher()) {
+ for (int i = 0; i < numCommits; i++) {
+ assertTrue(batcher.commitWorkItem(COMPUTATION_ID,
workItemCommitRequest(i), ignored -> {}));
+ }
+ }
+ commitWorkStream.shutdown();
+
+ Set<Windmill.CommitStatus> commitStatuses = new HashSet<>();
Review Comment:
move inside loops? not trying to verify across loops so keeps it mentally
simpler
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java:
##########
@@ -122,12 +132,62 @@ void addRequest(QueuedRequest request) {
byteSize += request.byteSize();
}
- void countDown() {
+ /** Let waiting for threads know that the request has been successfully
sent. */
+ synchronized void notifySent() {
sent.countDown();
}
- void await() throws InterruptedException {
+ /** Let waiting for threads know that a failure occurred. */
+ synchronized void notifyFailed() {
+ failed = true;
+ sent.countDown();
+ }
+
+ /**
+ * Block until notified of a successful send via {@link #notifySent()} or
a non-retryable
+ * failure via {@link #notifyFailed()}. On failure, throw an exception to
on calling threads.
+ */
+ void waitForSendOrFailNotification() throws InterruptedException {
sent.await();
+ if (failed) {
+ ImmutableList<String> cancelledRequests =
createStreamCancelledErrorMessage();
Review Comment:
should we have a special message for empty or not treat it as failure?
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java:
##########
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import javax.annotation.Nullable;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.ServerCallStreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcCommitWorkStreamTest {
+ private static final String FAKE_SERVER_NAME = "Fake server for
GrpcCommitWorkStreamTest";
+ private static final Windmill.JobHeader TEST_JOB_HEADER =
+ Windmill.JobHeader.newBuilder()
+ .setJobId("test_job")
+ .setWorkerId("test_worker")
+ .setProjectId("test_project")
+ .build();
+ private static final String COMPUTATION_ID = "computationId";
+
+ @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+ private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+ private ManagedChannel inProcessChannel;
+
+ private static Windmill.WorkItemCommitRequest workItemCommitRequest(long
value) {
+ return Windmill.WorkItemCommitRequest.newBuilder()
+ .setKey(ByteString.EMPTY)
+ .setShardingKey(value)
+ .setWorkToken(value)
+ .setCacheToken(value)
+ .build();
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ Server server =
+ InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+ .fallbackHandlerRegistry(serviceRegistry)
+ .directExecutor()
+ .build()
+ .start();
+
+ inProcessChannel =
+ grpcCleanup.register(
+
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+ grpcCleanup.register(server);
+ grpcCleanup.register(inProcessChannel);
+ }
+
+ @After
+ public void cleanUp() {
+ inProcessChannel.shutdownNow();
+ }
+
+ private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamTestStub
testStub) {
+ serviceRegistry.addService(testStub);
+ GrpcCommitWorkStream commitWorkStream =
+ (GrpcCommitWorkStream)
+ GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+ .build()
+ .createCommitWorkStream(
+ CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel),
+ new ThrottleTimer());
+ commitWorkStream.start();
+ return commitWorkStream;
+ }
+
+ @Test
+ public void testShutdown_abortsQueuedCommits() throws InterruptedException {
+ int numCommits = 5;
+ CountDownLatch commitProcessed = new CountDownLatch(numCommits);
+ Set<Windmill.CommitStatus> onDone = new HashSet<>();
+
+ TestCommitWorkStreamRequestObserver requestObserver =
+ spy(new TestCommitWorkStreamRequestObserver());
+ CommitWorkStreamTestStub testStub = new
CommitWorkStreamTestStub(requestObserver);
+ GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
+ try (WindmillStream.CommitWorkStream.RequestBatcher batcher =
commitWorkStream.batcher()) {
+ for (int i = 0; i < numCommits; i++) {
+ batcher.commitWorkItem(
+ COMPUTATION_ID,
+ workItemCommitRequest(i),
+ commitStatus -> {
+ onDone.add(commitStatus);
+ commitProcessed.countDown();
+ });
+ }
+ }
+
+ // Verify that we sent the commits above in a request + the initial header.
+ verify(requestObserver,
times(2)).onNext(any(Windmill.StreamingCommitWorkRequest.class));
+ // We won't get responses so we will have some pending requests.
+ assertTrue(commitWorkStream.hasPendingRequests());
+
+ commitWorkStream.shutdown();
+ commitProcessed.await();
+
+ assertThat(onDone).containsExactly(Windmill.CommitStatus.ABORTED);
+ }
+
+ @Test
+ public void testCommitWorkItem_afterShutdownFalse() {
+ int numCommits = 5;
+
+ CommitWorkStreamTestStub testStub =
+ new CommitWorkStreamTestStub(new
TestCommitWorkStreamRequestObserver());
+ GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
+
+ try (WindmillStream.CommitWorkStream.RequestBatcher batcher =
commitWorkStream.batcher()) {
+ for (int i = 0; i < numCommits; i++) {
+ assertTrue(batcher.commitWorkItem(COMPUTATION_ID,
workItemCommitRequest(i), ignored -> {}));
+ }
+ }
+ commitWorkStream.shutdown();
+
+ Set<Windmill.CommitStatus> commitStatuses = new HashSet<>();
+ try (WindmillStream.CommitWorkStream.RequestBatcher batcher =
commitWorkStream.batcher()) {
+ for (int i = 0; i < numCommits; i++) {
+ assertFalse(
+ batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i),
commitStatuses::add));
+
assertThat(commitStatuses).containsExactly(Windmill.CommitStatus.ABORTED);
+ }
+ }
+ }
+
+ @Test
+ public void testSend_notCalledAfterShutdown() {
+ int numCommits = 5;
+ CountDownLatch commitProcessed = new CountDownLatch(numCommits);
+
+ TestCommitWorkStreamRequestObserver requestObserver =
+ spy(new TestCommitWorkStreamRequestObserver());
+ CommitWorkStreamTestStub testStub = new
CommitWorkStreamTestStub(requestObserver);
+ GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub);
+
+ try (WindmillStream.CommitWorkStream.RequestBatcher batcher =
commitWorkStream.batcher()) {
+ for (int i = 0; i < numCommits; i++) {
+ assertTrue(
+ batcher.commitWorkItem(
+ COMPUTATION_ID,
+ workItemCommitRequest(i),
+ commitStatus -> commitProcessed.countDown()));
+ }
+ commitWorkStream.shutdown();
+ }
+
+ // send() uses the requestObserver to send requests. We expect 1 send
since startStream() sends
+ // the header, which happens before we shutdown.
+ verify(requestObserver,
times(1)).onNext(any(Windmill.StreamingCommitWorkRequest.class));
Review Comment:
can use InOrder to verify that onNext is before onComplete and then no more
calls
https://stackoverflow.com/questions/21901368/mockito-verify-order-sequence-of-method-calls
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java:
##########
@@ -21,26 +21,30 @@
import java.io.PrintWriter;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.Executor;
+import java.util.concurrent.ExecutorService;
Review Comment:
still need to get to this file, it was collapsed
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java:
##########
@@ -117,6 +129,23 @@ public void setUp() throws IOException {
WindmillStreamPool.create(
1, Duration.standardMinutes(1),
fakeWindmillServer::commitWorkStream)
::getCloseableStream;
+ Server server =
+ InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+ .fallbackHandlerRegistry(serviceRegistry)
+ .directExecutor()
+ .build()
+ .start();
+
+ inProcessChannel =
Review Comment:
I don't see how this is being used if we're using the fake
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java:
##########
@@ -74,69 +76,104 @@ public void onNext(T value) {
while (true) {
try {
synchronized (lock) {
+ // If we awaited previously and timed out, wait for the same phase.
Otherwise we're
+ // careful to observe the phase before observing isReady.
+ if (awaitPhase < 0) {
+ awaitPhase = isReadyNotifier.getPhase();
+ // If getPhase() returns a value less than 0, the phaser has been
terminated.
+ if (awaitPhase < 0) {
+ return;
+ }
+ }
+
// We only check isReady periodically to effectively allow for
increasing the outbound
// buffer periodically. This reduces the overhead of blocking while
still restricting
// memory because there is a limited # of streams, and we have a max
messages size of 2MB.
if (++messagesSinceReady <= messagesBetweenIsReadyChecks) {
outboundObserver.onNext(value);
return;
}
- // If we awaited previously and timed out, wait for the same phase.
Otherwise we're
- // careful to observe the phase before observing isReady.
- if (awaitPhase < 0) {
- awaitPhase = phaser.getPhase();
- }
+
if (outboundObserver.isReady()) {
messagesSinceReady = 0;
outboundObserver.onNext(value);
return;
}
}
+
// A callback has been registered to advance the phaser whenever the
observer
// transitions to is ready. Since we are waiting for a phase observed
before the
// outboundObserver.isReady() returned false, we expect it to advance
after the
// channel has become ready. This doesn't always seem to be the case
(despite
// documentation stating otherwise) so we poll periodically and
enforce an overall
// timeout related to the stream deadline.
- phaser.awaitAdvanceInterruptibly(awaitPhase, waitSeconds,
TimeUnit.SECONDS);
+ int nextPhase =
+ isReadyNotifier.awaitAdvanceInterruptibly(awaitPhase, waitSeconds,
TimeUnit.SECONDS);
+ // If nextPhase is a value less than 0, the phaser has been terminated.
+ if (nextPhase < 0) {
Review Comment:
see above, I think that we want to check the phaser isn't terminated beneath
synchronized block with onNext below
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]