Repository: beam Updated Branches: refs/heads/master 585440d22 -> 1cd87e325
[BEAM-1347] Implement a BeamFnStateClient which communicates over gRPC. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/fb2d6b58 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/fb2d6b58 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/fb2d6b58 Branch: refs/heads/master Commit: fb2d6b58c065604daedf02a492457ce35bacfde2 Parents: 585440d Author: Luke Cwik <lc...@google.com> Authored: Tue Aug 29 18:31:39 2017 -0700 Committer: Luke Cwik <lc...@google.com> Committed: Wed Aug 30 16:10:25 2017 -0700 ---------------------------------------------------------------------- .../org/apache/beam/fn/harness/IdGenerator.java | 33 +++ .../state/BeamFnStateGrpcClientCache.java | 173 ++++++++++++++ .../apache/beam/fn/harness/IdGeneratorTest.java | 40 ++++ .../state/BeamFnStateGrpcClientCacheTest.java | 234 +++++++++++++++++++ 4 files changed, 480 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/fb2d6b58/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/IdGenerator.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/IdGenerator.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/IdGenerator.java new file mode 100644 index 0000000..1112f43 --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/IdGenerator.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness; + +import java.util.concurrent.atomic.AtomicLong; + +/** + * An id generator. + * + * <p>This encapsulation exists to prevent usage of the wrong method on a shared {@link AtomicLong}. + */ +public final class IdGenerator { + private static final AtomicLong idGenerator = new AtomicLong(-1); + + public static String generate() { + return Long.toString(idGenerator.getAndDecrement()); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/fb2d6b58/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java new file mode 100644 index 0000000..316e3e6 --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCache.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.state; + +import io.grpc.ManagedChannel; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import org.apache.beam.fn.harness.data.BeamFnDataGrpcClient; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.fn.v1.BeamFnApi.ApiServiceDescriptor; +import org.apache.beam.fn.v1.BeamFnApi.StateRequest; +import org.apache.beam.fn.v1.BeamFnApi.StateResponse; +import org.apache.beam.fn.v1.BeamFnStateGrpc; +import org.apache.beam.sdk.options.PipelineOptions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A cache of {@link BeamFnStateClient}s which handle Beam Fn State requests using gRPC. + * + * <p>TODO: Add the ability to close which cancels any pending and stops any future requests. + */ +public class BeamFnStateGrpcClientCache { + private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataGrpcClient.class); + + private final ConcurrentMap<ApiServiceDescriptor, BeamFnStateClient> cache; + private final Function<ApiServiceDescriptor, ManagedChannel> channelFactory; + private final BiFunction<Function<StreamObserver<StateResponse>, + StreamObserver<StateRequest>>, + StreamObserver<StateResponse>, + StreamObserver<StateRequest>> streamObserverFactory; + private final PipelineOptions options; + private final Supplier<String> idGenerator; + + public BeamFnStateGrpcClientCache( + PipelineOptions options, + Supplier<String> idGenerator, + Function<BeamFnApi.ApiServiceDescriptor, ManagedChannel> channelFactory, + BiFunction<Function<StreamObserver<StateResponse>, StreamObserver<StateRequest>>, + StreamObserver<StateResponse>, + StreamObserver<StateRequest>> streamObserverFactory) { + this.options = options; + this.idGenerator = idGenerator; + this.channelFactory = channelFactory; + this.streamObserverFactory = streamObserverFactory; + this.cache = new ConcurrentHashMap<>(); + } + + /**( + * Creates or returns an existing {@link BeamFnStateClient} depending on whether the passed in + * {@link ApiServiceDescriptor} currently has a {@link BeamFnStateClient} bound to the same + * channel. + */ + public BeamFnStateClient forApiServiceDescriptor(ApiServiceDescriptor apiServiceDescriptor) + throws IOException { + return cache.computeIfAbsent(apiServiceDescriptor, this::createBeamFnStateClient); + } + + private BeamFnStateClient createBeamFnStateClient(ApiServiceDescriptor apiServiceDescriptor) { + return new GrpcStateClient(apiServiceDescriptor); + } + + /** + * A {@link BeamFnStateClient} for a given {@link ApiServiceDescriptor}. + */ + private class GrpcStateClient implements BeamFnStateClient { + private final ApiServiceDescriptor apiServiceDescriptor; + private final ConcurrentMap<String, CompletableFuture<StateResponse>> outstandingRequests; + private final StreamObserver<StateRequest> outboundObserver; + private final ManagedChannel channel; + private volatile RuntimeException closed; + + private GrpcStateClient(ApiServiceDescriptor apiServiceDescriptor) { + this.apiServiceDescriptor = apiServiceDescriptor; + this.outstandingRequests = new ConcurrentHashMap<>(); + this.channel = channelFactory.apply(apiServiceDescriptor); + this.outboundObserver = streamObserverFactory.apply( + BeamFnStateGrpc.newStub(channel)::state, new InboundObserver()); + } + + @Override + public void handle( + StateRequest.Builder requestBuilder, CompletableFuture<StateResponse> response) { + requestBuilder.setId(idGenerator.get()); + StateRequest request = requestBuilder.build(); + outstandingRequests.put(request.getId(), response); + + // If the server closes, gRPC will throw an error if onNext is called. + LOG.debug("Sending StateRequest {}", request); + outboundObserver.onNext(request); + } + + private synchronized void closeAndCleanUp(RuntimeException cause) { + if (closed != null) { + return; + } + cache.remove(apiServiceDescriptor); + closed = cause; + + // Make a copy of the map to make the view of the outstanding requests consistent. + Map<String, CompletableFuture<StateResponse>> outstandingRequestsCopy = + new ConcurrentHashMap<>(outstandingRequests); + + if (outstandingRequestsCopy.isEmpty()) { + outboundObserver.onCompleted(); + return; + } + + outstandingRequests.clear(); + LOG.error("BeamFnState failed, clearing outstanding requests {}", outstandingRequestsCopy); + + for (CompletableFuture<StateResponse> entry : outstandingRequestsCopy.values()) { + entry.completeExceptionally(cause); + } + } + + /** + * A {@link StreamObserver} which propagates any server side state request responses by + * completing the outstanding response future. + * + * <p>Also propagates server side failures and closes completing any outstanding requests + * exceptionally. + */ + private class InboundObserver implements StreamObserver<StateResponse> { + @Override + public void onNext(StateResponse value) { + LOG.debug("Received StateResponse {}", value); + CompletableFuture<StateResponse> responseFuture = outstandingRequests.remove(value.getId()); + if (responseFuture != null) { + if (value.getError().isEmpty()) { + responseFuture.complete(value); + } else { + responseFuture.completeExceptionally(new IllegalStateException(value.getError())); + } + } + } + + @Override + public void onError(Throwable t) { + closeAndCleanUp(t instanceof RuntimeException + ? (RuntimeException) t + : new RuntimeException(t)); + } + + @Override + public void onCompleted() { + closeAndCleanUp(new RuntimeException("Server hanged up.")); + } + } + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/fb2d6b58/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/IdGeneratorTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/IdGeneratorTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/IdGeneratorTest.java new file mode 100644 index 0000000..10ce393 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/IdGeneratorTest.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness; + +import static org.junit.Assert.assertEquals; + +import java.util.HashSet; +import java.util.Set; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link IdGenerator}. */ +@RunWith(JUnit4.class) +public class IdGeneratorTest { + @Test + public void testGenerationNeverMatches() { + final int numToGenerate = 10000; + Set<String> generatedValues = new HashSet<>(); + for (int i = 0; i < numToGenerate; ++i) { + generatedValues.add(IdGenerator.generate()); + } + assertEquals(numToGenerate, generatedValues.size()); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/fb2d6b58/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCacheTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCacheTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCacheTest.java new file mode 100644 index 0000000..f0e84c7 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BeamFnStateGrpcClientCacheTest.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.state; + +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import com.google.common.util.concurrent.Uninterruptibles; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.CallStreamObserver; +import io.grpc.stub.StreamObserver; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.function.Function; +import org.apache.beam.fn.harness.IdGenerator; +import org.apache.beam.fn.harness.test.TestStreams; +import org.apache.beam.fn.v1.BeamFnApi.ApiServiceDescriptor; +import org.apache.beam.fn.v1.BeamFnApi.StateRequest; +import org.apache.beam.fn.v1.BeamFnApi.StateResponse; +import org.apache.beam.fn.v1.BeamFnStateGrpc; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link BeamFnStateGrpcClientCache}. */ +@RunWith(JUnit4.class) +public class BeamFnStateGrpcClientCacheTest { + private static final String SUCCESS = "SUCCESS"; + private static final String FAIL = "FAIL"; + private static final String TEST_ERROR = "TEST ERROR"; + private static final String SERVER_ERROR = "SERVER ERROR"; + + private ApiServiceDescriptor apiServiceDescriptor; + private ManagedChannel testChannel; + private Server testServer; + private BeamFnStateGrpcClientCache clientCache; + private BlockingQueue<StreamObserver<StateResponse>> outboundServerObservers; + private BlockingQueue<StateRequest> values; + + @Before + public void setUp() throws Exception { + values = new LinkedBlockingQueue<>(); + outboundServerObservers = new LinkedBlockingQueue<>(); + CallStreamObserver<StateRequest> inboundServerObserver = + TestStreams.withOnNext(values::add).build(); + + apiServiceDescriptor = + ApiServiceDescriptor.newBuilder() + .setUrl(this.getClass().getName() + "-" + UUID.randomUUID().toString()) + .build(); + testServer = InProcessServerBuilder.forName(apiServiceDescriptor.getUrl()) + .addService(new BeamFnStateGrpc.BeamFnStateImplBase() { + @Override + public StreamObserver<StateRequest> state( + StreamObserver<StateResponse> outboundObserver) { + Uninterruptibles.putUninterruptibly(outboundServerObservers, outboundObserver); + return inboundServerObserver; + } + }) + .build(); + testServer.start(); + + testChannel = InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build(); + + clientCache = new BeamFnStateGrpcClientCache( + PipelineOptionsFactory.create(), + IdGenerator::generate, + (ApiServiceDescriptor descriptor) -> testChannel, + this::createStreamForTest); + } + + @After + public void tearDown() throws Exception { + testServer.shutdownNow(); + testChannel.shutdownNow(); + } + + @Test + public void testCachingOfClient() throws Exception { + assertSame(clientCache.forApiServiceDescriptor(apiServiceDescriptor), + clientCache.forApiServiceDescriptor(apiServiceDescriptor)); + assertNotSame(clientCache.forApiServiceDescriptor(apiServiceDescriptor), + clientCache.forApiServiceDescriptor( + ApiServiceDescriptor.newBuilder().setId("OTHER").build())); + } + + @Test + public void testRequestResponses() throws Exception { + BeamFnStateClient client = clientCache.forApiServiceDescriptor(apiServiceDescriptor); + + CompletableFuture<StateResponse> successfulResponse = new CompletableFuture<>(); + CompletableFuture<StateResponse> unsuccessfulResponse = new CompletableFuture<>(); + + client.handle( + StateRequest.newBuilder().setInstructionReference(SUCCESS), successfulResponse); + client.handle( + StateRequest.newBuilder().setInstructionReference(FAIL), unsuccessfulResponse); + + // Wait for the client to connect. + StreamObserver<StateResponse> outboundServerObserver = outboundServerObservers.take(); + // Ensure the client doesn't break when sent garbage. + outboundServerObserver.onNext(StateResponse.newBuilder().setId("UNKNOWN ID").build()); + + // We expect to receive and handle two requests + handleServerRequest(outboundServerObserver, values.take()); + handleServerRequest(outboundServerObserver, values.take()); + + // Ensure that the successful and unsuccessful responses were propagated. + assertNotNull(successfulResponse.get()); + try { + unsuccessfulResponse.get(); + fail("Expected unsuccessful response"); + } catch (ExecutionException e) { + assertThat(e.toString(), containsString(TEST_ERROR)); + } + } + + @Test + public void testServerErrorCausesPendingAndFutureCallsToFail() throws Exception { + BeamFnStateClient client = clientCache.forApiServiceDescriptor(apiServiceDescriptor); + + CompletableFuture<StateResponse> inflight = new CompletableFuture<>(); + client.handle(StateRequest.newBuilder().setInstructionReference(SUCCESS), inflight); + + // Wait for the client to connect. + StreamObserver<StateResponse> outboundServerObserver = outboundServerObservers.take(); + // Send an error from the server. + outboundServerObserver.onError( + new StatusRuntimeException(Status.INTERNAL.withDescription(SERVER_ERROR))); + + try { + inflight.get(); + fail("Expected unsuccessful response due to server error"); + } catch (ExecutionException e) { + assertThat(e.toString(), containsString(SERVER_ERROR)); + } + + // Send a response after the client will have received an error. + CompletableFuture<StateResponse> late = new CompletableFuture<>(); + client.handle(StateRequest.newBuilder().setInstructionReference(SUCCESS), late); + + try { + inflight.get(); + fail("Expected unsuccessful response due to server error"); + } catch (ExecutionException e) { + assertThat(e.toString(), containsString(SERVER_ERROR)); + } + } + + @Test + public void testServerCompletionCausesPendingAndFutureCallsToFail() throws Exception { + BeamFnStateClient client = clientCache.forApiServiceDescriptor(apiServiceDescriptor); + + CompletableFuture<StateResponse> inflight = new CompletableFuture<>(); + client.handle(StateRequest.newBuilder().setInstructionReference(SUCCESS), inflight); + + // Wait for the client to connect. + StreamObserver<StateResponse> outboundServerObserver = outboundServerObservers.take(); + // Send that the server is done. + outboundServerObserver.onCompleted(); + + try { + inflight.get(); + fail("Expected unsuccessful response due to server completion"); + } catch (ExecutionException e) { + assertThat(e.toString(), containsString("Server hanged up")); + } + + // Send a response after the client will have received an error. + CompletableFuture<StateResponse> late = new CompletableFuture<>(); + client.handle(StateRequest.newBuilder().setInstructionReference(SUCCESS), late); + + try { + inflight.get(); + fail("Expected unsuccessful response due to server completion"); + } catch (ExecutionException e) { + assertThat(e.toString(), containsString("Server hanged up")); + } + } + + private void handleServerRequest( + StreamObserver<StateResponse> outboundObserver, StateRequest value) { + switch (value.getInstructionReference()) { + case SUCCESS: + outboundObserver.onNext(StateResponse.newBuilder().setId(value.getId()).build()); + return; + case FAIL: + outboundObserver.onNext(StateResponse.newBuilder() + .setId(value.getId()) + .setError(TEST_ERROR) + .build()); + return; + default: + outboundObserver.onNext(StateResponse.newBuilder().setId(value.getId()).build()); + return; + } + } + + private <ReqT, RespT> StreamObserver<RespT> createStreamForTest( + Function<StreamObserver<ReqT>, StreamObserver<RespT>> clientFactory, + StreamObserver<ReqT> handler) { + return clientFactory.apply(handler); + } +}