http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java deleted file mode 100644 index b3cf3a7..0000000 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java +++ /dev/null @@ -1,547 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.core; - -import static com.google.common.base.Preconditions.checkArgument; - -import com.google.auto.service.AutoService; -import com.google.common.collect.Collections2; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; -import com.google.common.collect.Multimap; -import com.google.protobuf.ByteString; -import com.google.protobuf.BytesValue; -import com.google.protobuf.InvalidProtocolBufferException; -import java.util.Collection; -import java.util.HashSet; -import java.util.Iterator; -import java.util.Map; -import java.util.Objects; -import java.util.function.Consumer; -import java.util.function.Supplier; -import org.apache.beam.fn.harness.data.BeamFnDataClient; -import org.apache.beam.fn.harness.fn.ThrowingConsumer; -import org.apache.beam.fn.harness.fn.ThrowingRunnable; -import org.apache.beam.runners.core.construction.ParDoTranslation; -import org.apache.beam.runners.dataflow.util.DoFnInfo; -import org.apache.beam.sdk.common.runner.v1.RunnerApi; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.state.State; -import org.apache.beam.sdk.state.TimeDomain; -import org.apache.beam.sdk.state.Timer; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.DoFn.OnTimerContext; -import org.apache.beam.sdk.transforms.DoFn.ProcessContext; -import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; -import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; -import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.PaneInfo; -import org.apache.beam.sdk.util.SerializableUtils; -import org.apache.beam.sdk.util.UserCodeException; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.WindowingStrategy; -import org.joda.time.Instant; - -/** - * A {@link DoFnRunner} specific to integrating with the Fn Api. This is to remove the layers - * of abstraction caused by StateInternals/TimerInternals since they model state and timer - * concepts differently. - */ -public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, OutputT> { - /** - * A registrar which provides a factory to handle Java {@link DoFn}s. - */ - @AutoService(PTransformRunnerFactory.Registrar.class) - public static class Registrar implements - PTransformRunnerFactory.Registrar { - - @Override - public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() { - return ImmutableMap.of(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN, new Factory()); - } - } - - /** A factory for {@link FnApiDoFnRunner}. */ - static class Factory<InputT, OutputT> - implements PTransformRunnerFactory<DoFnRunner<InputT, OutputT>> { - - @Override - public DoFnRunner<InputT, OutputT> createRunnerForPTransform( - PipelineOptions pipelineOptions, - BeamFnDataClient beamFnDataClient, - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier<String> processBundleInstructionId, - Map<String, RunnerApi.PCollection> pCollections, - Map<String, RunnerApi.Coder> coders, - Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, - Consumer<ThrowingRunnable> addStartFunction, - Consumer<ThrowingRunnable> addFinishFunction) { - - // For every output PCollection, create a map from output name to Consumer - ImmutableMap.Builder<String, Collection<ThrowingConsumer<WindowedValue<?>>>> - outputMapBuilder = ImmutableMap.builder(); - for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) { - outputMapBuilder.put( - entry.getKey(), - pCollectionIdsToConsumers.get(entry.getValue())); - } - ImmutableMap<String, Collection<ThrowingConsumer<WindowedValue<?>>>> outputMap = - outputMapBuilder.build(); - - // Get the DoFnInfo from the serialized blob. - ByteString serializedFn; - try { - serializedFn = pTransform.getSpec().getParameter().unpack(BytesValue.class).getValue(); - } catch (InvalidProtocolBufferException e) { - throw new IllegalArgumentException( - String.format("Unable to unwrap DoFn %s", pTransform.getSpec()), e); - } - @SuppressWarnings({"unchecked", "rawtypes"}) - DoFnInfo<InputT, OutputT> doFnInfo = (DoFnInfo) SerializableUtils.deserializeFromByteArray( - serializedFn.toByteArray(), "DoFnInfo"); - - // Verify that the DoFnInfo tag to output map matches the output map on the PTransform. - checkArgument( - Objects.equals( - new HashSet<>(Collections2.transform(outputMap.keySet(), Long::parseLong)), - doFnInfo.getOutputMap().keySet()), - "Unexpected mismatch between transform output map %s and DoFnInfo output map %s.", - outputMap.keySet(), - doFnInfo.getOutputMap()); - - ImmutableMultimap.Builder<TupleTag<?>, - ThrowingConsumer<WindowedValue<?>>> tagToOutputMapBuilder = - ImmutableMultimap.builder(); - for (Map.Entry<Long, TupleTag<?>> entry : doFnInfo.getOutputMap().entrySet()) { - @SuppressWarnings({"unchecked", "rawtypes"}) - Collection<ThrowingConsumer<WindowedValue<?>>> consumers = - outputMap.get(Long.toString(entry.getKey())); - tagToOutputMapBuilder.putAll(entry.getValue(), consumers); - } - - ImmutableMultimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> tagToOutputMap = - tagToOutputMapBuilder.build(); - - @SuppressWarnings({"unchecked", "rawtypes"}) - DoFnRunner<InputT, OutputT> runner = new FnApiDoFnRunner<>( - pipelineOptions, - doFnInfo.getDoFn(), - (Collection<ThrowingConsumer<WindowedValue<OutputT>>>) (Collection) - tagToOutputMap.get(doFnInfo.getOutputMap().get(doFnInfo.getMainOutput())), - tagToOutputMap, - doFnInfo.getWindowingStrategy()); - - // Register the appropriate handlers. - addStartFunction.accept(runner::startBundle); - for (String pcollectionId : pTransform.getInputsMap().values()) { - pCollectionIdsToConsumers.put( - pcollectionId, - (ThrowingConsumer) (ThrowingConsumer<WindowedValue<InputT>>) runner::processElement); - } - addFinishFunction.accept(runner::finishBundle); - return runner; - } - } - - ////////////////////////////////////////////////////////////////////////////////////////////////// - - private final PipelineOptions pipelineOptions; - private final DoFn<InputT, OutputT> doFn; - private final Collection<ThrowingConsumer<WindowedValue<OutputT>>> mainOutputConsumers; - private final Multimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> outputMap; - private final DoFnInvoker<InputT, OutputT> doFnInvoker; - private final StartBundleContext startBundleContext; - private final ProcessBundleContext processBundleContext; - private final FinishBundleContext finishBundleContext; - - /** - * The lifetime of this member is only valid during {@link #processElement(WindowedValue)}. - */ - private WindowedValue<InputT> currentElement; - - /** - * The lifetime of this member is only valid during {@link #processElement(WindowedValue)}. - */ - private BoundedWindow currentWindow; - - FnApiDoFnRunner( - PipelineOptions pipelineOptions, - DoFn<InputT, OutputT> doFn, - Collection<ThrowingConsumer<WindowedValue<OutputT>>> mainOutputConsumers, - Multimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> outputMap, - WindowingStrategy windowingStrategy) { - this.pipelineOptions = pipelineOptions; - this.doFn = doFn; - this.mainOutputConsumers = mainOutputConsumers; - this.outputMap = outputMap; - this.doFnInvoker = DoFnInvokers.invokerFor(doFn); - this.startBundleContext = new StartBundleContext(); - this.processBundleContext = new ProcessBundleContext(); - this.finishBundleContext = new FinishBundleContext(); - } - - @Override - public void startBundle() { - doFnInvoker.invokeStartBundle(startBundleContext); - } - - @Override - public void processElement(WindowedValue<InputT> elem) { - currentElement = elem; - try { - Iterator<BoundedWindow> windowIterator = - (Iterator<BoundedWindow>) elem.getWindows().iterator(); - while (windowIterator.hasNext()) { - currentWindow = windowIterator.next(); - doFnInvoker.invokeProcessElement(processBundleContext); - } - } finally { - currentElement = null; - currentWindow = null; - } - } - - @Override - public void onTimer( - String timerId, - BoundedWindow window, - Instant timestamp, - TimeDomain timeDomain) { - throw new UnsupportedOperationException("TODO: Add support for timers"); - } - - @Override - public void finishBundle() { - doFnInvoker.invokeFinishBundle(finishBundleContext); - } - - /** - * Outputs the given element to the specified set of consumers wrapping any exceptions. - */ - private <T> void outputTo( - Collection<ThrowingConsumer<WindowedValue<T>>> consumers, - WindowedValue<T> output) { - Iterator<ThrowingConsumer<WindowedValue<T>>> consumerIterator; - try { - for (ThrowingConsumer<WindowedValue<T>> consumer : consumers) { - consumer.accept(output); - } - } catch (Throwable t) { - throw UserCodeException.wrap(t); - } - } - - /** - * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.StartBundle @StartBundle}. - */ - private class StartBundleContext - extends DoFn<InputT, OutputT>.StartBundleContext - implements DoFnInvoker.ArgumentProvider<InputT, OutputT> { - - private StartBundleContext() { - doFn.super(); - } - - @Override - public PipelineOptions getPipelineOptions() { - return pipelineOptions; - } - - @Override - public PipelineOptions pipelineOptions() { - return pipelineOptions; - } - - @Override - public BoundedWindow window() { - throw new UnsupportedOperationException( - "Cannot access window outside of @ProcessElement and @OnTimer methods."); - } - - @Override - public DoFn<InputT, OutputT>.StartBundleContext startBundleContext( - DoFn<InputT, OutputT> doFn) { - return this; - } - - @Override - public DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext( - DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access FinishBundleContext outside of @FinishBundle method."); - } - - @Override - public DoFn<InputT, OutputT>.ProcessContext processContext(DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access ProcessContext outside of @ProcessElement method."); - } - - @Override - public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access OnTimerContext outside of @OnTimer methods."); - } - - @Override - public RestrictionTracker<?> restrictionTracker() { - throw new UnsupportedOperationException( - "Cannot access RestrictionTracker outside of @ProcessElement method."); - } - - @Override - public State state(String stateId) { - throw new UnsupportedOperationException( - "Cannot access state outside of @ProcessElement and @OnTimer methods."); - } - - @Override - public Timer timer(String timerId) { - throw new UnsupportedOperationException( - "Cannot access timers outside of @ProcessElement and @OnTimer methods."); - } - } - - /** - * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.ProcessElement @ProcessElement}. - */ - private class ProcessBundleContext - extends DoFn<InputT, OutputT>.ProcessContext - implements DoFnInvoker.ArgumentProvider<InputT, OutputT> { - - private ProcessBundleContext() { - doFn.super(); - } - - @Override - public BoundedWindow window() { - return currentWindow; - } - - @Override - public DoFn.StartBundleContext startBundleContext(DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access StartBundleContext outside of @StartBundle method."); - } - - @Override - public DoFn.FinishBundleContext finishBundleContext(DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access FinishBundleContext outside of @FinishBundle method."); - } - - @Override - public ProcessContext processContext(DoFn<InputT, OutputT> doFn) { - return this; - } - - @Override - public OnTimerContext onTimerContext(DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException("TODO: Add support for timers"); - } - - @Override - public RestrictionTracker<?> restrictionTracker() { - throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn"); - } - - @Override - public State state(String stateId) { - throw new UnsupportedOperationException("TODO: Add support for state"); - } - - @Override - public Timer timer(String timerId) { - throw new UnsupportedOperationException("TODO: Add support for timers"); - } - - @Override - public PipelineOptions getPipelineOptions() { - return pipelineOptions; - } - - @Override - public PipelineOptions pipelineOptions() { - return pipelineOptions; - } - - @Override - public void output(OutputT output) { - outputTo(mainOutputConsumers, - WindowedValue.of( - output, - currentElement.getTimestamp(), - currentWindow, - currentElement.getPane())); - } - - @Override - public void outputWithTimestamp(OutputT output, Instant timestamp) { - outputTo(mainOutputConsumers, - WindowedValue.of( - output, - timestamp, - currentWindow, - currentElement.getPane())); - } - - @Override - public <T> void output(TupleTag<T> tag, T output) { - Collection<ThrowingConsumer<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag); - if (consumers == null) { - throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); - } - outputTo(consumers, - WindowedValue.of( - output, - currentElement.getTimestamp(), - currentWindow, - currentElement.getPane())); - } - - @Override - public <T> void outputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) { - Collection<ThrowingConsumer<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag); - if (consumers == null) { - throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); - } - outputTo(consumers, - WindowedValue.of( - output, - timestamp, - currentWindow, - currentElement.getPane())); - } - - @Override - public InputT element() { - return currentElement.getValue(); - } - - @Override - public <T> T sideInput(PCollectionView<T> view) { - throw new UnsupportedOperationException("TODO: Support side inputs"); - } - - @Override - public Instant timestamp() { - return currentElement.getTimestamp(); - } - - @Override - public PaneInfo pane() { - return currentElement.getPane(); - } - - @Override - public void updateWatermark(Instant watermark) { - throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn"); - } - } - - /** - * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.FinishBundle @FinishBundle}. - */ - private class FinishBundleContext - extends DoFn<InputT, OutputT>.FinishBundleContext - implements DoFnInvoker.ArgumentProvider<InputT, OutputT> { - - private FinishBundleContext() { - doFn.super(); - } - - @Override - public PipelineOptions getPipelineOptions() { - return pipelineOptions; - } - - @Override - public PipelineOptions pipelineOptions() { - return pipelineOptions; - } - - @Override - public BoundedWindow window() { - throw new UnsupportedOperationException( - "Cannot access window outside of @ProcessElement and @OnTimer methods."); - } - - @Override - public DoFn<InputT, OutputT>.StartBundleContext startBundleContext( - DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access StartBundleContext outside of @StartBundle method."); - } - - @Override - public DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext( - DoFn<InputT, OutputT> doFn) { - return this; - } - - @Override - public DoFn<InputT, OutputT>.ProcessContext processContext(DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access ProcessContext outside of @ProcessElement method."); - } - - @Override - public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access OnTimerContext outside of @OnTimer methods."); - } - - @Override - public RestrictionTracker<?> restrictionTracker() { - throw new UnsupportedOperationException( - "Cannot access RestrictionTracker outside of @ProcessElement method."); - } - - @Override - public State state(String stateId) { - throw new UnsupportedOperationException( - "Cannot access state outside of @ProcessElement and @OnTimer methods."); - } - - @Override - public Timer timer(String timerId) { - throw new UnsupportedOperationException( - "Cannot access timers outside of @ProcessElement and @OnTimer methods."); - } - - @Override - public void output(OutputT output, Instant timestamp, BoundedWindow window) { - outputTo(mainOutputConsumers, - WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); - } - - @Override - public <T> void output(TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) { - Collection<ThrowingConsumer<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag); - if (consumers == null) { - throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); - } - outputTo(consumers, - WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); - } - } -}
http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java deleted file mode 100644 index b325db4..0000000 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.core; - -import com.google.common.collect.Multimap; -import java.io.IOException; -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Supplier; -import org.apache.beam.fn.harness.data.BeamFnDataClient; -import org.apache.beam.fn.harness.fn.ThrowingConsumer; -import org.apache.beam.fn.harness.fn.ThrowingRunnable; -import org.apache.beam.sdk.common.runner.v1.RunnerApi; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.util.WindowedValue; - -/** - * A factory able to instantiate an appropriate handler for a given PTransform. - */ -public interface PTransformRunnerFactory<T> { - - /** - * Creates and returns a handler for a given PTransform. Note that the handler must support - * processing multiple bundles. The handler will be discarded if an error is thrown during - * element processing, or during execution of start/finish. - * - * @param pipelineOptions Pipeline options - * @param beamFnDataClient - * @param pTransformId The id of the PTransform. - * @param pTransform The PTransform definition. - * @param processBundleInstructionId A supplier containing the active process bundle instruction - * id. - * @param pCollections A mapping from PCollection id to PCollection definition. - * @param coders A mapping from coder id to coder definition. - * @param pCollectionIdsToConsumers A mapping from PCollection id to a collection of consumers. - * Note that if this handler is a consumer, it should register itself within this multimap under - * the appropriate PCollection ids. Also note that all output consumers needed by this PTransform - * (based on the values of the {@link RunnerApi.PTransform#getOutputsMap()} will have already - * registered within this multimap. - * @param addStartFunction A consumer to register a start bundle handler with. - * @param addFinishFunction A consumer to register a finish bundle handler with. - */ - T createRunnerForPTransform( - PipelineOptions pipelineOptions, - BeamFnDataClient beamFnDataClient, - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier<String> processBundleInstructionId, - Map<String, RunnerApi.PCollection> pCollections, - Map<String, RunnerApi.Coder> coders, - Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, - Consumer<ThrowingRunnable> addStartFunction, - Consumer<ThrowingRunnable> addFinishFunction) throws IOException; - - /** - * A registrar which can return a mapping from {@link RunnerApi.FunctionSpec#getUrn()} to - * a factory capable of instantiating an appropriate handler. - */ - interface Registrar { - /** - * Returns a mapping from {@link RunnerApi.FunctionSpec#getUrn()} to a factory capable of - * instantiating an appropriate handler. - */ - Map<String, PTransformRunnerFactory> getPTransformRunnerFactories(); - } -} http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/main/java/org/apache/beam/runners/core/package-info.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/package-info.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/package-info.java deleted file mode 100644 index d250a6a..0000000 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Provides utilities for Beam runner authors. - */ -package org.apache.beam.runners.core; http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java new file mode 100644 index 0000000..a7c6666 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness; + +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Suppliers; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; +import com.google.common.util.concurrent.Uninterruptibles; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.ServiceLoader; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar; +import org.apache.beam.fn.harness.data.BeamFnDataClient; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.fn.harness.test.TestExecutors; +import org.apache.beam.fn.harness.test.TestExecutors.TestExecutorService; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.runners.dataflow.util.CloudObjects; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.hamcrest.collection.IsMapContaining; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** Tests for {@link BeamFnDataReadRunner}. */ +@RunWith(JUnit4.class) +public class BeamFnDataReadRunnerTest { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder() + .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build(); + private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder() + .setParameter(Any.pack(PORT_SPEC)).build(); + private static final Coder<WindowedValue<String>> CODER = + WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); + private static final String CODER_SPEC_ID = "string-coder-id"; + private static final RunnerApi.Coder CODER_SPEC; + private static final String URN = "urn:org.apache.beam:source:runner:0.1"; + + static { + try { + CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec( + RunnerApi.SdkFunctionSpec.newBuilder().setSpec( + RunnerApi.FunctionSpec.newBuilder().setParameter( + Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(CODER)))) + .build())) + .build()) + .build()) + .build(); + } catch (IOException e) { + throw new ExceptionInInitializerError(e); + } + } + private static final BeamFnApi.Target INPUT_TARGET = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference("1") + .setName("out") + .build(); + + @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool); + @Mock private BeamFnDataClient mockBeamFnDataClient; + @Captor private ArgumentCaptor<ThrowingConsumer<WindowedValue<String>>> consumerCaptor; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { + String bundleId = "57"; + String outputId = "101"; + + List<WindowedValue<String>> outputValues = new ArrayList<>(); + + Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create(); + consumers.put("outputPC", + (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) outputValues::add); + List<ThrowingRunnable> startFunctions = new ArrayList<>(); + List<ThrowingRunnable> finishFunctions = new ArrayList<>(); + + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() + .setUrn("urn:org.apache.beam:source:runner:0.1") + .setParameter(Any.pack(PORT_SPEC)) + .build(); + + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putOutputs(outputId, "outputPC") + .build(); + + new BeamFnDataReadRunner.Factory<String>().createRunnerForPTransform( + PipelineOptionsFactory.create(), + mockBeamFnDataClient, + "pTransformId", + pTransform, + Suppliers.ofInstance(bundleId)::get, + ImmutableMap.of("outputPC", + RunnerApi.PCollection.newBuilder().setCoderId(CODER_SPEC_ID).build()), + ImmutableMap.of(CODER_SPEC_ID, CODER_SPEC), + consumers, + startFunctions::add, + finishFunctions::add); + + verifyZeroInteractions(mockBeamFnDataClient); + + CompletableFuture<Void> completionFuture = new CompletableFuture<>(); + when(mockBeamFnDataClient.forInboundConsumer(any(), any(), any(), any())) + .thenReturn(completionFuture); + Iterables.getOnlyElement(startFunctions).run(); + verify(mockBeamFnDataClient).forInboundConsumer( + eq(PORT_SPEC.getApiServiceDescriptor()), + eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference("pTransformId") + .setName(outputId) + .build())), + eq(CODER), + consumerCaptor.capture()); + + consumerCaptor.getValue().accept(valueInGlobalWindow("TestValue")); + assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); + outputValues.clear(); + + assertThat(consumers.keySet(), containsInAnyOrder("outputPC")); + + completionFuture.complete(null); + Iterables.getOnlyElement(finishFunctions).run(); + + verifyNoMoreInteractions(mockBeamFnDataClient); + } + + @Test + public void testReuseForMultipleBundles() throws Exception { + CompletableFuture<Void> bundle1Future = new CompletableFuture<>(); + CompletableFuture<Void> bundle2Future = new CompletableFuture<>(); + when(mockBeamFnDataClient.forInboundConsumer( + any(), + any(), + any(), + any())).thenReturn(bundle1Future).thenReturn(bundle2Future); + List<WindowedValue<String>> valuesA = new ArrayList<>(); + List<WindowedValue<String>> valuesB = new ArrayList<>(); + + AtomicReference<String> bundleId = new AtomicReference<>("0"); + BeamFnDataReadRunner<String> readRunner = new BeamFnDataReadRunner<>( + FUNCTION_SPEC, + bundleId::get, + INPUT_TARGET, + CODER_SPEC, + mockBeamFnDataClient, + ImmutableList.of(valuesA::add, valuesB::add)); + + // Process for bundle id 0 + readRunner.registerInputLocation(); + + verify(mockBeamFnDataClient).forInboundConsumer( + eq(PORT_SPEC.getApiServiceDescriptor()), + eq(KV.of(bundleId.get(), INPUT_TARGET)), + eq(CODER), + consumerCaptor.capture()); + + executor.submit(new Runnable() { + @Override + public void run() { + // Sleep for some small amount of time simulating the parent blocking + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + try { + consumerCaptor.getValue().accept(valueInGlobalWindow("ABC")); + consumerCaptor.getValue().accept(valueInGlobalWindow("DEF")); + } catch (Exception e) { + bundle1Future.completeExceptionally(e); + } finally { + bundle1Future.complete(null); + } + } + }); + + readRunner.blockTillReadFinishes(); + assertThat(valuesA, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"))); + assertThat(valuesB, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"))); + + // Process for bundle id 1 + bundleId.set("1"); + valuesA.clear(); + valuesB.clear(); + readRunner.registerInputLocation(); + + verify(mockBeamFnDataClient).forInboundConsumer( + eq(PORT_SPEC.getApiServiceDescriptor()), + eq(KV.of(bundleId.get(), INPUT_TARGET)), + eq(CODER), + consumerCaptor.capture()); + + executor.submit(new Runnable() { + @Override + public void run() { + // Sleep for some small amount of time simulating the parent blocking + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); + try { + consumerCaptor.getValue().accept(valueInGlobalWindow("GHI")); + consumerCaptor.getValue().accept(valueInGlobalWindow("JKL")); + } catch (Exception e) { + bundle2Future.completeExceptionally(e); + } finally { + bundle2Future.complete(null); + } + } + }); + + readRunner.blockTillReadFinishes(); + assertThat(valuesA, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); + assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); + + verifyNoMoreInteractions(mockBeamFnDataClient); + } + + @Test + public void testRegistration() { + for (Registrar registrar : + ServiceLoader.load(Registrar.class)) { + if (registrar instanceof BeamFnDataReadRunner.Registrar) { + assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN)); + return; + } + } + fail("Expected registrar not found."); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java new file mode 100644 index 0000000..28838b1 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness; + +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Suppliers; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.ServiceLoader; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar; +import org.apache.beam.fn.harness.data.BeamFnDataClient; +import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.fn.v1.BeamFnApi; +import org.apache.beam.runners.dataflow.util.CloudObjects; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.hamcrest.collection.IsMapContaining; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Matchers; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** Tests for {@link BeamFnDataWriteRunner}. */ +@RunWith(JUnit4.class) +public class BeamFnDataWriteRunnerTest { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder() + .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build(); + private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder() + .setParameter(Any.pack(PORT_SPEC)).build(); + private static final String CODER_ID = "string-coder-id"; + private static final Coder<WindowedValue<String>> CODER = + WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); + private static final RunnerApi.Coder CODER_SPEC; + private static final String URN = "urn:org.apache.beam:sink:runner:0.1"; + + static { + try { + CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec( + RunnerApi.SdkFunctionSpec.newBuilder().setSpec( + RunnerApi.FunctionSpec.newBuilder().setParameter( + Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(CODER)))) + .build())) + .build()) + .build()) + .build(); + } catch (IOException e) { + throw new ExceptionInInitializerError(e); + } + } + private static final BeamFnApi.Target OUTPUT_TARGET = BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference("1") + .setName("out") + .build(); + + @Mock private BeamFnDataClient mockBeamFnDataClient; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + + @Test + public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception { + String bundleId = "57L"; + String inputId = "100L"; + + Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create(); + List<ThrowingRunnable> startFunctions = new ArrayList<>(); + List<ThrowingRunnable> finishFunctions = new ArrayList<>(); + + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() + .setUrn("urn:org.apache.beam:sink:runner:0.1") + .setParameter(Any.pack(PORT_SPEC)) + .build(); + + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs(inputId, "inputPC") + .build(); + + new BeamFnDataWriteRunner.Factory<String>().createRunnerForPTransform( + PipelineOptionsFactory.create(), + mockBeamFnDataClient, + "ptransformId", + pTransform, + Suppliers.ofInstance(bundleId)::get, + ImmutableMap.of("inputPC", + RunnerApi.PCollection.newBuilder().setCoderId(CODER_ID).build()), + ImmutableMap.of(CODER_ID, CODER_SPEC), + consumers, + startFunctions::add, + finishFunctions::add); + + verifyZeroInteractions(mockBeamFnDataClient); + + List<WindowedValue<String>> outputValues = new ArrayList<>(); + AtomicBoolean wasCloseCalled = new AtomicBoolean(); + CloseableThrowingConsumer<WindowedValue<String>> outputConsumer = + new CloseableThrowingConsumer<WindowedValue<String>>(){ + @Override + public void close() throws Exception { + wasCloseCalled.set(true); + } + + @Override + public void accept(WindowedValue<String> t) throws Exception { + outputValues.add(t); + } + }; + + when(mockBeamFnDataClient.forOutboundConsumer( + any(), + any(), + Matchers.<Coder<WindowedValue<String>>>any())).thenReturn(outputConsumer); + Iterables.getOnlyElement(startFunctions).run(); + verify(mockBeamFnDataClient).forOutboundConsumer( + eq(PORT_SPEC.getApiServiceDescriptor()), + eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference("ptransformId") + .setName(inputId) + .build())), + eq(CODER)); + + assertThat(consumers.keySet(), containsInAnyOrder("inputPC")); + Iterables.getOnlyElement(consumers.get("inputPC")).accept(valueInGlobalWindow("TestValue")); + assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); + outputValues.clear(); + + assertFalse(wasCloseCalled.get()); + Iterables.getOnlyElement(finishFunctions).run(); + assertTrue(wasCloseCalled.get()); + + verifyNoMoreInteractions(mockBeamFnDataClient); + } + + @Test + public void testReuseForMultipleBundles() throws Exception { + RecordingConsumer<WindowedValue<String>> valuesA = new RecordingConsumer<>(); + RecordingConsumer<WindowedValue<String>> valuesB = new RecordingConsumer<>(); + when(mockBeamFnDataClient.forOutboundConsumer( + any(), + any(), + Matchers.<Coder<WindowedValue<String>>>any())).thenReturn(valuesA).thenReturn(valuesB); + AtomicReference<String> bundleId = new AtomicReference<>("0"); + BeamFnDataWriteRunner<String> writeRunner = new BeamFnDataWriteRunner<>( + FUNCTION_SPEC, + bundleId::get, + OUTPUT_TARGET, + CODER_SPEC, + mockBeamFnDataClient); + + // Process for bundle id 0 + writeRunner.registerForOutput(); + + verify(mockBeamFnDataClient).forOutboundConsumer( + eq(PORT_SPEC.getApiServiceDescriptor()), + eq(KV.of(bundleId.get(), OUTPUT_TARGET)), + eq(CODER)); + + writeRunner.consume(valueInGlobalWindow("ABC")); + writeRunner.consume(valueInGlobalWindow("DEF")); + writeRunner.close(); + + assertTrue(valuesA.closed); + assertThat(valuesA, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"))); + + // Process for bundle id 1 + bundleId.set("1"); + valuesA.clear(); + valuesB.clear(); + writeRunner.registerForOutput(); + + verify(mockBeamFnDataClient).forOutboundConsumer( + eq(PORT_SPEC.getApiServiceDescriptor()), + eq(KV.of(bundleId.get(), OUTPUT_TARGET)), + eq(CODER)); + + writeRunner.consume(valueInGlobalWindow("GHI")); + writeRunner.consume(valueInGlobalWindow("JKL")); + writeRunner.close(); + + assertTrue(valuesB.closed); + assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); + verifyNoMoreInteractions(mockBeamFnDataClient); + } + + private static class RecordingConsumer<T> extends ArrayList<T> + implements CloseableThrowingConsumer<T> { + private boolean closed; + @Override + public void close() throws Exception { + closed = true; + } + + @Override + public void accept(T t) throws Exception { + if (closed) { + throw new IllegalStateException("Consumer is closed but attempting to consume " + t); + } + add(t); + } + } + + @Test + public void testRegistration() { + for (Registrar registrar : + ServiceLoader.load(Registrar.class)) { + if (registrar instanceof BeamFnDataWriteRunner.Registrar) { + assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN)); + return; + } + } + fail("Expected registrar not found."); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java new file mode 100644 index 0000000..7aec161 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness; + +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.collection.IsEmptyCollection.empty; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import com.google.common.base.Suppliers; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.ServiceLoader; +import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.CountingSource; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.util.WindowedValue; +import org.hamcrest.Matchers; +import org.hamcrest.collection.IsMapContaining; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link BoundedSourceRunner}. */ +@RunWith(JUnit4.class) +public class BoundedSourceRunnerTest { + + public static final String URN = "urn:org.apache.beam:source:java:0.1"; + + @Test + public void testRunReadLoopWithMultipleSources() throws Exception { + List<WindowedValue<Long>> out1Values = new ArrayList<>(); + List<WindowedValue<Long>> out2Values = new ArrayList<>(); + Collection<ThrowingConsumer<WindowedValue<Long>>> consumers = + ImmutableList.of(out1Values::add, out2Values::add); + + BoundedSourceRunner<BoundedSource<Long>, Long> runner = new BoundedSourceRunner<>( + PipelineOptionsFactory.create(), + RunnerApi.FunctionSpec.getDefaultInstance(), + consumers); + + runner.runReadLoop(valueInGlobalWindow(CountingSource.upTo(2))); + runner.runReadLoop(valueInGlobalWindow(CountingSource.upTo(1))); + + assertThat(out1Values, + contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L), valueInGlobalWindow(0L))); + assertThat(out2Values, + contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L), valueInGlobalWindow(0L))); + } + + @Test + public void testRunReadLoopWithEmptySource() throws Exception { + List<WindowedValue<Long>> outValues = new ArrayList<>(); + Collection<ThrowingConsumer<WindowedValue<Long>>> consumers = + ImmutableList.of(outValues::add); + + BoundedSourceRunner<BoundedSource<Long>, Long> runner = new BoundedSourceRunner<>( + PipelineOptionsFactory.create(), + RunnerApi.FunctionSpec.getDefaultInstance(), + consumers); + + runner.runReadLoop(valueInGlobalWindow(CountingSource.upTo(0))); + + assertThat(outValues, empty()); + } + + @Test + public void testStart() throws Exception { + List<WindowedValue<Long>> outValues = new ArrayList<>(); + Collection<ThrowingConsumer<WindowedValue<Long>>> consumers = + ImmutableList.of(outValues::add); + + ByteString encodedSource = + ByteString.copyFrom(SerializableUtils.serializeToByteArray(CountingSource.upTo(3))); + + BoundedSourceRunner<BoundedSource<Long>, Long> runner = new BoundedSourceRunner<>( + PipelineOptionsFactory.create(), + RunnerApi.FunctionSpec.newBuilder().setParameter( + Any.pack(BytesValue.newBuilder().setValue(encodedSource).build())).build(), + consumers); + + runner.start(); + + assertThat(outValues, + contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L), valueInGlobalWindow(2L))); + } + + @Test + public void testCreatingAndProcessingSourceFromFactory() throws Exception { + List<WindowedValue<String>> outputValues = new ArrayList<>(); + + Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create(); + consumers.put("outputPC", + (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) outputValues::add); + List<ThrowingRunnable> startFunctions = new ArrayList<>(); + List<ThrowingRunnable> finishFunctions = new ArrayList<>(); + + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() + .setUrn("urn:org.apache.beam:source:java:0.1") + .setParameter(Any.pack(BytesValue.newBuilder() + .setValue(ByteString.copyFrom( + SerializableUtils.serializeToByteArray(CountingSource.upTo(3)))) + .build())) + .build(); + + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs("input", "inputPC") + .putOutputs("output", "outputPC") + .build(); + + new BoundedSourceRunner.Factory<>().createRunnerForPTransform( + PipelineOptionsFactory.create(), + null /* beamFnDataClient */, + "pTransformId", + pTransform, + Suppliers.ofInstance("57L")::get, + ImmutableMap.of(), + ImmutableMap.of(), + consumers, + startFunctions::add, + finishFunctions::add); + + // This is testing a deprecated way of running sources and should be removed + // once all source definitions are instead propagated along the input edge. + Iterables.getOnlyElement(startFunctions).run(); + assertThat(outputValues, contains( + valueInGlobalWindow(0L), + valueInGlobalWindow(1L), + valueInGlobalWindow(2L))); + outputValues.clear(); + + // Check that when passing a source along as an input, the source is processed. + assertThat(consumers.keySet(), containsInAnyOrder("inputPC", "outputPC")); + Iterables.getOnlyElement(consumers.get("inputPC")).accept( + valueInGlobalWindow(CountingSource.upTo(2))); + assertThat(outputValues, contains( + valueInGlobalWindow(0L), + valueInGlobalWindow(1L))); + + assertThat(finishFunctions, Matchers.empty()); + } + + @Test + public void testRegistration() { + for (Registrar registrar : + ServiceLoader.load(Registrar.class)) { + if (registrar instanceof BoundedSourceRunner.Registrar) { + assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN)); + return; + } + } + fail("Expected registrar not found."); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java new file mode 100644 index 0000000..98362a2 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.fn.harness; + +import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow; +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Suppliers; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import com.google.protobuf.Message; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.ServiceLoader; +import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar; +import org.apache.beam.fn.harness.fn.ThrowingConsumer; +import org.apache.beam.fn.harness.fn.ThrowingRunnable; +import org.apache.beam.runners.core.construction.ParDoTranslation; +import org.apache.beam.runners.dataflow.util.CloudObjects; +import org.apache.beam.runners.dataflow.util.DoFnInfo; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.common.runner.v1.RunnerApi; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.hamcrest.collection.IsMapContaining; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link FnApiDoFnRunner}. */ +@RunWith(JUnit4.class) +public class FnApiDoFnRunnerTest { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final Coder<WindowedValue<String>> STRING_CODER = + WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); + private static final String STRING_CODER_SPEC_ID = "999L"; + private static final RunnerApi.Coder STRING_CODER_SPEC; + + static { + try { + STRING_CODER_SPEC = RunnerApi.Coder.newBuilder() + .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder() + .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(STRING_CODER)))) + .build()))) + .build()) + .build(); + } catch (IOException e) { + throw new ExceptionInInitializerError(e); + } + } + + private static class TestDoFn extends DoFn<String, String> { + private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput"); + private static final TupleTag<String> additionalOutput = new TupleTag<>("output"); + + private BoundedWindow window; + + @ProcessElement + public void processElement(ProcessContext context, BoundedWindow window) { + context.output("MainOutput" + context.element()); + context.output(additionalOutput, "AdditionalOutput" + context.element()); + this.window = window; + } + + @FinishBundle + public void finishBundle(FinishBundleContext context) { + if (window != null) { + context.output("FinishBundle", window.maxTimestamp(), window); + window = null; + } + } + } + + /** + * Create a DoFn that has 3 inputs (inputATarget1, inputATarget2, inputBTarget) and 2 outputs + * (mainOutput, output). Validate that inputs are fed to the {@link DoFn} and that outputs + * are directed to the correct consumers. + */ + @Test + public void testCreatingAndProcessingDoFn() throws Exception { + Map<String, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); + String pTransformId = "pTransformId"; + String mainOutputId = "101"; + String additionalOutputId = "102"; + + DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn( + new TestDoFn(), + WindowingStrategy.globalDefault(), + ImmutableList.of(), + StringUtf8Coder.of(), + Long.parseLong(mainOutputId), + ImmutableMap.of( + Long.parseLong(mainOutputId), TestDoFn.mainOutput, + Long.parseLong(additionalOutputId), TestDoFn.additionalOutput)); + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() + .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN) + .setParameter(Any.pack(BytesValue.newBuilder() + .setValue(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo))) + .build())) + .build(); + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs("inputA", "inputATarget") + .putInputs("inputB", "inputBTarget") + .putOutputs(mainOutputId, "mainOutputTarget") + .putOutputs(additionalOutputId, "additionalOutputTarget") + .build(); + + List<WindowedValue<String>> mainOutputValues = new ArrayList<>(); + List<WindowedValue<String>> additionalOutputValues = new ArrayList<>(); + Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create(); + consumers.put("mainOutputTarget", + (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) mainOutputValues::add); + consumers.put("additionalOutputTarget", + (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) additionalOutputValues::add); + List<ThrowingRunnable> startFunctions = new ArrayList<>(); + List<ThrowingRunnable> finishFunctions = new ArrayList<>(); + + new FnApiDoFnRunner.Factory<>().createRunnerForPTransform( + PipelineOptionsFactory.create(), + null /* beamFnDataClient */, + pTransformId, + pTransform, + Suppliers.ofInstance("57L")::get, + ImmutableMap.of(), + ImmutableMap.of(), + consumers, + startFunctions::add, + finishFunctions::add); + + Iterables.getOnlyElement(startFunctions).run(); + mainOutputValues.clear(); + + assertThat(consumers.keySet(), containsInAnyOrder( + "inputATarget", "inputBTarget", "mainOutputTarget", "additionalOutputTarget")); + + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A1")); + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A2")); + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("B")); + assertThat(mainOutputValues, contains( + valueInGlobalWindow("MainOutputA1"), + valueInGlobalWindow("MainOutputA2"), + valueInGlobalWindow("MainOutputB"))); + assertThat(additionalOutputValues, contains( + valueInGlobalWindow("AdditionalOutputA1"), + valueInGlobalWindow("AdditionalOutputA2"), + valueInGlobalWindow("AdditionalOutputB"))); + mainOutputValues.clear(); + additionalOutputValues.clear(); + + Iterables.getOnlyElement(finishFunctions).run(); + assertThat( + mainOutputValues, + contains( + timestampedValueInGlobalWindow("FinishBundle", GlobalWindow.INSTANCE.maxTimestamp()))); + mainOutputValues.clear(); + } + + @Test + public void testRegistration() { + for (Registrar registrar : + ServiceLoader.load(Registrar.class)) { + if (registrar instanceof FnApiDoFnRunner.Registrar) { + assertThat(registrar.getPTransformRunnerFactories(), + IsMapContaining.hasKey(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN)); + return; + } + } + fail("Expected registrar not found."); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index a616b2c..0a94b5b 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -31,11 +31,11 @@ import java.util.List; import java.util.Map; import java.util.function.Consumer; import java.util.function.Supplier; +import org.apache.beam.fn.harness.PTransformRunnerFactory; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fn.ThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.v1.BeamFnApi; -import org.apache.beam.runners.core.PTransformRunnerFactory; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; http://git-wip-us.apache.org/repos/asf/beam/blob/f1b4700f/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java deleted file mode 100644 index d6a476e..0000000 --- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java +++ /dev/null @@ -1,281 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.beam.runners.core; - -import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; -import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.fail; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.base.Suppliers; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterables; -import com.google.common.collect.Multimap; -import com.google.common.util.concurrent.Uninterruptibles; -import com.google.protobuf.Any; -import com.google.protobuf.ByteString; -import com.google.protobuf.BytesValue; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.ServiceLoader; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import org.apache.beam.fn.harness.data.BeamFnDataClient; -import org.apache.beam.fn.harness.fn.ThrowingConsumer; -import org.apache.beam.fn.harness.fn.ThrowingRunnable; -import org.apache.beam.fn.harness.test.TestExecutors; -import org.apache.beam.fn.harness.test.TestExecutors.TestExecutorService; -import org.apache.beam.fn.v1.BeamFnApi; -import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar; -import org.apache.beam.runners.dataflow.util.CloudObjects; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.common.runner.v1.RunnerApi; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.transforms.windowing.GlobalWindow; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.KV; -import org.hamcrest.collection.IsMapContaining; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -/** Tests for {@link BeamFnDataReadRunner}. */ -@RunWith(JUnit4.class) -public class BeamFnDataReadRunnerTest { - - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder() - .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build(); - private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder() - .setParameter(Any.pack(PORT_SPEC)).build(); - private static final Coder<WindowedValue<String>> CODER = - WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); - private static final String CODER_SPEC_ID = "string-coder-id"; - private static final RunnerApi.Coder CODER_SPEC; - private static final String URN = "urn:org.apache.beam:source:runner:0.1"; - - static { - try { - CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec( - RunnerApi.SdkFunctionSpec.newBuilder().setSpec( - RunnerApi.FunctionSpec.newBuilder().setParameter( - Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( - OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(CODER)))) - .build())) - .build()) - .build()) - .build(); - } catch (IOException e) { - throw new ExceptionInInitializerError(e); - } - } - private static final BeamFnApi.Target INPUT_TARGET = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference("1") - .setName("out") - .build(); - - @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool); - @Mock private BeamFnDataClient mockBeamFnDataClient; - @Captor private ArgumentCaptor<ThrowingConsumer<WindowedValue<String>>> consumerCaptor; - - @Before - public void setUp() { - MockitoAnnotations.initMocks(this); - } - - @Test - public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { - String bundleId = "57"; - String outputId = "101"; - - List<WindowedValue<String>> outputValues = new ArrayList<>(); - - Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create(); - consumers.put("outputPC", - (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) outputValues::add); - List<ThrowingRunnable> startFunctions = new ArrayList<>(); - List<ThrowingRunnable> finishFunctions = new ArrayList<>(); - - RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() - .setUrn("urn:org.apache.beam:source:runner:0.1") - .setParameter(Any.pack(PORT_SPEC)) - .build(); - - RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() - .setSpec(functionSpec) - .putOutputs(outputId, "outputPC") - .build(); - - new BeamFnDataReadRunner.Factory<String>().createRunnerForPTransform( - PipelineOptionsFactory.create(), - mockBeamFnDataClient, - "pTransformId", - pTransform, - Suppliers.ofInstance(bundleId)::get, - ImmutableMap.of("outputPC", - RunnerApi.PCollection.newBuilder().setCoderId(CODER_SPEC_ID).build()), - ImmutableMap.of(CODER_SPEC_ID, CODER_SPEC), - consumers, - startFunctions::add, - finishFunctions::add); - - verifyZeroInteractions(mockBeamFnDataClient); - - CompletableFuture<Void> completionFuture = new CompletableFuture<>(); - when(mockBeamFnDataClient.forInboundConsumer(any(), any(), any(), any())) - .thenReturn(completionFuture); - Iterables.getOnlyElement(startFunctions).run(); - verify(mockBeamFnDataClient).forInboundConsumer( - eq(PORT_SPEC.getApiServiceDescriptor()), - eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference("pTransformId") - .setName(outputId) - .build())), - eq(CODER), - consumerCaptor.capture()); - - consumerCaptor.getValue().accept(valueInGlobalWindow("TestValue")); - assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); - outputValues.clear(); - - assertThat(consumers.keySet(), containsInAnyOrder("outputPC")); - - completionFuture.complete(null); - Iterables.getOnlyElement(finishFunctions).run(); - - verifyNoMoreInteractions(mockBeamFnDataClient); - } - - @Test - public void testReuseForMultipleBundles() throws Exception { - CompletableFuture<Void> bundle1Future = new CompletableFuture<>(); - CompletableFuture<Void> bundle2Future = new CompletableFuture<>(); - when(mockBeamFnDataClient.forInboundConsumer( - any(), - any(), - any(), - any())).thenReturn(bundle1Future).thenReturn(bundle2Future); - List<WindowedValue<String>> valuesA = new ArrayList<>(); - List<WindowedValue<String>> valuesB = new ArrayList<>(); - - AtomicReference<String> bundleId = new AtomicReference<>("0"); - BeamFnDataReadRunner<String> readRunner = new BeamFnDataReadRunner<>( - FUNCTION_SPEC, - bundleId::get, - INPUT_TARGET, - CODER_SPEC, - mockBeamFnDataClient, - ImmutableList.of(valuesA::add, valuesB::add)); - - // Process for bundle id 0 - readRunner.registerInputLocation(); - - verify(mockBeamFnDataClient).forInboundConsumer( - eq(PORT_SPEC.getApiServiceDescriptor()), - eq(KV.of(bundleId.get(), INPUT_TARGET)), - eq(CODER), - consumerCaptor.capture()); - - executor.submit(new Runnable() { - @Override - public void run() { - // Sleep for some small amount of time simulating the parent blocking - Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); - try { - consumerCaptor.getValue().accept(valueInGlobalWindow("ABC")); - consumerCaptor.getValue().accept(valueInGlobalWindow("DEF")); - } catch (Exception e) { - bundle1Future.completeExceptionally(e); - } finally { - bundle1Future.complete(null); - } - } - }); - - readRunner.blockTillReadFinishes(); - assertThat(valuesA, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"))); - assertThat(valuesB, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"))); - - // Process for bundle id 1 - bundleId.set("1"); - valuesA.clear(); - valuesB.clear(); - readRunner.registerInputLocation(); - - verify(mockBeamFnDataClient).forInboundConsumer( - eq(PORT_SPEC.getApiServiceDescriptor()), - eq(KV.of(bundleId.get(), INPUT_TARGET)), - eq(CODER), - consumerCaptor.capture()); - - executor.submit(new Runnable() { - @Override - public void run() { - // Sleep for some small amount of time simulating the parent blocking - Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); - try { - consumerCaptor.getValue().accept(valueInGlobalWindow("GHI")); - consumerCaptor.getValue().accept(valueInGlobalWindow("JKL")); - } catch (Exception e) { - bundle2Future.completeExceptionally(e); - } finally { - bundle2Future.complete(null); - } - } - }); - - readRunner.blockTillReadFinishes(); - assertThat(valuesA, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); - assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); - - verifyNoMoreInteractions(mockBeamFnDataClient); - } - - @Test - public void testRegistration() { - for (Registrar registrar : - ServiceLoader.load(Registrar.class)) { - if (registrar instanceof BeamFnDataReadRunner.Registrar) { - assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN)); - return; - } - } - fail("Expected registrar not found."); - } -}
