http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java new file mode 100644 index 0000000..92042d0 --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.FluentIterable; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.BytesValue; +import java.io.IOException; +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; +import org.apache.beam.fn.harness.data.BeamFnDataClient; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.util.Serializer; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Registers as a consumer for data over the Beam Fn API. Multiplexes any received data + * to all consumers in the specified output map. + * + * <p>Can be re-used serially across {@link org.apache.beam.fn.v1.BeamFnApi.ProcessBundleRequest}s. + * For each request, call {@link #registerInputLocation()} to start and call + * {@link #blockTillReadFinishes()} to finish. + */ +public class BeamFnDataReadRunner<OutputT> { + private static final Logger LOGGER = LoggerFactory.getLogger(BeamFnDataReadRunner.class); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private final BeamFnApi.ApiServiceDescriptor apiServiceDescriptor; + private final Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers; + private final Supplier<Long> processBundleInstructionIdSupplier; + private final BeamFnDataClient beamFnDataClientFactory; + private final Coder<WindowedValue<OutputT>> coder; + private final BeamFnApi.Target inputTarget; + + private CompletableFuture<Void> readFuture; + + public BeamFnDataReadRunner( + BeamFnApi.FunctionSpec functionSpec, + Supplier<Long> processBundleInstructionIdSupplier, + BeamFnApi.Target inputTarget, + BeamFnApi.Coder coderSpec, + BeamFnDataClient beamFnDataClientFactory, + Map<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap) + throws IOException { + this.apiServiceDescriptor = functionSpec.getData().unpack(BeamFnApi.RemoteGrpcPort.class) + .getApiServiceDescriptor(); + this.inputTarget = inputTarget; + this.processBundleInstructionIdSupplier = processBundleInstructionIdSupplier; + this.beamFnDataClientFactory = beamFnDataClientFactory; + this.consumers = ImmutableList.copyOf(FluentIterable.concat(outputMap.values())); + + @SuppressWarnings("unchecked") + Coder<WindowedValue<OutputT>> coder = Serializer.deserialize( + OBJECT_MAPPER.readValue( + coderSpec.getFunctionSpec().getData().unpack(BytesValue.class).getValue().newInput(), + Map.class), + Coder.class); + this.coder = coder; + } + + public void registerInputLocation() { + this.readFuture = beamFnDataClientFactory.forInboundConsumer( + apiServiceDescriptor, + KV.of(processBundleInstructionIdSupplier.get(), inputTarget), + coder, + this::multiplexToConsumers); + } + + public void blockTillReadFinishes() throws Exception { + LOGGER.debug("Waiting for process bundle instruction {} and target {} to close.", + processBundleInstructionIdSupplier.get(), inputTarget); + readFuture.get(); + } + + private void multiplexToConsumers(WindowedValue<OutputT> value) throws Exception { + for (ThrowingConsumer<WindowedValue<OutputT>> consumer : consumers) { + consumer.accept(value); + } + } +}
http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java new file mode 100644 index 0000000..596afe5 --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.protobuf.BytesValue; +import java.io.IOException; +import java.util.Map; +import java.util.function.Supplier; +import org.apache.beam.fn.harness.data.BeamFnDataClient; +import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.util.Serializer; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; + +/** + * Registers as a consumer with the Beam Fn Data API. Propagates and elements consumed to + * the the registered consumer. + * + * <p>Can be re-used serially across {@link org.apache.beam.fn.v1.BeamFnApi.ProcessBundleRequest}s. + * For each request, call {@link #registerForOutput()} to start and call {@link #close()} to finish. + */ +public class BeamFnDataWriteRunner<InputT> { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private final BeamFnApi.ApiServiceDescriptor apiServiceDescriptor; + private final BeamFnApi.Target outputTarget; + private final Coder<WindowedValue<InputT>> coder; + private final BeamFnDataClient beamFnDataClientFactory; + private final Supplier<Long> processBundleInstructionIdSupplier; + + private CloseableThrowingConsumer<WindowedValue<InputT>> consumer; + + public BeamFnDataWriteRunner( + BeamFnApi.FunctionSpec functionSpec, + Supplier<Long> processBundleInstructionIdSupplier, + BeamFnApi.Target outputTarget, + BeamFnApi.Coder coderSpec, + BeamFnDataClient beamFnDataClientFactory) + throws IOException { + this.apiServiceDescriptor = functionSpec.getData().unpack(BeamFnApi.RemoteGrpcPort.class) + .getApiServiceDescriptor(); + this.beamFnDataClientFactory = beamFnDataClientFactory; + this.processBundleInstructionIdSupplier = processBundleInstructionIdSupplier; + this.outputTarget = outputTarget; + + @SuppressWarnings("unchecked") + Coder<WindowedValue<InputT>> coder = Serializer.deserialize( + OBJECT_MAPPER.readValue( + coderSpec.getFunctionSpec().getData().unpack(BytesValue.class).getValue().newInput(), + Map.class), + Coder.class); + this.coder = coder; + } + + public void registerForOutput() { + consumer = beamFnDataClientFactory.forOutboundConsumer( + apiServiceDescriptor, + KV.of(processBundleInstructionIdSupplier.get(), outputTarget), + coder); + } + + public void close() throws Exception { + consumer.close(); + } + + public void consume(WindowedValue<InputT> value) throws Exception { + consumer.accept(value); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java new file mode 100644 index 0000000..9d9c433 --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core; + +import com.google.common.collect.FluentIterable; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.BytesValue; +import com.google.protobuf.InvalidProtocolBufferException; +import java.io.IOException; +import java.util.Collection; +import java.util.Map; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.Source.Reader; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.util.WindowedValue; + +/** + * A runner which creates {@link Reader}s for each {@link BoundedSource} and executes + * the {@link Reader}s read loop. + */ +public class BoundedSourceRunner<InputT extends BoundedSource<OutputT>, OutputT> { + private final PipelineOptions pipelineOptions; + private final BeamFnApi.FunctionSpec definition; + private final Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers; + + public BoundedSourceRunner( + PipelineOptions pipelineOptions, + BeamFnApi.FunctionSpec definition, + Map<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap) { + this.pipelineOptions = pipelineOptions; + this.definition = definition; + this.consumers = ImmutableList.copyOf(FluentIterable.concat(outputMap.values())); + } + + /** + * The runner harness is meant to send the source over the Beam Fn Data API which would be + * consumed by the {@link #runReadLoop}. Drop this method once the runner harness sends the + * source instead of unpacking it from the data block of the function specification. + */ + @Deprecated + public void start() throws Exception { + try { + // The representation here is defined as the java serialized representation of the + // bounded source object packed into a protobuf Any using a protobuf BytesValue wrapper. + byte[] bytes = definition.getData().unpack(BytesValue.class).getValue().toByteArray(); + @SuppressWarnings("unchecked") + InputT boundedSource = + (InputT) SerializableUtils.deserializeFromByteArray(bytes, definition.toString()); + runReadLoop(WindowedValue.valueInGlobalWindow(boundedSource)); + } catch (InvalidProtocolBufferException e) { + throw new IOException( + String.format("Failed to decode %s, expected %s", + definition.getData().getTypeUrl(), BytesValue.getDescriptor().getFullName()), + e); + } + } + + /** + * Creates a {@link Reader} for each {@link BoundedSource} and executes the {@link Reader}s + * read loop. See {@link Reader} for further details of the read loop. + * + * <p>Propagates any exceptions caused during reading or processing via a consumer to the + * caller. + */ + public void runReadLoop(WindowedValue<InputT> value) throws Exception { + try (Reader<OutputT> reader = value.getValue().createReader(pipelineOptions)) { + if (!reader.start()) { + // Reader has no data, immediately return + return; + } + do { + // TODO: Should this use the input window as the window for all the outputs? + WindowedValue<OutputT> nextValue = WindowedValue.timestampedValueInGlobalWindow( + reader.getCurrent(), reader.getCurrentTimestamp()); + for (ThrowingConsumer<WindowedValue<OutputT>> consumer : consumers) { + consumer.accept(nextValue); + } + } while (reader.advance()); + } + } + + @Override + public String toString() { + return definition.toString(); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/main/java/org/apache/beam/runners/core/package-info.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/package-info.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/package-info.java new file mode 100644 index 0000000..d250a6a --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Provides utilities for Beam runner authors. + */ +package org.apache.beam.runners.core; http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnHarnessTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnHarnessTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnHarnessTest.java new file mode 100644 index 0000000..ff05225 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnHarnessTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness; + +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.assertThat; + +import com.google.common.util.concurrent.Uninterruptibles; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.stub.StreamObserver; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.function.Consumer; +import org.apache.beam.fn.harness.test.TestStreams; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.fn.v1.BeamFnApi.InstructionRequest; +import org.apache.beam.fn.v1.BeamFnApi.InstructionResponse; +import org.apache.beam.fn.v1.BeamFnApi.LogControl; +import org.apache.beam.fn.v1.BeamFnControlGrpc; +import org.apache.beam.fn.v1.BeamFnLoggingGrpc; +import org.apache.beam.sdk.options.GcsOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link FnHarness}. */ +@RunWith(JUnit4.class) +public class FnHarnessTest { + private static final BeamFnApi.InstructionRequest INSTRUCTION_REQUEST = + BeamFnApi.InstructionRequest.newBuilder() + .setInstructionId(999L) + .setRegister(BeamFnApi.RegisterRequest.getDefaultInstance()) + .build(); + private static final BeamFnApi.InstructionResponse INSTRUCTION_RESPONSE = + BeamFnApi.InstructionResponse.newBuilder() + .setInstructionId(999L) + .setRegister(BeamFnApi.RegisterResponse.getDefaultInstance()) + .build(); + + @Test + public void testLaunchFnHarnessAndTeardownCleanly() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + + List<BeamFnApi.LogEntry> logEntries = new ArrayList<>(); + List<BeamFnApi.InstructionResponse> instructionResponses = new ArrayList<>(); + + BeamFnLoggingGrpc.BeamFnLoggingImplBase loggingService = + new BeamFnLoggingGrpc.BeamFnLoggingImplBase() { + @Override + public StreamObserver<BeamFnApi.LogEntry.List> logging( + StreamObserver<LogControl> responseObserver) { + return TestStreams.withOnNext( + (BeamFnApi.LogEntry.List entries) -> logEntries.addAll(entries.getLogEntriesList())) + .withOnCompleted(() -> responseObserver.onCompleted()) + .build(); + } + }; + + BeamFnControlGrpc.BeamFnControlImplBase controlService = + new BeamFnControlGrpc.BeamFnControlImplBase() { + @Override + public StreamObserver<InstructionResponse> control( + StreamObserver<InstructionRequest> responseObserver) { + CountDownLatch waitForResponses = new CountDownLatch(1 /* number of responses expected */); + options.as(GcsOptions.class).getExecutorService().submit(new Runnable() { + @Override + public void run() { + responseObserver.onNext(INSTRUCTION_REQUEST); + Uninterruptibles.awaitUninterruptibly(waitForResponses); + responseObserver.onCompleted(); + } + }); + return TestStreams.withOnNext(new Consumer<BeamFnApi.InstructionResponse>() { + @Override + public void accept(InstructionResponse t) { + instructionResponses.add(t); + waitForResponses.countDown(); + } + }).withOnCompleted(waitForResponses::countDown).build(); + } + }; + + Server loggingServer = ServerBuilder.forPort(0).addService(loggingService).build(); + loggingServer.start(); + try { + Server controlServer = ServerBuilder.forPort(0).addService(controlService).build(); + controlServer.start(); + try { + BeamFnApi.ApiServiceDescriptor loggingDescriptor = BeamFnApi.ApiServiceDescriptor + .newBuilder() + .setId(1L) + .setUrl("localhost:" + loggingServer.getPort()) + .build(); + BeamFnApi.ApiServiceDescriptor controlDescriptor = BeamFnApi.ApiServiceDescriptor + .newBuilder() + .setId(2L) + .setUrl("localhost:" + controlServer.getPort()) + .build(); + + FnHarness.main(options, loggingDescriptor, controlDescriptor); + assertThat(instructionResponses, contains(INSTRUCTION_RESPONSE)); + } finally { + controlServer.shutdownNow(); + } + } finally { + loggingServer.shutdownNow(); + } + } +} + http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/channel/ManagedChannelFactoryTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/channel/ManagedChannelFactoryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/channel/ManagedChannelFactoryTest.java new file mode 100644 index 0000000..9f634c9 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/channel/ManagedChannelFactoryTest.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.channel; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assume.assumeTrue; + +import io.grpc.ManagedChannel; +import org.apache.beam.fn.v1.BeamFnApi.ApiServiceDescriptor; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link ManagedChannelFactory}. */ +@RunWith(JUnit4.class) +public class ManagedChannelFactoryTest { + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Test + public void testDefaultChannel() { + ApiServiceDescriptor apiServiceDescriptor = ApiServiceDescriptor.newBuilder() + .setUrl("localhost:123") + .build(); + ManagedChannel channel = ManagedChannelFactory.from(PipelineOptionsFactory.create()) + .forDescriptor(apiServiceDescriptor); + assertEquals("localhost:123", channel.authority()); + channel.shutdownNow(); + } + + @Test + public void testEpollHostPortChannel() { + assumeTrue(io.netty.channel.epoll.Epoll.isAvailable()); + ApiServiceDescriptor apiServiceDescriptor = ApiServiceDescriptor.newBuilder() + .setUrl("localhost:123") + .build(); + ManagedChannel channel = ManagedChannelFactory.from( + PipelineOptionsFactory.fromArgs(new String[]{ "--experiments=beam_fn_api_epoll" }).create()) + .forDescriptor(apiServiceDescriptor); + assertEquals("localhost:123", channel.authority()); + channel.shutdownNow(); + } + + @Test + public void testEpollDomainSocketChannel() throws Exception { + assumeTrue(io.netty.channel.epoll.Epoll.isAvailable()); + ApiServiceDescriptor apiServiceDescriptor = ApiServiceDescriptor.newBuilder() + .setUrl("unix://" + tmpFolder.newFile().getAbsolutePath()) + .build(); + ManagedChannel channel = ManagedChannelFactory.from( + PipelineOptionsFactory.fromArgs(new String[]{ "--experiments=beam_fn_api_epoll" }).create()) + .forDescriptor(apiServiceDescriptor); + assertEquals(apiServiceDescriptor.getUrl().substring("unix://".length()), channel.authority()); + channel.shutdownNow(); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/channel/SocketAddressFactoryTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/channel/SocketAddressFactoryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/channel/SocketAddressFactoryTest.java new file mode 100644 index 0000000..610a8ea --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/channel/SocketAddressFactoryTest.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.channel; + +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import io.netty.channel.unix.DomainSocketAddress; +import java.io.File; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link SocketAddressFactory}. */ +@RunWith(JUnit4.class) +public class SocketAddressFactoryTest { + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Test + public void testHostPortSocket() { + SocketAddress socketAddress = SocketAddressFactory.createFrom("localhost:123"); + assertThat(socketAddress, instanceOf(InetSocketAddress.class)); + assertEquals("localhost", ((InetSocketAddress) socketAddress).getHostString()); + assertEquals(123, ((InetSocketAddress) socketAddress).getPort()); + } + + @Test + public void testDomainSocket() throws Exception { + File tmpFile = tmpFolder.newFile(); + SocketAddress socketAddress = SocketAddressFactory.createFrom( + "unix://" + tmpFile.getAbsolutePath()); + assertThat(socketAddress, instanceOf(DomainSocketAddress.class)); + assertEquals(tmpFile.getAbsolutePath(), ((DomainSocketAddress) socketAddress).path()); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BeamFnControlClientTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BeamFnControlClientTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BeamFnControlClientTest.java new file mode 100644 index 0000000..fc3af49 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BeamFnControlClientTest.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.control; + +import static com.google.common.base.Throwables.getStackTraceAsString; +import static org.junit.Assert.assertEquals; + +import com.google.common.util.concurrent.Uninterruptibles; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.CallStreamObserver; +import io.grpc.stub.StreamObserver; +import java.util.EnumMap; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import org.apache.beam.fn.harness.fn.ThrowingFunction; +import org.apache.beam.fn.harness.test.TestStreams; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.fn.v1.BeamFnControlGrpc; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link BeamFnControlClient}. */ +@RunWith(JUnit4.class) +public class BeamFnControlClientTest { + private static final BeamFnApi.InstructionRequest SUCCESSFUL_REQUEST = + BeamFnApi.InstructionRequest.newBuilder() + .setInstructionId(1L) + .setProcessBundle(BeamFnApi.ProcessBundleRequest.getDefaultInstance()) + .build(); + private static final BeamFnApi.InstructionResponse SUCCESSFUL_RESPONSE = + BeamFnApi.InstructionResponse.newBuilder() + .setInstructionId(1L) + .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance()) + .build(); + private static final BeamFnApi.InstructionRequest UNKNOWN_HANDLER_REQUEST = + BeamFnApi.InstructionRequest.newBuilder() + .setInstructionId(2L) + .build(); + private static final BeamFnApi.InstructionResponse UNKNOWN_HANDLER_RESPONSE = + BeamFnApi.InstructionResponse.newBuilder() + .setInstructionId(2L) + .setError("Unknown InstructionRequest type " + + BeamFnApi.InstructionRequest.RequestCase.REQUEST_NOT_SET) + .build(); + private static final RuntimeException FAILURE = new RuntimeException("TestFailure"); + private static final BeamFnApi.InstructionRequest FAILURE_REQUEST = + BeamFnApi.InstructionRequest.newBuilder() + .setInstructionId(3L) + .setRegister(BeamFnApi.RegisterRequest.getDefaultInstance()) + .build(); + private static final BeamFnApi.InstructionResponse FAILURE_RESPONSE = + BeamFnApi.InstructionResponse.newBuilder() + .setInstructionId(3L) + .setError(getStackTraceAsString(FAILURE)) + .build(); + + @Test + public void testDelegation() throws Exception { + AtomicBoolean clientClosedStream = new AtomicBoolean(); + BlockingQueue<BeamFnApi.InstructionResponse> values = new LinkedBlockingQueue<>(); + BlockingQueue<StreamObserver<BeamFnApi.InstructionRequest>> outboundServerObservers = + new LinkedBlockingQueue<>(); + CallStreamObserver<BeamFnApi.InstructionResponse> inboundServerObserver = + TestStreams.withOnNext(values::add) + .withOnCompleted(() -> clientClosedStream.set(true)).build(); + + BeamFnApi.ApiServiceDescriptor apiServiceDescriptor = + BeamFnApi.ApiServiceDescriptor.newBuilder() + .setUrl(this.getClass().getName() + "-" + UUID.randomUUID().toString()) + .build(); + Server server = InProcessServerBuilder.forName(apiServiceDescriptor.getUrl()) + .addService(new BeamFnControlGrpc.BeamFnControlImplBase() { + @Override + public StreamObserver<BeamFnApi.InstructionResponse> control( + StreamObserver<BeamFnApi.InstructionRequest> outboundObserver) { + Uninterruptibles.putUninterruptibly(outboundServerObservers, outboundObserver); + return inboundServerObserver; + } + }) + .build(); + server.start(); + try { + ManagedChannel channel = + InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build(); + + EnumMap<BeamFnApi.InstructionRequest.RequestCase, + ThrowingFunction<BeamFnApi.InstructionRequest, + BeamFnApi.InstructionResponse.Builder>> handlers = + new EnumMap<>(BeamFnApi.InstructionRequest.RequestCase.class); + handlers.put(BeamFnApi.InstructionRequest.RequestCase.PROCESS_BUNDLE, + new ThrowingFunction<BeamFnApi.InstructionRequest, + BeamFnApi.InstructionResponse.Builder>() { + @Override + public BeamFnApi.InstructionResponse.Builder apply(BeamFnApi.InstructionRequest value) + throws Exception { + return BeamFnApi.InstructionResponse.newBuilder() + .setProcessBundle(BeamFnApi.ProcessBundleResponse.getDefaultInstance()); + } + }); + handlers.put(BeamFnApi.InstructionRequest.RequestCase.REGISTER, + new ThrowingFunction<BeamFnApi.InstructionRequest, + BeamFnApi.InstructionResponse.Builder>() { + @Override + public BeamFnApi.InstructionResponse.Builder apply(BeamFnApi.InstructionRequest value) + throws Exception { + throw FAILURE; + } + }); + + BeamFnControlClient client = new BeamFnControlClient( + apiServiceDescriptor, + (BeamFnApi.ApiServiceDescriptor descriptor) -> channel, + this::createStreamForTest, + handlers); + + // Get the connected client and attempt to send and receive an instruction + StreamObserver<BeamFnApi.InstructionRequest> outboundServerObserver = + outboundServerObservers.take(); + + ExecutorService executor = Executors.newCachedThreadPool(); + Future<Void> future = executor.submit(new Callable<Void>() { + @Override + public Void call() throws Exception { + client.processInstructionRequests(executor); + return null; + } + }); + + outboundServerObserver.onNext(SUCCESSFUL_REQUEST); + assertEquals(SUCCESSFUL_RESPONSE, values.take()); + + // Ensure that conversion of an unknown request type is properly converted to a + // failure response. + outboundServerObserver.onNext(UNKNOWN_HANDLER_REQUEST); + assertEquals(UNKNOWN_HANDLER_RESPONSE, values.take()); + + // Ensure that all exceptions are caught and translated to failures + outboundServerObserver.onNext(FAILURE_REQUEST); + assertEquals(FAILURE_RESPONSE, values.take()); + + // Ensure that the server completing the stream translates to the completable future + // being completed allowing for a successful shutdown of the client. + outboundServerObserver.onCompleted(); + future.get(); + } finally { + server.shutdownNow(); + } + } + + private <ReqT, RespT> StreamObserver<RespT> createStreamForTest( + Function<StreamObserver<ReqT>, StreamObserver<RespT>> clientFactory, + StreamObserver<ReqT> handler) { + return clientFactory.apply(handler); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java new file mode 100644 index 0000000..1d451b5 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -0,0 +1,674 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.control; + +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Suppliers; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import com.google.protobuf.Message; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import org.apache.beam.fn.harness.data.BeamFnDataClient; +import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.runners.dataflow.util.DoFnInfo; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.CountingSource; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Matchers; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** Tests for {@link ProcessBundleHandler}. */ +@RunWith(JUnit4.class) +public class ProcessBundleHandlerTest { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final Coder<WindowedValue<String>> STRING_CODER = + WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); + private static final long LONG_CODER_SPEC_ID = 998L; + private static final long STRING_CODER_SPEC_ID = 999L; + private static final BeamFnApi.RemoteGrpcPort REMOTE_PORT = BeamFnApi.RemoteGrpcPort.newBuilder() + .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.newBuilder() + .setId(58L) + .setUrl("TestUrl")) + .build(); + private static final BeamFnApi.Coder LONG_CODER_SPEC; + private static final BeamFnApi.Coder STRING_CODER_SPEC; + static { + try { + STRING_CODER_SPEC = + BeamFnApi.Coder.newBuilder().setFunctionSpec(BeamFnApi.FunctionSpec.newBuilder() + .setId(STRING_CODER_SPEC_ID) + .setData(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes(STRING_CODER.asCloudObject()))).build()))) + .build(); + LONG_CODER_SPEC = + BeamFnApi.Coder.newBuilder().setFunctionSpec(BeamFnApi.FunctionSpec.newBuilder() + .setId(STRING_CODER_SPEC_ID) + .setData(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes(WindowedValue.getFullCoder( + VarLongCoder.of(), GlobalWindow.Coder.INSTANCE).asCloudObject()))).build()))) + .build(); + } catch (IOException e) { + throw new ExceptionInInitializerError(e); + } + } + + private static final String DATA_INPUT_URN = "urn:org.apache.beam:source:runner:0.1"; + private static final String DATA_OUTPUT_URN = "urn:org.apache.beam:sink:runner:0.1"; + private static final String JAVA_DO_FN_URN = "urn:org.apache.beam:dofn:java:0.1"; + private static final String JAVA_SOURCE_URN = "urn:org.apache.beam:source:java:0.1"; + + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Mock private BeamFnDataClient beamFnDataClient; + @Captor private ArgumentCaptor<ThrowingConsumer<WindowedValue<String>>> consumerCaptor; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testOrderOfStartAndFinishCalls() throws Exception { + BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = + BeamFnApi.ProcessBundleDescriptor.newBuilder() + .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId(2L)) + .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId(3L)) + .build(); + Map<Long, Message> fnApiRegistry = ImmutableMap.of(1L, processBundleDescriptor); + + List<BeamFnApi.PrimitiveTransform> transformsProcessed = new ArrayList<>(); + List<String> orderOfOperations = new ArrayList<>(); + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient) { + @Override + protected <InputT, OutputT> void createConsumersForPrimitiveTransform( + BeamFnApi.PrimitiveTransform primitiveTransform, + Supplier<Long> processBundleInstructionId, + Function<BeamFnApi.Target, + Collection<ThrowingConsumer<WindowedValue<OutputT>>>> consumers, + BiConsumer<BeamFnApi.Target, ThrowingConsumer<WindowedValue<InputT>>> addConsumer, + Consumer<ThrowingRunnable> addStartFunction, + Consumer<ThrowingRunnable> addFinishFunction) + throws IOException { + + assertEquals((Long) 999L, processBundleInstructionId.get()); + + transformsProcessed.add(primitiveTransform); + addStartFunction.accept( + () -> orderOfOperations.add("Start" + primitiveTransform.getId())); + addFinishFunction.accept( + () -> orderOfOperations.add("Finish" + primitiveTransform.getId())); + } + }; + handler.processBundle(BeamFnApi.InstructionRequest.newBuilder() + .setInstructionId(999L) + .setProcessBundle( + BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference(1L)) + .build()); + + // Processing of primitive transforms is performed in reverse order. + assertThat(transformsProcessed, contains( + processBundleDescriptor.getPrimitiveTransform(1), + processBundleDescriptor.getPrimitiveTransform(0))); + // Start should occur in reverse order while finish calls should occur in forward order + assertThat(orderOfOperations, contains("Start3", "Start2", "Finish2", "Finish3")); + } + + @Test + public void testCreatingPrimitiveTransformExceptionsArePropagated() throws Exception { + BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = + BeamFnApi.ProcessBundleDescriptor.newBuilder() + .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId(2L)) + .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId(3L)) + .build(); + Map<Long, Message> fnApiRegistry = ImmutableMap.of(1L, processBundleDescriptor); + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient) { + @Override + protected <InputT, OutputT> void createConsumersForPrimitiveTransform( + BeamFnApi.PrimitiveTransform primitiveTransform, + Supplier<Long> processBundleInstructionId, + Function<BeamFnApi.Target, + Collection<ThrowingConsumer<WindowedValue<OutputT>>>> consumers, + BiConsumer<BeamFnApi.Target, ThrowingConsumer<WindowedValue<InputT>>> addConsumer, + Consumer<ThrowingRunnable> addStartFunction, + Consumer<ThrowingRunnable> addFinishFunction) + throws IOException { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("TestException"); + throw new IllegalStateException("TestException"); + } + }; + handler.processBundle( + BeamFnApi.InstructionRequest.newBuilder().setProcessBundle( + BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference(1L)) + .build()); + } + + @Test + public void testPrimitiveTransformStartExceptionsArePropagated() throws Exception { + BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = + BeamFnApi.ProcessBundleDescriptor.newBuilder() + .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId(2L)) + .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId(3L)) + .build(); + Map<Long, Message> fnApiRegistry = ImmutableMap.of(1L, processBundleDescriptor); + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient) { + @Override + protected <InputT, OutputT> void createConsumersForPrimitiveTransform( + BeamFnApi.PrimitiveTransform primitiveTransform, + Supplier<Long> processBundleInstructionId, + Function<BeamFnApi.Target, + Collection<ThrowingConsumer<WindowedValue<OutputT>>>> consumers, + BiConsumer<BeamFnApi.Target, ThrowingConsumer<WindowedValue<InputT>>> addConsumer, + Consumer<ThrowingRunnable> addStartFunction, + Consumer<ThrowingRunnable> addFinishFunction) + throws IOException { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("TestException"); + addStartFunction.accept(this::throwException); + } + + private void throwException() { + throw new IllegalStateException("TestException"); + } + }; + handler.processBundle( + BeamFnApi.InstructionRequest.newBuilder().setProcessBundle( + BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference(1L)) + .build()); + } + + @Test + public void testPrimitiveTransformFinishExceptionsArePropagated() throws Exception { + BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = + BeamFnApi.ProcessBundleDescriptor.newBuilder() + .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId(2L)) + .addPrimitiveTransform(BeamFnApi.PrimitiveTransform.newBuilder().setId(3L)) + .build(); + Map<Long, Message> fnApiRegistry = ImmutableMap.of(1L, processBundleDescriptor); + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient) { + @Override + protected <InputT, OutputT> void createConsumersForPrimitiveTransform( + BeamFnApi.PrimitiveTransform primitiveTransform, + Supplier<Long> processBundleInstructionId, + Function<BeamFnApi.Target, + Collection<ThrowingConsumer<WindowedValue<OutputT>>>> consumers, + BiConsumer<BeamFnApi.Target, ThrowingConsumer<WindowedValue<InputT>>> addConsumer, + Consumer<ThrowingRunnable> addStartFunction, + Consumer<ThrowingRunnable> addFinishFunction) + throws IOException { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("TestException"); + addFinishFunction.accept(this::throwException); + } + + private void throwException() { + throw new IllegalStateException("TestException"); + } + }; + handler.processBundle( + BeamFnApi.InstructionRequest.newBuilder().setProcessBundle( + BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference(1L)) + .build()); + } + + private static class TestDoFn extends DoFn<String, String> { + private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput"); + private static final TupleTag<String> sideOutput = new TupleTag<>("sideOutput"); + + @StartBundle + public void startBundle(Context context) { + context.output("StartBundle"); + } + + @ProcessElement + public void processElement(ProcessContext context) { + context.output("MainOutput" + context.element()); + context.sideOutput(sideOutput, "SideOutput" + context.element()); + } + + @FinishBundle + public void finishBundle(Context context) { + context.output("FinishBundle"); + } + } + + /** + * Create a DoFn that has 3 inputs (inputATarget1, inputATarget2, inputBTarget) and 2 outputs + * (mainOutput, sideOutput). Validate that inputs are fed to the {@link DoFn} and that outputs + * are directed to the correct consumers. + */ + @Test + public void testCreatingAndProcessingDoFn() throws Exception { + Map<Long, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); + long primitiveTransformId = 100L; + long mainOutputId = 101L; + long sideOutputId = 102L; + + DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn( + new TestDoFn(), + WindowingStrategy.globalDefault(), + ImmutableList.of(), + STRING_CODER, + mainOutputId, + ImmutableMap.of( + mainOutputId, TestDoFn.mainOutput, + sideOutputId, TestDoFn.sideOutput)); + BeamFnApi.FunctionSpec functionSpec = BeamFnApi.FunctionSpec.newBuilder() + .setId(1L) + .setUrn(JAVA_DO_FN_URN) + .setData(Any.pack(BytesValue.newBuilder() + .setValue(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo))) + .build())) + .build(); + BeamFnApi.Target inputATarget1 = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(1000L) + .setName("inputATarget1") + .build(); + BeamFnApi.Target inputATarget2 = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(1001L) + .setName("inputATarget1") + .build(); + BeamFnApi.Target inputBTarget = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(1002L) + .setName("inputBTarget") + .build(); + BeamFnApi.PrimitiveTransform primitiveTransform = BeamFnApi.PrimitiveTransform.newBuilder() + .setId(primitiveTransformId) + .setFunctionSpec(functionSpec) + .putInputs("inputA", BeamFnApi.Target.List.newBuilder() + .addTarget(inputATarget1) + .addTarget(inputATarget2) + .build()) + .putInputs("inputB", BeamFnApi.Target.List.newBuilder() + .addTarget(inputBTarget) + .build()) + .putOutputs(Long.toString(mainOutputId), BeamFnApi.PCollection.newBuilder() + .setCoderReference(STRING_CODER_SPEC_ID) + .build()) + .putOutputs(Long.toString(sideOutputId), BeamFnApi.PCollection.newBuilder() + .setCoderReference(STRING_CODER_SPEC_ID) + .build()) + .build(); + + List<WindowedValue<String>> mainOutputValues = new ArrayList<>(); + List<WindowedValue<String>> sideOutputValues = new ArrayList<>(); + BeamFnApi.Target mainOutputTarget = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(primitiveTransformId) + .setName(Long.toString(mainOutputId)) + .build(); + BeamFnApi.Target sideOutputTarget = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(primitiveTransformId) + .setName(Long.toString(sideOutputId)) + .build(); + Multimap<BeamFnApi.Target, ThrowingConsumer<WindowedValue<String>>> existingConsumers = + ImmutableMultimap.of( + mainOutputTarget, mainOutputValues::add, + sideOutputTarget, sideOutputValues::add); + Multimap<BeamFnApi.Target, ThrowingConsumer<WindowedValue<String>>> newConsumers = + HashMultimap.create(); + List<ThrowingRunnable> startFunctions = new ArrayList<>(); + List<ThrowingRunnable> finishFunctions = new ArrayList<>(); + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient); + handler.createConsumersForPrimitiveTransform( + primitiveTransform, + Suppliers.ofInstance(57L)::get, + existingConsumers::get, + newConsumers::put, + startFunctions::add, + finishFunctions::add); + + Iterables.getOnlyElement(startFunctions).run(); + assertThat(mainOutputValues, contains(valueInGlobalWindow("StartBundle"))); + mainOutputValues.clear(); + + assertEquals(newConsumers.keySet(), + ImmutableSet.of(inputATarget1, inputATarget2, inputBTarget)); + + Iterables.getOnlyElement(newConsumers.get(inputATarget1)).accept(valueInGlobalWindow("A1")); + Iterables.getOnlyElement(newConsumers.get(inputATarget1)).accept(valueInGlobalWindow("A2")); + Iterables.getOnlyElement(newConsumers.get(inputATarget1)).accept(valueInGlobalWindow("B")); + assertThat(mainOutputValues, contains( + valueInGlobalWindow("MainOutputA1"), + valueInGlobalWindow("MainOutputA2"), + valueInGlobalWindow("MainOutputB"))); + assertThat(sideOutputValues, contains( + valueInGlobalWindow("SideOutputA1"), + valueInGlobalWindow("SideOutputA2"), + valueInGlobalWindow("SideOutputB"))); + mainOutputValues.clear(); + sideOutputValues.clear(); + + Iterables.getOnlyElement(finishFunctions).run(); + assertThat(mainOutputValues, contains(valueInGlobalWindow("FinishBundle"))); + mainOutputValues.clear(); + } + + @Test + public void testCreatingAndProcessingSource() throws Exception { + Map<Long, Message> fnApiRegistry = ImmutableMap.of(LONG_CODER_SPEC_ID, LONG_CODER_SPEC); + long primitiveTransformId = 100L; + long outputId = 101L; + + BeamFnApi.Target inputTarget = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(1000L) + .setName("inputTarget") + .build(); + + List<WindowedValue<String>> outputValues = new ArrayList<>(); + BeamFnApi.Target outputTarget = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(primitiveTransformId) + .setName(Long.toString(outputId)) + .build(); + + Multimap<BeamFnApi.Target, ThrowingConsumer<WindowedValue<String>>> existingConsumers = + ImmutableMultimap.of(outputTarget, outputValues::add); + Multimap<BeamFnApi.Target, + ThrowingConsumer<WindowedValue<BoundedSource<Long>>>> newConsumers = + HashMultimap.create(); + List<ThrowingRunnable> startFunctions = new ArrayList<>(); + List<ThrowingRunnable> finishFunctions = new ArrayList<>(); + + BeamFnApi.FunctionSpec functionSpec = BeamFnApi.FunctionSpec.newBuilder() + .setId(1L) + .setUrn(JAVA_SOURCE_URN) + .setData(Any.pack(BytesValue.newBuilder() + .setValue(ByteString.copyFrom( + SerializableUtils.serializeToByteArray(CountingSource.upTo(3)))) + .build())) + .build(); + + BeamFnApi.PrimitiveTransform primitiveTransform = BeamFnApi.PrimitiveTransform.newBuilder() + .setId(primitiveTransformId) + .setFunctionSpec(functionSpec) + .putInputs("input", + BeamFnApi.Target.List.newBuilder().addTarget(inputTarget).build()) + .putOutputs(Long.toString(outputId), + BeamFnApi.PCollection.newBuilder().setCoderReference(LONG_CODER_SPEC_ID).build()) + .build(); + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient); + + handler.createConsumersForPrimitiveTransform( + primitiveTransform, + Suppliers.ofInstance(57L)::get, + existingConsumers::get, + newConsumers::put, + startFunctions::add, + finishFunctions::add); + + // This is testing a deprecated way of running sources and should be removed + // once all source definitions are instead propagated along the input edge. + Iterables.getOnlyElement(startFunctions).run(); + assertThat(outputValues, contains( + valueInGlobalWindow(0L), + valueInGlobalWindow(1L), + valueInGlobalWindow(2L))); + outputValues.clear(); + + // Check that when passing a source along as an input, the source is processed. + assertEquals(newConsumers.keySet(), ImmutableSet.of(inputTarget)); + Iterables.getOnlyElement(newConsumers.get(inputTarget)).accept( + valueInGlobalWindow(CountingSource.upTo(2))); + assertThat(outputValues, contains( + valueInGlobalWindow(0L), + valueInGlobalWindow(1L))); + + assertThat(finishFunctions, empty()); + } + + @Test + public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { + Map<Long, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); + long bundleId = 57L; + long primitiveTransformId = 100L; + long outputId = 101L; + + List<WindowedValue<String>> outputValues = new ArrayList<>(); + BeamFnApi.Target outputTarget = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(primitiveTransformId) + .setName(Long.toString(outputId)) + .build(); + + Multimap<BeamFnApi.Target, ThrowingConsumer<WindowedValue<String>>> existingConsumers = + ImmutableMultimap.of(outputTarget, outputValues::add); + Multimap<BeamFnApi.Target, ThrowingConsumer<WindowedValue<String>>> newConsumers = + HashMultimap.create(); + List<ThrowingRunnable> startFunctions = new ArrayList<>(); + List<ThrowingRunnable> finishFunctions = new ArrayList<>(); + + BeamFnApi.FunctionSpec functionSpec = BeamFnApi.FunctionSpec.newBuilder() + .setId(1L) + .setUrn(DATA_INPUT_URN) + .setData(Any.pack(REMOTE_PORT)) + .build(); + + BeamFnApi.PrimitiveTransform primitiveTransform = BeamFnApi.PrimitiveTransform.newBuilder() + .setId(primitiveTransformId) + .setFunctionSpec(functionSpec) + .putInputs("input", BeamFnApi.Target.List.getDefaultInstance()) + .putOutputs(Long.toString(outputId), + BeamFnApi.PCollection.newBuilder().setCoderReference(STRING_CODER_SPEC_ID).build()) + .build(); + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient); + + handler.createConsumersForPrimitiveTransform( + primitiveTransform, + Suppliers.ofInstance(bundleId)::get, + existingConsumers::get, + newConsumers::put, + startFunctions::add, + finishFunctions::add); + + verifyZeroInteractions(beamFnDataClient); + + CompletableFuture<Void> completionFuture = new CompletableFuture<>(); + when(beamFnDataClient.forInboundConsumer(any(), any(), any(), any())) + .thenReturn(completionFuture); + Iterables.getOnlyElement(startFunctions).run(); + verify(beamFnDataClient).forInboundConsumer( + eq(REMOTE_PORT.getApiServiceDescriptor()), + eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(primitiveTransformId) + .setName("input") + .build())), + eq(STRING_CODER), + consumerCaptor.capture()); + + consumerCaptor.getValue().accept(valueInGlobalWindow("TestValue")); + assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); + outputValues.clear(); + + assertThat(newConsumers.keySet(), empty()); + + completionFuture.complete(null); + Iterables.getOnlyElement(finishFunctions).run(); + + verifyNoMoreInteractions(beamFnDataClient); + } + + @Test + public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception { + Map<Long, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); + long bundleId = 57L; + long primitiveTransformId = 100L; + long outputId = 101L; + + BeamFnApi.Target inputTarget = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(1000L) + .setName("inputTarget") + .build(); + + Multimap<BeamFnApi.Target, ThrowingConsumer<WindowedValue<String>>> existingConsumers = + ImmutableMultimap.of(); + Multimap<BeamFnApi.Target, ThrowingConsumer<WindowedValue<String>>> newConsumers = + HashMultimap.create(); + List<ThrowingRunnable> startFunctions = new ArrayList<>(); + List<ThrowingRunnable> finishFunctions = new ArrayList<>(); + + BeamFnApi.FunctionSpec functionSpec = BeamFnApi.FunctionSpec.newBuilder() + .setId(1L) + .setUrn(DATA_OUTPUT_URN) + .setData(Any.pack(REMOTE_PORT)) + .build(); + + BeamFnApi.PrimitiveTransform primitiveTransform = BeamFnApi.PrimitiveTransform.newBuilder() + .setId(primitiveTransformId) + .setFunctionSpec(functionSpec) + .putInputs("input", BeamFnApi.Target.List.newBuilder().addTarget(inputTarget).build()) + .putOutputs(Long.toString(outputId), + BeamFnApi.PCollection.newBuilder().setCoderReference(STRING_CODER_SPEC_ID).build()) + .build(); + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient); + + handler.createConsumersForPrimitiveTransform( + primitiveTransform, + Suppliers.ofInstance(bundleId)::get, + existingConsumers::get, + newConsumers::put, + startFunctions::add, + finishFunctions::add); + + verifyZeroInteractions(beamFnDataClient); + + List<WindowedValue<String>> outputValues = new ArrayList<>(); + AtomicBoolean wasCloseCalled = new AtomicBoolean(); + CloseableThrowingConsumer<WindowedValue<String>> outputConsumer = + new CloseableThrowingConsumer<WindowedValue<String>>(){ + @Override + public void close() throws Exception { + wasCloseCalled.set(true); + } + + @Override + public void accept(WindowedValue<String> t) throws Exception { + outputValues.add(t); + } + }; + + when(beamFnDataClient.forOutboundConsumer( + any(), + any(), + Matchers.<Coder<WindowedValue<String>>>any())).thenReturn(outputConsumer); + Iterables.getOnlyElement(startFunctions).run(); + verify(beamFnDataClient).forOutboundConsumer( + eq(REMOTE_PORT.getApiServiceDescriptor()), + eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference(primitiveTransformId) + .setName(Long.toString(outputId)) + .build())), + eq(STRING_CODER)); + + assertEquals(newConsumers.keySet(), ImmutableSet.of(inputTarget)); + Iterables.getOnlyElement(newConsumers.get(inputTarget)).accept( + valueInGlobalWindow("TestValue")); + assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); + outputValues.clear(); + + assertFalse(wasCloseCalled.get()); + Iterables.getOnlyElement(finishFunctions).run(); + assertTrue(wasCloseCalled.get()); + + verifyNoMoreInteractions(beamFnDataClient); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java new file mode 100644 index 0000000..7b07a08 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness.control; + +import static org.junit.Assert.assertEquals; + +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import org.apache.beam.fn.harness.test.TestExecutors; +import org.apache.beam.fn.harness.test.TestExecutors.TestExecutorService; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.fn.v1.BeamFnApi.RegisterResponse; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link RegisterHandler}. */ +@RunWith(JUnit4.class) +public class RegisterHandlerTest { + @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool); + + private static final BeamFnApi.InstructionRequest REGISTER_REQUEST = + BeamFnApi.InstructionRequest.newBuilder() + .setInstructionId(1L) + .setRegister(BeamFnApi.RegisterRequest.newBuilder() + .addProcessBundleDescriptor(BeamFnApi.ProcessBundleDescriptor.newBuilder().setId(1L) + .addCoders(BeamFnApi.Coder.newBuilder().setFunctionSpec( + BeamFnApi.FunctionSpec.newBuilder().setId(10L)).build())) + .addProcessBundleDescriptor(BeamFnApi.ProcessBundleDescriptor.newBuilder().setId(2L) + .addCoders(BeamFnApi.Coder.newBuilder().setFunctionSpec( + BeamFnApi.FunctionSpec.newBuilder().setId(20L)).build())) + .build()) + .build(); + private static final BeamFnApi.InstructionResponse REGISTER_RESPONSE = + BeamFnApi.InstructionResponse.newBuilder() + .setRegister(RegisterResponse.getDefaultInstance()) + .build(); + + @Test + public void testRegistration() throws Exception { + RegisterHandler handler = new RegisterHandler(); + Future<BeamFnApi.InstructionResponse> responseFuture = + executor.submit(new Callable<BeamFnApi.InstructionResponse>() { + @Override + public BeamFnApi.InstructionResponse call() throws Exception { + // Purposefully wait a small amount of time making it likely that + // a downstream caller needs to block. + Thread.sleep(100); + return handler.register(REGISTER_REQUEST).build(); + } + }); + assertEquals(REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(0), + handler.getById(1L)); + assertEquals(REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(1), + handler.getById(2L)); + assertEquals(REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(0).getCoders(0), + handler.getById(10L)); + assertEquals(REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(1).getCoders(0), + handler.getById(20L)); + assertEquals(REGISTER_RESPONSE, responseFuture.get()); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/0b4b2bec/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataBufferingOutboundObserverTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataBufferingOutboundObserverTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataBufferingOutboundObserverTest.java new file mode 100644 index 0000000..64a0e11 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataBufferingOutboundObserverTest.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.data; + +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.hamcrest.Matchers.empty; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import com.google.common.collect.Iterables; +import com.google.protobuf.ByteString; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer; +import org.apache.beam.fn.harness.test.TestStreams; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.Coder.Context; +import org.apache.beam.sdk.coders.LengthPrefixCoder; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link BeamFnDataBufferingOutboundObserver}. */ +@RunWith(JUnit4.class) +public class BeamFnDataBufferingOutboundObserverTest { + private static final int DEFAULT_BUFFER_LIMIT = 1_000_000; + private static final KV<Long, BeamFnApi.Target> OUTPUT_LOCATION = KV.of(777L, + BeamFnApi.Target.newBuilder().setPrimitiveTransformReference(555L).setName("Test").build()); + private static final Coder<WindowedValue<byte[]>> CODER = + LengthPrefixCoder.of(WindowedValue.getValueOnlyCoder(ByteArrayCoder.of())); + + @Test + public void testWithDefaultBuffer() throws Exception { + Collection<BeamFnApi.Elements> values = new ArrayList<>(); + AtomicBoolean onCompletedWasCalled = new AtomicBoolean(); + CloseableThrowingConsumer<WindowedValue<byte[]>> consumer = + new BeamFnDataBufferingOutboundObserver<>( + PipelineOptionsFactory.create(), + OUTPUT_LOCATION, + CODER, + TestStreams.withOnNext(values::add) + .withOnCompleted(() -> onCompletedWasCalled.set(true)) + .build()); + + // Test that nothing is emitted till the default buffer size is surpassed. + consumer.accept(valueInGlobalWindow(new byte[DEFAULT_BUFFER_LIMIT - 50])); + assertThat(values, empty()); + + // Test that when we cross the buffer, we emit. + consumer.accept(valueInGlobalWindow(new byte[50])); + assertEquals( + messageWithData(new byte[DEFAULT_BUFFER_LIMIT - 50], new byte[50]), + Iterables.get(values, 0)); + + // Test that nothing is emitted till the default buffer size is surpassed after a reset + consumer.accept(valueInGlobalWindow(new byte[DEFAULT_BUFFER_LIMIT - 50])); + assertEquals(1, values.size()); + + // Test that when we cross the buffer, we emit. + consumer.accept(valueInGlobalWindow(new byte[50])); + assertEquals( + messageWithData(new byte[DEFAULT_BUFFER_LIMIT - 50], new byte[50]), + Iterables.get(values, 1)); + + // Test that when we close with an empty buffer we only have one end of stream + consumer.close(); + assertEquals(messageWithData(), + Iterables.get(values, 2)); + } + + @Test + public void testExperimentConfiguresBufferLimit() throws Exception { + Collection<BeamFnApi.Elements> values = new ArrayList<>(); + AtomicBoolean onCompletedWasCalled = new AtomicBoolean(); + CloseableThrowingConsumer<WindowedValue<byte[]>> consumer = + new BeamFnDataBufferingOutboundObserver<>( + PipelineOptionsFactory.fromArgs( + new String[] { "--experiments=beam_fn_api_data_buffer_limit=100" }).create(), + OUTPUT_LOCATION, + CODER, + TestStreams.withOnNext(values::add) + .withOnCompleted(() -> onCompletedWasCalled.set(true)) + .build()); + + // Test that nothing is emitted till the default buffer size is surpassed. + consumer.accept(valueInGlobalWindow(new byte[51])); + assertThat(values, empty()); + + // Test that when we cross the buffer, we emit. + consumer.accept(valueInGlobalWindow(new byte[49])); + assertEquals( + messageWithData(new byte[51], new byte[49]), + Iterables.get(values, 0)); + + // Test that when we close we empty the value, and then the stream terminator as part + // of the same message + consumer.accept(valueInGlobalWindow(new byte[1])); + consumer.close(); + assertEquals( + BeamFnApi.Elements.newBuilder(messageWithData(new byte[1])) + .addData(BeamFnApi.Elements.Data.newBuilder() + .setInstructionReference(OUTPUT_LOCATION.getKey()) + .setTarget(OUTPUT_LOCATION.getValue())) + .build(), + Iterables.get(values, 1)); + } + + private static BeamFnApi.Elements messageWithData(byte[] ... datum) throws IOException { + ByteString.Output output = ByteString.newOutput(); + for (byte[] data : datum) { + CODER.encode(valueInGlobalWindow(data), output, Context.NESTED); + } + return BeamFnApi.Elements.newBuilder() + .addData(BeamFnApi.Elements.Data.newBuilder() + .setInstructionReference(OUTPUT_LOCATION.getKey()) + .setTarget(OUTPUT_LOCATION.getValue()) + .setData(output.toByteString())) + .build(); + } +}
