http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java new file mode 100644 index 0000000..20566ea --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.data; + +import static org.apache.beam.sdk.util.CoderUtils.encodeToByteArray; +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.collection.IsEmptyCollection.empty; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import com.google.protobuf.ByteString; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.CallStreamObserver; +import io.grpc.stub.StreamObserver; +import java.util.Collection; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; +import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.test.TestStreams; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.fn.v1.BeamFnDataGrpc; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.LengthPrefixCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link BeamFnDataGrpcClient}. */ +@RunWith(JUnit4.class) +public class BeamFnDataGrpcClientTest { + private static final Coder<WindowedValue<String>> CODER = + LengthPrefixCoder.of( + WindowedValue.getFullCoder(StringUtf8Coder.of(), + GlobalWindow.Coder.INSTANCE)); + private static final KV<Long, BeamFnApi.Target> KEY_A = KV.of( + 12L, + BeamFnApi.Target.newBuilder().setPrimitiveTransformReference(34L).setName("targetA").build()); + private static final KV<Long, BeamFnApi.Target> KEY_B = KV.of( + 56L, + BeamFnApi.Target.newBuilder().setPrimitiveTransformReference(78L).setName("targetB").build()); + + private static final BeamFnApi.Elements ELEMENTS_A_1; + private static final BeamFnApi.Elements ELEMENTS_A_2; + private static final BeamFnApi.Elements ELEMENTS_B_1; + static { + try { + ELEMENTS_A_1 = BeamFnApi.Elements.newBuilder() + .addData(BeamFnApi.Elements.Data.newBuilder() + .setInstructionReference(KEY_A.getKey()) + .setTarget(KEY_A.getValue()) + .setData(ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("ABC"))) + .concat(ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("DEF")))))) + .build(); + ELEMENTS_A_2 = BeamFnApi.Elements.newBuilder() + .addData(BeamFnApi.Elements.Data.newBuilder() + .setInstructionReference(KEY_A.getKey()) + .setTarget(KEY_A.getValue()) + .setData(ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("GHI"))))) + .addData(BeamFnApi.Elements.Data.newBuilder() + .setInstructionReference(KEY_A.getKey()) + .setTarget(KEY_A.getValue())) + .build(); + ELEMENTS_B_1 = BeamFnApi.Elements.newBuilder() + .addData(BeamFnApi.Elements.Data.newBuilder() + .setInstructionReference(KEY_B.getKey()) + .setTarget(KEY_B.getValue()) + .setData(ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("JKL"))) + .concat(ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("MNO")))))) + .addData(BeamFnApi.Elements.Data.newBuilder() + .setInstructionReference(KEY_B.getKey()) + .setTarget(KEY_B.getValue())) + .build(); + } catch (Exception e) { + throw new ExceptionInInitializerError(e); + } + } + + @Test + public void testForInboundConsumer() throws Exception { + CountDownLatch waitForClientToConnect = new CountDownLatch(1); + Collection<WindowedValue<String>> inboundValuesA = new ConcurrentLinkedQueue<>(); + Collection<WindowedValue<String>> inboundValuesB = new ConcurrentLinkedQueue<>(); + Collection<BeamFnApi.Elements> inboundServerValues = new ConcurrentLinkedQueue<>(); + AtomicReference<StreamObserver<BeamFnApi.Elements>> outboundServerObserver = + new AtomicReference<>(); + CallStreamObserver<BeamFnApi.Elements> inboundServerObserver = + TestStreams.withOnNext(inboundServerValues::add).build(); + + BeamFnApi.ApiServiceDescriptor apiServiceDescriptor = + BeamFnApi.ApiServiceDescriptor.newBuilder() + .setUrl(this.getClass().getName() + "-" + UUID.randomUUID().toString()) + .build(); + Server server = InProcessServerBuilder.forName(apiServiceDescriptor.getUrl()) + .addService(new BeamFnDataGrpc.BeamFnDataImplBase() { + @Override + public StreamObserver<BeamFnApi.Elements> data( + StreamObserver<BeamFnApi.Elements> outboundObserver) { + outboundServerObserver.set(outboundObserver); + waitForClientToConnect.countDown(); + return inboundServerObserver; + } + }) + .build(); + server.start(); + try { + ManagedChannel channel = + InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build(); + + BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( + PipelineOptionsFactory.create(), + (BeamFnApi.ApiServiceDescriptor descriptor) -> channel, + this::createStreamForTest); + + CompletableFuture<Void> readFutureA = clientFactory.forInboundConsumer( + apiServiceDescriptor, + KEY_A, + CODER, + inboundValuesA::add); + + waitForClientToConnect.await(); + outboundServerObserver.get().onNext(ELEMENTS_A_1); + // Purposefully transmit some data before the consumer for B is bound showing that + // data is not lost + outboundServerObserver.get().onNext(ELEMENTS_B_1); + Thread.sleep(100); + + CompletableFuture<Void> readFutureB = clientFactory.forInboundConsumer( + apiServiceDescriptor, + KEY_B, + CODER, + inboundValuesB::add); + + // Show that out of order stream completion can occur. + readFutureB.get(); + assertThat(inboundValuesB, contains( + valueInGlobalWindow("JKL"), valueInGlobalWindow("MNO"))); + + outboundServerObserver.get().onNext(ELEMENTS_A_2); + readFutureA.get(); + assertThat(inboundValuesA, contains( + valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"), valueInGlobalWindow("GHI"))); + } finally { + server.shutdownNow(); + } + } + + @Test + public void testForInboundConsumerThatThrows() throws Exception { + CountDownLatch waitForClientToConnect = new CountDownLatch(1); + AtomicInteger consumerInvoked = new AtomicInteger(); + Collection<BeamFnApi.Elements> inboundServerValues = new ConcurrentLinkedQueue<>(); + AtomicReference<StreamObserver<BeamFnApi.Elements>> outboundServerObserver = + new AtomicReference<>(); + CallStreamObserver<BeamFnApi.Elements> inboundServerObserver = + TestStreams.withOnNext(inboundServerValues::add).build(); + + BeamFnApi.ApiServiceDescriptor apiServiceDescriptor = + BeamFnApi.ApiServiceDescriptor.newBuilder() + .setUrl(this.getClass().getName() + "-" + UUID.randomUUID().toString()) + .build(); + Server server = InProcessServerBuilder.forName(apiServiceDescriptor.getUrl()) + .addService(new BeamFnDataGrpc.BeamFnDataImplBase() { + @Override + public StreamObserver<BeamFnApi.Elements> data( + StreamObserver<BeamFnApi.Elements> outboundObserver) { + outboundServerObserver.set(outboundObserver); + waitForClientToConnect.countDown(); + return inboundServerObserver; + } + }) + .build(); + server.start(); + RuntimeException exceptionToThrow = new RuntimeException("TestFailure"); + try { + ManagedChannel channel = + InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build(); + + BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( + PipelineOptionsFactory.create(), + (BeamFnApi.ApiServiceDescriptor descriptor) -> channel, + this::createStreamForTest); + + CompletableFuture<Void> readFuture = clientFactory.forInboundConsumer( + apiServiceDescriptor, + KEY_A, + CODER, + new ThrowingConsumer<WindowedValue<String>>() { + @Override + public void accept(WindowedValue<String> t) throws Exception { + consumerInvoked.incrementAndGet(); + throw exceptionToThrow; + } + }); + + waitForClientToConnect.await(); + + // This first message should cause a failure afterwards all other messages are dropped. + outboundServerObserver.get().onNext(ELEMENTS_A_1); + outboundServerObserver.get().onNext(ELEMENTS_A_2); + + try { + readFuture.get(); + fail("Expected channel to fail"); + } catch (ExecutionException e) { + assertEquals(exceptionToThrow, e.getCause()); + } + // The server should not have received any values + assertThat(inboundServerValues, empty()); + // The consumer should have only been invoked once + assertEquals(1, consumerInvoked.get()); + } finally { + server.shutdownNow(); + } + } + + @Test + public void testForOutboundConsumer() throws Exception { + CountDownLatch waitForInboundServerValuesCompletion = new CountDownLatch(2); + Collection<BeamFnApi.Elements> inboundServerValues = new ConcurrentLinkedQueue<>(); + CallStreamObserver<BeamFnApi.Elements> inboundServerObserver = + TestStreams.withOnNext( + new Consumer<BeamFnApi.Elements>() { + @Override + public void accept(BeamFnApi.Elements t) { + inboundServerValues.add(t); + waitForInboundServerValuesCompletion.countDown(); + } + } + ).build(); + + BeamFnApi.ApiServiceDescriptor apiServiceDescriptor = + BeamFnApi.ApiServiceDescriptor.newBuilder() + .setUrl(this.getClass().getName() + "-" + UUID.randomUUID().toString()) + .build(); + Server server = InProcessServerBuilder.forName(apiServiceDescriptor.getUrl()) + .addService(new BeamFnDataGrpc.BeamFnDataImplBase() { + @Override + public StreamObserver<BeamFnApi.Elements> data( + StreamObserver<BeamFnApi.Elements> outboundObserver) { + return inboundServerObserver; + } + }) + .build(); + server.start(); + try { + ManagedChannel channel = + InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build(); + + BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( + PipelineOptionsFactory.fromArgs( + new String[]{ "--experiments=beam_fn_api_data_buffer_limit=20" }).create(), + (BeamFnApi.ApiServiceDescriptor descriptor) -> channel, + this::createStreamForTest); + + try (CloseableThrowingConsumer<WindowedValue<String>> consumer = + clientFactory.forOutboundConsumer(apiServiceDescriptor, KEY_A, CODER)) { + consumer.accept(valueInGlobalWindow("ABC")); + consumer.accept(valueInGlobalWindow("DEF")); + consumer.accept(valueInGlobalWindow("GHI")); + } + + waitForInboundServerValuesCompletion.await(); + + assertThat(inboundServerValues, contains(ELEMENTS_A_1, ELEMENTS_A_2)); + } finally { + server.shutdownNow(); + } + } + + private <ReqT, RespT> StreamObserver<RespT> createStreamForTest( + Function<StreamObserver<ReqT>, StreamObserver<RespT>> clientFactory, + StreamObserver<ReqT> handler) { + return clientFactory.apply(handler); + } +}
http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcMultiplexerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcMultiplexerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcMultiplexerTest.java new file mode 100644 index 0000000..38d9e2c --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcMultiplexerTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.data; + +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.common.util.concurrent.Uninterruptibles; +import com.google.protobuf.ByteString; +import io.grpc.stub.StreamObserver; +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import org.apache.beam.fn.harness.test.TestStreams; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.sdk.values.KV; +import org.junit.Test; + +/** Tests for {@link BeamFnDataGrpcMultiplexer}. */ +public class BeamFnDataGrpcMultiplexerTest { + private static final BeamFnApi.ApiServiceDescriptor DESCRIPTOR = + BeamFnApi.ApiServiceDescriptor.newBuilder().setUrl("test").build(); + private static final KV<Long, BeamFnApi.Target> OUTPUT_LOCATION = KV.of(777L, + BeamFnApi.Target.newBuilder() + .setName("name") + .setPrimitiveTransformReference(888L) + .build()); + private static final BeamFnApi.Elements ELEMENTS = BeamFnApi.Elements.newBuilder() + .addData(BeamFnApi.Elements.Data.newBuilder() + .setInstructionReference(OUTPUT_LOCATION.getKey()) + .setTarget(OUTPUT_LOCATION.getValue()) + .setData(ByteString.copyFrom(new byte[1]))) + .build(); + private static final BeamFnApi.Elements TERMINAL_ELEMENTS = BeamFnApi.Elements.newBuilder() + .addData(BeamFnApi.Elements.Data.newBuilder() + .setInstructionReference(OUTPUT_LOCATION.getKey()) + .setTarget(OUTPUT_LOCATION.getValue())) + .build(); + + @Test + public void testOutboundObserver() { + Collection<BeamFnApi.Elements> values = new ArrayList<>(); + BeamFnDataGrpcMultiplexer multiplexer = new BeamFnDataGrpcMultiplexer( + DESCRIPTOR, + (StreamObserver<BeamFnApi.Elements> inboundObserver) + -> TestStreams.withOnNext(values::add).build()); + multiplexer.getOutboundObserver().onNext(ELEMENTS); + assertThat(values, contains(ELEMENTS)); + } + + @Test + public void testInboundObserverBlocksTillConsumerConnects() throws Exception { + Collection<BeamFnApi.Elements> outboundValues = new ArrayList<>(); + Collection<BeamFnApi.Elements.Data> inboundValues = new ArrayList<>(); + BeamFnDataGrpcMultiplexer multiplexer = new BeamFnDataGrpcMultiplexer( + DESCRIPTOR, + (StreamObserver<BeamFnApi.Elements> inboundObserver) + -> TestStreams.withOnNext(outboundValues::add).build()); + ExecutorService executorService = Executors.newCachedThreadPool(); + executorService.submit(new Runnable() { + @Override + public void run() { + // Purposefully sleep to simulate a delay in a consumer connecting. + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + multiplexer.futureForKey(OUTPUT_LOCATION).complete(inboundValues::add); + } + }); + multiplexer.getInboundObserver().onNext(ELEMENTS); + assertTrue(multiplexer.consumers.containsKey(OUTPUT_LOCATION)); + // Ensure that when we see a terminal Elements object, we remove the consumer + multiplexer.getInboundObserver().onNext(TERMINAL_ELEMENTS); + assertFalse(multiplexer.consumers.containsKey(OUTPUT_LOCATION)); + + // Assert that normal and terminal Elements are passed to the consumer + assertThat(inboundValues, contains(ELEMENTS.getData(0), TERMINAL_ELEMENTS.getData(0))); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataInboundObserverTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataInboundObserverTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataInboundObserverTest.java new file mode 100644 index 0000000..ff0e083 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataInboundObserverTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.data; + +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.protobuf.ByteString; +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.Coder.Context; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.WindowedValue; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link BeamFnDataInboundObserver}. */ +@RunWith(JUnit4.class) +public class BeamFnDataInboundObserverTest { + private static final Coder<WindowedValue<String>> CODER = + WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); + + @Test + public void testDecodingElements() throws Exception { + Collection<WindowedValue<String>> values = new ArrayList<>(); + CompletableFuture<Void> readFuture = new CompletableFuture<>(); + BeamFnDataInboundObserver<String> observer = new BeamFnDataInboundObserver<>( + CODER, + values::add, + readFuture); + + // Test decoding multiple messages + observer.accept(dataWith("ABC", "DEF", "GHI")); + assertThat(values, contains( + valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"), valueInGlobalWindow("GHI"))); + values.clear(); + + // Test empty message signaling end of stream + assertFalse(readFuture.isDone()); + observer.accept(dataWith()); + assertTrue(readFuture.isDone()); + + // Test messages after stream is finished are discarded + observer.accept(dataWith("ABC", "DEF", "GHI")); + assertThat(values, empty()); + } + + @Test + public void testConsumptionFailureCompletesReadFutureAndDiscardsMessages() throws Exception { + CompletableFuture<Void> readFuture = new CompletableFuture<>(); + BeamFnDataInboundObserver<String> observer = new BeamFnDataInboundObserver<>( + CODER, + this::throwOnDefValue, + readFuture); + + assertFalse(readFuture.isDone()); + observer.accept(dataWith("ABC", "DEF", "GHI")); + assertTrue(readFuture.isCompletedExceptionally()); + + try { + readFuture.get(); + fail("Expected failure"); + } catch (ExecutionException e) { + assertThat(e.getCause(), instanceOf(RuntimeException.class)); + assertEquals("Failure", e.getCause().getMessage()); + } + } + + private void throwOnDefValue(WindowedValue<String> value) { + if ("DEF".equals(value.getValue())) { + throw new RuntimeException("Failure"); + } + } + + private BeamFnApi.Elements.Data dataWith(String ... values) throws Exception { + BeamFnApi.Elements.Data.Builder builder = BeamFnApi.Elements.Data.newBuilder() + .setInstructionReference(777L) + .setTarget(BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(999L) + .setName("Test")); + ByteString.Output output = ByteString.newOutput(); + for (String value : values) { + CODER.encode(valueInGlobalWindow(value), output, Context.NESTED); + } + builder.setData(output.toByteString()); + return builder.build(); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClientTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClientTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClientTest.java new file mode 100644 index 0000000..bb6a501 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/logging/BeamFnLoggingClientTest.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.logging; + +import static com.google.common.base.Throwables.getStackTraceAsString; +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import com.google.protobuf.Timestamp; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.CallStreamObserver; +import io.grpc.stub.StreamObserver; +import java.util.Collection; +import java.util.UUID; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.logging.Level; +import java.util.logging.LogManager; +import java.util.logging.LogRecord; +import org.apache.beam.fn.harness.test.TestStreams; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.fn.v1.BeamFnLoggingGrpc; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link BeamFnLoggingClient}. */ +@RunWith(JUnit4.class) +public class BeamFnLoggingClientTest { + + private static final LogRecord FILTERED_RECORD; + private static final LogRecord TEST_RECORD; + private static final LogRecord TEST_RECORD_WITH_EXCEPTION; + + static { + FILTERED_RECORD = new LogRecord(Level.SEVERE, "FilteredMessage"); + + TEST_RECORD = new LogRecord(Level.FINE, "Message"); + TEST_RECORD.setLoggerName("LoggerName"); + TEST_RECORD.setMillis(1234567890L); + TEST_RECORD.setThreadID(12345); + + TEST_RECORD_WITH_EXCEPTION = new LogRecord(Level.WARNING, "MessageWithException"); + TEST_RECORD_WITH_EXCEPTION.setLoggerName("LoggerName"); + TEST_RECORD_WITH_EXCEPTION.setMillis(1234567890L); + TEST_RECORD_WITH_EXCEPTION.setThreadID(12345); + TEST_RECORD_WITH_EXCEPTION.setThrown(new RuntimeException("ExceptionMessage")); + } + + private static final BeamFnApi.LogEntry TEST_ENTRY = + BeamFnApi.LogEntry.newBuilder() + .setSeverity(BeamFnApi.LogEntry.Severity.DEBUG) + .setMessage("Message") + .setThread("12345") + .setTimestamp(Timestamp.newBuilder().setSeconds(1234567).setNanos(890000000).build()) + .setLogLocation("LoggerName") + .build(); + private static final BeamFnApi.LogEntry TEST_ENTRY_WITH_EXCEPTION = + BeamFnApi.LogEntry.newBuilder() + .setSeverity(BeamFnApi.LogEntry.Severity.WARN) + .setMessage("MessageWithException") + .setTrace(getStackTraceAsString(TEST_RECORD_WITH_EXCEPTION.getThrown())) + .setThread("12345") + .setTimestamp(Timestamp.newBuilder().setSeconds(1234567).setNanos(890000000).build()) + .setLogLocation("LoggerName") + .build(); + + @Test + public void testLogging() throws Exception { + AtomicBoolean clientClosedStream = new AtomicBoolean(); + Collection<BeamFnApi.LogEntry> values = new ConcurrentLinkedQueue<>(); + AtomicReference<StreamObserver<BeamFnApi.LogControl>> outboundServerObserver = + new AtomicReference<>(); + CallStreamObserver<BeamFnApi.LogEntry.List> inboundServerObserver = + TestStreams.withOnNext( + (BeamFnApi.LogEntry.List logEntries) -> values.addAll(logEntries.getLogEntriesList())) + .withOnCompleted(new Runnable() { + @Override + public void run() { + // Remember that the client told us that this stream completed + clientClosedStream.set(true); + outboundServerObserver.get().onCompleted(); + } + }).build(); + + BeamFnApi.ApiServiceDescriptor apiServiceDescriptor = + BeamFnApi.ApiServiceDescriptor.newBuilder() + .setUrl(this.getClass().getName() + "-" + UUID.randomUUID().toString()) + .build(); + Server server = InProcessServerBuilder.forName(apiServiceDescriptor.getUrl()) + .addService(new BeamFnLoggingGrpc.BeamFnLoggingImplBase() { + @Override + public StreamObserver<BeamFnApi.LogEntry.List> logging( + StreamObserver<BeamFnApi.LogControl> outboundObserver) { + outboundServerObserver.set(outboundObserver); + return inboundServerObserver; + } + }) + .build(); + server.start(); + try { + ManagedChannel channel = + InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build(); + + BeamFnLoggingClient client = new BeamFnLoggingClient( + PipelineOptionsFactory.fromArgs(new String[] { + "--defaultWorkerLogLevel=OFF", + "--workerLogLevelOverrides={\"ConfiguredLogger\": \"DEBUG\"}" + }).create(), + apiServiceDescriptor, + (BeamFnApi.ApiServiceDescriptor descriptor) -> channel, + this::createStreamForTest); + + // Ensure that log levels were correctly set. + assertEquals(Level.OFF, + LogManager.getLogManager().getLogger("").getLevel()); + assertEquals(Level.FINE, + LogManager.getLogManager().getLogger("ConfiguredLogger").getLevel()); + + // Should be filtered because the default log level override is OFF + LogManager.getLogManager().getLogger("").log(FILTERED_RECORD); + // Should not be filtered because the default log level override for ConfiguredLogger is DEBUG + LogManager.getLogManager().getLogger("ConfiguredLogger").log(TEST_RECORD); + LogManager.getLogManager().getLogger("ConfiguredLogger").log(TEST_RECORD_WITH_EXCEPTION); + client.close(); + + // Verify that after close, log levels are reset. + assertEquals(Level.INFO, LogManager.getLogManager().getLogger("").getLevel()); + assertNull(LogManager.getLogManager().getLogger("ConfiguredLogger").getLevel()); + + assertTrue(clientClosedStream.get()); + assertTrue(channel.isShutdown()); + assertThat(values, contains(TEST_ENTRY, TEST_ENTRY_WITH_EXCEPTION)); + } finally { + server.shutdownNow(); + } + } + + private <ReqT, RespT> StreamObserver<RespT> createStreamForTest( + Function<StreamObserver<ReqT>, StreamObserver<RespT>> clientFactory, + StreamObserver<ReqT> handler) { + return clientFactory.apply(handler); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/AdvancingPhaserTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/AdvancingPhaserTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/AdvancingPhaserTest.java new file mode 100644 index 0000000..3dd1b42 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/AdvancingPhaserTest.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.stream; + +import static org.hamcrest.collection.IsEmptyCollection.empty; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link AdvancingPhaser}. */ +@RunWith(JUnit4.class) +public class AdvancingPhaserTest { + @Test + public void testAdvancement() throws Exception { + AdvancingPhaser phaser = new AdvancingPhaser(1); + int currentPhase = phaser.getPhase(); + ExecutorService service = Executors.newSingleThreadExecutor(); + service.submit(phaser::arrive); + phaser.awaitAdvance(currentPhase); + assertFalse(phaser.isTerminated()); + service.shutdown(); + if (!service.awaitTermination(10, TimeUnit.SECONDS)) { + assertThat(service.shutdownNow(), empty()); + } + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/BufferingStreamObserverTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/BufferingStreamObserverTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/BufferingStreamObserverTest.java new file mode 100644 index 0000000..76b7ef0 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/BufferingStreamObserverTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.stream; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.Uninterruptibles; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import org.apache.beam.fn.harness.test.TestExecutors; +import org.apache.beam.fn.harness.test.TestExecutors.TestExecutorService; +import org.apache.beam.fn.harness.test.TestStreams; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link BufferingStreamObserver}. */ +@RunWith(JUnit4.class) +public class BufferingStreamObserverTest { + @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool); + + @Test + public void testThreadSafety() throws Exception { + List<String> onNextValues = new ArrayList<>(); + AdvancingPhaser phaser = new AdvancingPhaser(1); + final AtomicBoolean isCriticalSectionShared = new AtomicBoolean(); + final BufferingStreamObserver<String> streamObserver = + new BufferingStreamObserver<>( + phaser, + TestStreams.withOnNext( + new Consumer<String>() { + @Override + public void accept(String t) { + // Use the atomic boolean to detect if multiple threads are in this + // critical section. Any thread that enters purposefully blocks by sleeping + // to increase the contention between threads artificially. + assertFalse(isCriticalSectionShared.getAndSet(true)); + Uninterruptibles.sleepUninterruptibly(50, TimeUnit.MILLISECONDS); + onNextValues.add(t); + assertTrue(isCriticalSectionShared.getAndSet(false)); + } + }).build(), + executor, + 3); + + List<String> prefixes = ImmutableList.of("0", "1", "2", "3", "4"); + List<Callable<String>> tasks = new ArrayList<>(); + for (String prefix : prefixes) { + tasks.add( + new Callable<String>() { + @Override + public String call() throws Exception { + for (int i = 0; i < 10; i++) { + streamObserver.onNext(prefix + i); + } + return prefix; + } + }); + } + List<Future<String>> results = executor.invokeAll(tasks); + for (Future<String> result : results) { + result.get(); + } + streamObserver.onCompleted(); + + // Check that order was maintained. + int[] prefixesIndex = new int[prefixes.size()]; + assertEquals(50, onNextValues.size()); + for (String onNextValue : onNextValues) { + int prefix = Integer.parseInt(onNextValue.substring(0, 1)); + int suffix = Integer.parseInt(onNextValue.substring(1, 2)); + assertEquals(prefixesIndex[prefix], suffix); + prefixesIndex[prefix] += 1; + } + } + + @Test + public void testIsReadyIsHonored() throws Exception { + AdvancingPhaser phaser = new AdvancingPhaser(1); + final AtomicBoolean elementsAllowed = new AtomicBoolean(); + final BufferingStreamObserver<String> streamObserver = + new BufferingStreamObserver<>( + phaser, + TestStreams.withOnNext( + new Consumer<String>() { + @Override + public void accept(String t) { + assertTrue(elementsAllowed.get()); + } + }).withIsReady(elementsAllowed::get).build(), + executor, + 3); + + // Start all the tasks + List<Future<String>> results = new ArrayList<>(); + for (String prefix : ImmutableList.of("0", "1", "2", "3", "4")) { + results.add( + executor.submit( + new Callable<String>() { + @Override + public String call() throws Exception { + for (int i = 0; i < 10; i++) { + streamObserver.onNext(prefix + i); + } + return prefix; + } + })); + } + + // Have them wait and then flip that we do allow elements and wake up those awaiting + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + elementsAllowed.set(true); + phaser.arrive(); + + for (Future<String> result : results) { + result.get(); + } + streamObserver.onCompleted(); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/DirectStreamObserverTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/DirectStreamObserverTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/DirectStreamObserverTest.java new file mode 100644 index 0000000..b5d3ec1 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/DirectStreamObserverTest.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.stream; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.Uninterruptibles; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import org.apache.beam.fn.harness.test.TestExecutors; +import org.apache.beam.fn.harness.test.TestExecutors.TestExecutorService; +import org.apache.beam.fn.harness.test.TestStreams; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link DirectStreamObserver}. */ +@RunWith(JUnit4.class) +public class DirectStreamObserverTest { + @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool); + + @Test + public void testThreadSafety() throws Exception { + List<String> onNextValues = new ArrayList<>(); + AdvancingPhaser phaser = new AdvancingPhaser(1); + final AtomicBoolean isCriticalSectionShared = new AtomicBoolean(); + final DirectStreamObserver<String> streamObserver = + new DirectStreamObserver<>( + phaser, + TestStreams.withOnNext( + new Consumer<String>() { + @Override + public void accept(String t) { + // Use the atomic boolean to detect if multiple threads are in this + // critical section. Any thread that enters purposefully blocks by sleeping + // to increase the contention between threads artificially. + assertFalse(isCriticalSectionShared.getAndSet(true)); + Uninterruptibles.sleepUninterruptibly(50, TimeUnit.MILLISECONDS); + onNextValues.add(t); + assertTrue(isCriticalSectionShared.getAndSet(false)); + } + }).build()); + + List<String> prefixes = ImmutableList.of("0", "1", "2", "3", "4"); + List<Callable<String>> tasks = new ArrayList<>(); + for (String prefix : prefixes) { + tasks.add( + new Callable<String>() { + @Override + public String call() throws Exception { + for (int i = 0; i < 10; i++) { + streamObserver.onNext(prefix + i); + } + return prefix; + } + }); + } + executor.invokeAll(tasks); + streamObserver.onCompleted(); + + // Check that order was maintained. + int[] prefixesIndex = new int[prefixes.size()]; + assertEquals(50, onNextValues.size()); + for (String onNextValue : onNextValues) { + int prefix = Integer.parseInt(onNextValue.substring(0, 1)); + int suffix = Integer.parseInt(onNextValue.substring(1, 2)); + assertEquals(prefixesIndex[prefix], suffix); + prefixesIndex[prefix] += 1; + } + } + + @Test + public void testIsReadyIsHonored() throws Exception { + AdvancingPhaser phaser = new AdvancingPhaser(1); + final AtomicBoolean elementsAllowed = new AtomicBoolean(); + final DirectStreamObserver<String> streamObserver = + new DirectStreamObserver<>( + phaser, + TestStreams.withOnNext( + new Consumer<String>() { + @Override + public void accept(String t) { + assertTrue(elementsAllowed.get()); + } + }).withIsReady(elementsAllowed::get).build()); + + // Start all the tasks + List<Future<String>> results = new ArrayList<>(); + for (String prefix : ImmutableList.of("0", "1", "2", "3", "4")) { + results.add( + executor.submit( + new Callable<String>() { + @Override + public String call() throws Exception { + for (int i = 0; i < 10; i++) { + streamObserver.onNext(prefix + i); + } + return prefix; + } + })); + } + + // Have them wait and then flip that we do allow elements and wake up those awaiting + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + elementsAllowed.set(true); + phaser.arrive(); + + for (Future<String> result : results) { + result.get(); + } + streamObserver.onCompleted(); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/ForwardingClientResponseObserverTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/ForwardingClientResponseObserverTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/ForwardingClientResponseObserverTest.java new file mode 100644 index 0000000..598d7f3 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/ForwardingClientResponseObserverTest.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.stream; + +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import io.grpc.stub.ClientCallStreamObserver; +import io.grpc.stub.ClientResponseObserver; +import io.grpc.stub.StreamObserver; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +/** Tests for {@link ForwardingClientResponseObserver}. */ +@RunWith(JUnit4.class) +public class ForwardingClientResponseObserverTest { + @Test + public void testCallsAreForwardedAndOnReadyHandlerBound() { + @SuppressWarnings("unchecked") + StreamObserver<Object> delegateObserver = Mockito.mock(StreamObserver.class); + @SuppressWarnings("unchecked") + ClientCallStreamObserver<Object> callStreamObserver = + Mockito.mock(ClientCallStreamObserver.class); + Runnable onReadyHandler = new Runnable() { + @Override + public void run() { + } + }; + ClientResponseObserver<Object, Object> observer = + new ForwardingClientResponseObserver<>(delegateObserver, onReadyHandler); + observer.onNext("A"); + verify(delegateObserver).onNext("A"); + Throwable t = new RuntimeException(); + observer.onError(t); + verify(delegateObserver).onError(t); + observer.onCompleted(); + verify(delegateObserver).onCompleted(); + observer.beforeStart(callStreamObserver); + verify(callStreamObserver).setOnReadyHandler(onReadyHandler); + verifyNoMoreInteractions(delegateObserver, callStreamObserver); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/StreamObserverFactoryTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/StreamObserverFactoryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/StreamObserverFactoryTest.java new file mode 100644 index 0000000..9331079 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/stream/StreamObserverFactoryTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.stream; + +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import io.grpc.stub.CallStreamObserver; +import io.grpc.stub.StreamObserver; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** Tests for {@link StreamObserverFactory}. */ +@RunWith(JUnit4.class) +public class StreamObserverFactoryTest { + @Mock private StreamObserver<Integer> mockRequestObserver; + @Mock private CallStreamObserver<String> mockResponseObserver; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testDefaultInstantiation() { + StreamObserver<String> observer = + StreamObserverFactory.fromOptions(PipelineOptionsFactory.create()) + .from(this::fakeFactory, mockRequestObserver); + assertThat(observer, instanceOf(DirectStreamObserver.class)); + } + + @Test + public void testBufferedStreamInstantiation() { + StreamObserver<String> observer = + StreamObserverFactory.fromOptions( + PipelineOptionsFactory.fromArgs( + new String[] {"--experiments=beam_fn_api_buffered_stream"}) + .create()) + .from(this::fakeFactory, mockRequestObserver); + assertThat(observer, instanceOf(BufferingStreamObserver.class)); + } + + @Test + public void testBufferedStreamWithLimitInstantiation() { + StreamObserver<String> observer = + StreamObserverFactory.fromOptions( + PipelineOptionsFactory.fromArgs( + new String[] { + "--experiments=beam_fn_api_buffered_stream," + + "beam_fn_api_buffered_stream_buffer_size=1" + }) + .create()) + .from(this::fakeFactory, mockRequestObserver); + assertThat(observer, instanceOf(BufferingStreamObserver.class)); + assertEquals(1, ((BufferingStreamObserver<String>) observer).getBufferSize()); + } + + private StreamObserver<String> fakeFactory(StreamObserver<Integer> observer) { + assertThat(observer, instanceOf(ForwardingClientResponseObserver.class)); + return mockResponseObserver; + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestExecutors.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestExecutors.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestExecutors.java new file mode 100644 index 0000000..f846466 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestExecutors.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.test; + +import com.google.common.util.concurrent.ForwardingExecutorService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; + +/** + * A {@link TestRule} that validates that all submitted tasks finished and were completed. This + * allows for testing that tasks have exercised the appropriate shutdown logic. + */ +public class TestExecutors { + public static TestExecutorService from(Supplier<ExecutorService> executorServiceSuppler) { + return new FromSupplier(executorServiceSuppler); + } + + /** A union of the {@link ExecutorService} and {@link TestRule} interfaces. */ + public interface TestExecutorService extends ExecutorService, TestRule {} + + private static class FromSupplier extends ForwardingExecutorService + implements TestExecutorService { + private final Supplier<ExecutorService> executorServiceSupplier; + private ExecutorService delegate; + + private FromSupplier(Supplier<ExecutorService> executorServiceSupplier) { + this.executorServiceSupplier = executorServiceSupplier; + } + + @Override + public Statement apply(Statement statement, Description arg1) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + Throwable thrown = null; + delegate = executorServiceSupplier.get(); + try { + statement.evaluate(); + } catch (Throwable t) { + thrown = t; + } + shutdown(); + if (!awaitTermination(5, TimeUnit.SECONDS)) { + shutdownNow(); + IllegalStateException e = + new IllegalStateException("Test executor failed to shutdown cleanly."); + if (thrown != null) { + thrown.addSuppressed(e); + } else { + thrown = e; + } + } + if (thrown != null) { + throw thrown; + } + } + }; + } + + @Override + protected ExecutorService delegate() { + return delegate; + } + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestExecutorsTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestExecutorsTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestExecutorsTest.java new file mode 100644 index 0000000..85c64d0 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestExecutorsTest.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.beam.fn.harness.test.TestExecutors.TestExecutorService; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.junit.runners.model.Statement; + +/** Tests for {@link TestExecutors}. */ +@RunWith(JUnit4.class) +public class TestExecutorsTest { + @Test + public void testSuccessfulTermination() throws Throwable { + ExecutorService service = Executors.newSingleThreadExecutor(); + final TestExecutorService testService = TestExecutors.from(() -> service); + final AtomicBoolean taskRan = new AtomicBoolean(); + testService + .apply( + new Statement() { + @Override + public void evaluate() throws Throwable { + testService.submit(() -> taskRan.set(true)); + } + }, + null) + .evaluate(); + assertTrue(service.isTerminated()); + assertTrue(taskRan.get()); + } + + @Test + public void testTaskBlocksForeverCausesFailure() throws Throwable { + ExecutorService service = Executors.newSingleThreadExecutor(); + final TestExecutorService testService = TestExecutors.from(() -> service); + final AtomicBoolean taskStarted = new AtomicBoolean(); + final AtomicBoolean taskWasInterrupted = new AtomicBoolean(); + try { + testService + .apply( + new Statement() { + @Override + public void evaluate() throws Throwable { + testService.submit(this::taskToRun); + } + + private void taskToRun() { + taskStarted.set(true); + try { + while (true) { + Thread.sleep(10000); + } + } catch (InterruptedException e) { + taskWasInterrupted.set(true); + return; + } + } + }, + null) + .evaluate(); + fail(); + } catch (IllegalStateException e) { + assertEquals(IllegalStateException.class, e.getClass()); + assertEquals("Test executor failed to shutdown cleanly.", e.getMessage()); + } + assertTrue(service.isShutdown()); + } + + @Test + public void testStatementFailurePropagatedCleanly() throws Throwable { + ExecutorService service = Executors.newSingleThreadExecutor(); + final TestExecutorService testService = TestExecutors.from(() -> service); + final RuntimeException exceptionToThrow = new RuntimeException(); + try { + testService + .apply( + new Statement() { + @Override + public void evaluate() throws Throwable { + throw exceptionToThrow; + } + }, + null) + .evaluate(); + fail(); + } catch (RuntimeException thrownException) { + assertSame(exceptionToThrow, thrownException); + } + assertTrue(service.isShutdown()); + } + + @Test + public void testStatementFailurePropagatedWhenExecutorServiceFailingToTerminate() + throws Throwable { + ExecutorService service = Executors.newSingleThreadExecutor(); + final TestExecutorService testService = TestExecutors.from(() -> service); + final AtomicBoolean taskStarted = new AtomicBoolean(); + final AtomicBoolean taskWasInterrupted = new AtomicBoolean(); + final RuntimeException exceptionToThrow = new RuntimeException(); + try { + testService + .apply( + new Statement() { + @Override + public void evaluate() throws Throwable { + testService.submit(this::taskToRun); + throw exceptionToThrow; + } + + private void taskToRun() { + taskStarted.set(true); + try { + while (true) { + Thread.sleep(10000); + } + } catch (InterruptedException e) { + taskWasInterrupted.set(true); + return; + } + } + }, + null) + .evaluate(); + fail(); + } catch (RuntimeException thrownException) { + assertSame(exceptionToThrow, thrownException); + assertEquals(1, exceptionToThrow.getSuppressed().length); + assertEquals(IllegalStateException.class, exceptionToThrow.getSuppressed()[0].getClass()); + assertEquals( + "Test executor failed to shutdown cleanly.", + exceptionToThrow.getSuppressed()[0].getMessage()); + } + assertTrue(service.isShutdown()); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestStreams.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestStreams.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestStreams.java new file mode 100644 index 0000000..f398286 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestStreams.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.test; + +import io.grpc.stub.CallStreamObserver; +import io.grpc.stub.StreamObserver; +import java.util.function.Consumer; +import java.util.function.Supplier; + +/** Utility methods which enable testing of {@link StreamObserver}s. */ +public class TestStreams { + /** + * Creates a test {@link CallStreamObserver} {@link Builder} that forwards + * {@link StreamObserver#onNext} calls to the supplied {@link Consumer}. + */ + public static <T> Builder<T> withOnNext(Consumer<T> onNext) { + return new Builder<>(new ForwardingCallStreamObserver<>( + onNext, + TestStreams::noop, + TestStreams::noop, + TestStreams::returnTrue)); + } + + /** A builder for a test {@link CallStreamObserver} that performs various callbacks. */ + public static class Builder<T> { + private final ForwardingCallStreamObserver<T> observer; + private Builder(ForwardingCallStreamObserver<T> observer) { + this.observer = observer; + } + + /** + * Returns a new {@link Builder} like this one with the specified + * {@link CallStreamObserver#isReady} callback. + */ + public Builder<T> withIsReady(Supplier<Boolean> isReady) { + return new Builder<>(new ForwardingCallStreamObserver<>( + observer.onNext, + observer.onError, + observer.onCompleted, + isReady)); + } + + /** + * Returns a new {@link Builder} like this one with the specified + * {@link StreamObserver#onCompleted} callback. + */ + public Builder<T> withOnCompleted(Runnable onCompleted) { + return new Builder<>(new ForwardingCallStreamObserver<>( + observer.onNext, + observer.onError, + onCompleted, + observer.isReady)); + } + + /** + * Returns a new {@link Builder} like this one with the specified + * {@link StreamObserver#onError} callback. + */ + public Builder<T> withOnError(Runnable onError) { + return new Builder<>(new ForwardingCallStreamObserver<>( + observer.onNext, + new Consumer<Throwable>() { + @Override + public void accept(Throwable t) { + onError.run(); + } + }, + observer.onCompleted, + observer.isReady)); + } + + /** + * Returns a new {@link Builder} like this one with the specified + * {@link StreamObserver#onError} consumer. + */ + public Builder<T> withOnError(Consumer<Throwable> onError) { + return new Builder<>(new ForwardingCallStreamObserver<>( + observer.onNext, onError, observer.onCompleted, observer.isReady)); + } + + public CallStreamObserver<T> build() { + return observer; + } + } + + private static void noop() { + } + + private static void noop(Throwable t) { + } + + private static boolean returnTrue() { + return true; + } + + /** A {@link CallStreamObserver} which executes the supplied callbacks. */ + private static class ForwardingCallStreamObserver<T> extends CallStreamObserver<T> { + private final Consumer<T> onNext; + private final Supplier<Boolean> isReady; + private final Consumer<Throwable> onError; + private final Runnable onCompleted; + + public ForwardingCallStreamObserver( + Consumer<T> onNext, + Consumer<Throwable> onError, + Runnable onCompleted, + Supplier<Boolean> isReady) { + this.onNext = onNext; + this.onError = onError; + this.onCompleted = onCompleted; + this.isReady = isReady; + } + + @Override + public void onNext(T value) { + onNext.accept(value); + } + + @Override + public void onError(Throwable t) { + onError.accept(t); + } + + @Override + public void onCompleted() { + onCompleted.run(); + } + + @Override + public boolean isReady() { + return isReady.get(); + } + + @Override + public void setOnReadyHandler(Runnable onReadyHandler) {} + + @Override + public void disableAutoInboundFlowControl() {} + + @Override + public void request(int count) {} + + @Override + public void setMessageCompression(boolean enable) {} + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestStreamsTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestStreamsTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestStreamsTest.java new file mode 100644 index 0000000..b684c90 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/test/TestStreamsTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.test; + +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link TestStreams}. */ +@RunWith(JUnit4.class) +public class TestStreamsTest { + @Test + public void testOnNextIsCalled() { + AtomicBoolean onNextWasCalled = new AtomicBoolean(); + TestStreams.withOnNext(onNextWasCalled::set).build().onNext(true); + assertTrue(onNextWasCalled.get()); + } + + @Test + public void testIsReadyIsCalled() { + final AtomicBoolean isReadyWasCalled = new AtomicBoolean(); + assertFalse(TestStreams.withOnNext(null) + .withIsReady(() -> isReadyWasCalled.getAndSet(true)) + .build() + .isReady()); + assertTrue(isReadyWasCalled.get()); + } + + @Test + public void testOnCompletedIsCalled() { + AtomicBoolean onCompletedWasCalled = new AtomicBoolean(); + TestStreams.withOnNext(null) + .withOnCompleted(() -> onCompletedWasCalled.set(true)) + .build() + .onCompleted(); + assertTrue(onCompletedWasCalled.get()); + } + + @Test + public void testOnErrorRunnableIsCalled() { + RuntimeException throwable = new RuntimeException(); + AtomicBoolean onErrorWasCalled = new AtomicBoolean(); + TestStreams.withOnNext(null) + .withOnError(() -> onErrorWasCalled.set(true)) + .build() + .onError(throwable); + assertTrue(onErrorWasCalled.get()); + } + + @Test + public void testOnErrorConsumerIsCalled() { + RuntimeException throwable = new RuntimeException(); + Collection<Throwable> onErrorWasCalled = new ArrayList<>(); + TestStreams.withOnNext(null) + .withOnError(onErrorWasCalled::add) + .build() + .onError(throwable); + assertThat(onErrorWasCalled, contains(throwable)); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java new file mode 100644 index 0000000..511cc3f --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core; + +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.Uninterruptibles; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.beam.fn.harness.data.BeamFnDataClient; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.test.TestExecutors; +import org.apache.beam.fn.harness.test.TestExecutors.TestExecutorService; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** Tests for {@link BeamFnDataReadRunner}. */ +@RunWith(JUnit4.class) +public class BeamFnDataReadRunnerTest { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder() + .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build(); + private static final BeamFnApi.FunctionSpec FUNCTION_SPEC = BeamFnApi.FunctionSpec.newBuilder() + .setData(Any.pack(PORT_SPEC)).build(); + private static final Coder<WindowedValue<String>> CODER = + WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); + private static final BeamFnApi.Coder CODER_SPEC; + static { + try { + CODER_SPEC = BeamFnApi.Coder.newBuilder().setFunctionSpec(BeamFnApi.FunctionSpec.newBuilder() + .setData(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes(CODER.asCloudObject()))).build()))) + .build(); + } catch (IOException e) { + throw new ExceptionInInitializerError(e); + } + } + private static final BeamFnApi.Target INPUT_TARGET = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(1) + .setName("out") + .build(); + + @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool); + @Mock private BeamFnDataClient mockBeamFnDataClientFactory; + @Captor private ArgumentCaptor<ThrowingConsumer<WindowedValue<String>>> consumerCaptor; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testReuseForMultipleBundles() throws Exception { + CompletableFuture<Void> bundle1Future = new CompletableFuture<>(); + CompletableFuture<Void> bundle2Future = new CompletableFuture<>(); + when(mockBeamFnDataClientFactory.forInboundConsumer( + any(), + any(), + any(), + any())).thenReturn(bundle1Future).thenReturn(bundle2Future); + List<WindowedValue<String>> valuesA = new ArrayList<>(); + List<WindowedValue<String>> valuesB = new ArrayList<>(); + Map<String, Collection<ThrowingConsumer<WindowedValue<String>>>> outputMap = ImmutableMap.of( + "outA", ImmutableList.of(valuesA::add), + "outB", ImmutableList.of(valuesB::add)); + AtomicLong bundleId = new AtomicLong(0); + BeamFnDataReadRunner<String> readRunner = new BeamFnDataReadRunner<>( + FUNCTION_SPEC, + bundleId::get, + INPUT_TARGET, + CODER_SPEC, + mockBeamFnDataClientFactory, + outputMap); + + // Process for bundle id 0 + readRunner.registerInputLocation(); + + verify(mockBeamFnDataClientFactory).forInboundConsumer( + eq(PORT_SPEC.getApiServiceDescriptor()), + eq(KV.of(bundleId.get(), INPUT_TARGET)), + eq(CODER), + consumerCaptor.capture()); + + executor.submit(new Runnable() { + @Override + public void run() { + // Sleep for some small amount of time simulating the parent blocking + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + try { + consumerCaptor.getValue().accept(valueInGlobalWindow("ABC")); + consumerCaptor.getValue().accept(valueInGlobalWindow("DEF")); + } catch (Exception e) { + bundle1Future.completeExceptionally(e); + } finally { + bundle1Future.complete(null); + } + } + }); + + readRunner.blockTillReadFinishes(); + assertThat(valuesA, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"))); + assertThat(valuesB, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"))); + + // Process for bundle id 1 + bundleId.incrementAndGet(); + valuesA.clear(); + valuesB.clear(); + readRunner.registerInputLocation(); + + verify(mockBeamFnDataClientFactory).forInboundConsumer( + eq(PORT_SPEC.getApiServiceDescriptor()), + eq(KV.of(bundleId.get(), INPUT_TARGET)), + eq(CODER), + consumerCaptor.capture()); + + executor.submit(new Runnable() { + @Override + public void run() { + // Sleep for some small amount of time simulating the parent blocking + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + try { + consumerCaptor.getValue().accept(valueInGlobalWindow("GHI")); + consumerCaptor.getValue().accept(valueInGlobalWindow("JKL")); + } catch (Exception e) { + bundle2Future.completeExceptionally(e); + } finally { + bundle2Future.complete(null); + } + } + }); + + readRunner.blockTillReadFinishes(); + assertThat(valuesA, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); + assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); + + verifyNoMoreInteractions(mockBeamFnDataClientFactory); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java new file mode 100644 index 0000000..ed67b14 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core; + +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import java.io.IOException; +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.beam.fn.harness.data.BeamFnDataClient; +import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Matchers; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** Tests for {@link BeamFnDataWriteRunner}. */ +@RunWith(JUnit4.class) +public class BeamFnDataWriteRunnerTest { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder() + .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build(); + private static final BeamFnApi.FunctionSpec FUNCTION_SPEC = BeamFnApi.FunctionSpec.newBuilder() + .setData(Any.pack(PORT_SPEC)).build(); + private static final Coder<WindowedValue<String>> CODER = + WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); + private static final BeamFnApi.Coder CODER_SPEC; + static { + try { + CODER_SPEC = BeamFnApi.Coder.newBuilder().setFunctionSpec(BeamFnApi.FunctionSpec.newBuilder() + .setData(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes(CODER.asCloudObject()))).build()))) + .build(); + } catch (IOException e) { + throw new ExceptionInInitializerError(e); + } + } + private static final BeamFnApi.Target OUTPUT_TARGET = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(1) + .setName("out") + .build(); + + @Mock private BeamFnDataClient mockBeamFnDataClientFactory; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testReuseForMultipleBundles() throws Exception { + RecordingConsumer<WindowedValue<String>> valuesA = new RecordingConsumer<>(); + RecordingConsumer<WindowedValue<String>> valuesB = new RecordingConsumer<>(); + when(mockBeamFnDataClientFactory.forOutboundConsumer( + any(), + any(), + Matchers.<Coder<WindowedValue<String>>>any())).thenReturn(valuesA).thenReturn(valuesB); + AtomicLong bundleId = new AtomicLong(0); + BeamFnDataWriteRunner<String> writeRunner = new BeamFnDataWriteRunner<>( + FUNCTION_SPEC, + bundleId::get, + OUTPUT_TARGET, + CODER_SPEC, + mockBeamFnDataClientFactory); + + // Process for bundle id 0 + writeRunner.registerForOutput(); + + verify(mockBeamFnDataClientFactory).forOutboundConsumer( + eq(PORT_SPEC.getApiServiceDescriptor()), + eq(KV.of(bundleId.get(), OUTPUT_TARGET)), + eq(CODER)); + + writeRunner.consume(valueInGlobalWindow("ABC")); + writeRunner.consume(valueInGlobalWindow("DEF")); + writeRunner.close(); + + assertTrue(valuesA.closed); + assertThat(valuesA, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"))); + + // Process for bundle id 1 + bundleId.incrementAndGet(); + valuesA.clear(); + valuesB.clear(); + writeRunner.registerForOutput(); + + verify(mockBeamFnDataClientFactory).forOutboundConsumer( + eq(PORT_SPEC.getApiServiceDescriptor()), + eq(KV.of(bundleId.get(), OUTPUT_TARGET)), + eq(CODER)); + + writeRunner.consume(valueInGlobalWindow("GHI")); + writeRunner.consume(valueInGlobalWindow("JKL")); + writeRunner.close(); + + assertTrue(valuesB.closed); + assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); + verifyNoMoreInteractions(mockBeamFnDataClientFactory); + } + + private static class RecordingConsumer<T> extends ArrayList<T> + implements CloseableThrowingConsumer<T> { + private boolean closed; + @Override + public void close() throws Exception { + closed = true; + } + + @Override + public void accept(T t) throws Exception { + if (closed) { + throw new IllegalStateException("Consumer is closed but attempting to consume " + t); + } + add(t); + } + + } +}