http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java index 9339347..f0fe274 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataReadRunner.java @@ -18,28 +18,22 @@ package org.apache.beam.runners.core; -import static com.google.common.collect.Iterables.getOnlyElement; - import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.auto.service.AutoService; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Multimap; +import com.google.common.collect.FluentIterable; +import com.google.common.collect.ImmutableList; import com.google.protobuf.BytesValue; import java.io.IOException; import java.util.Collection; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fn.ThrowingConsumer; -import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.v1.BeamFnApi; import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; -import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.slf4j.Logger; @@ -54,61 +48,9 @@ import org.slf4j.LoggerFactory; * {@link #blockTillReadFinishes()} to finish. */ public class BeamFnDataReadRunner<OutputT> { - private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataReadRunner.class); - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - private static final String URN = "urn:org.apache.beam:source:runner:0.1"; - - /** A registrar which provides a factory to handle reading from the Fn Api Data Plane. */ - @AutoService(PTransformRunnerFactory.Registrar.class) - public static class Registrar implements - PTransformRunnerFactory.Registrar { - - @Override - public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() { - return ImmutableMap.of(URN, new Factory()); - } - } - /** A factory for {@link BeamFnDataReadRunner}s. */ - static class Factory<OutputT> - implements PTransformRunnerFactory<BeamFnDataReadRunner<OutputT>> { - - @Override - public BeamFnDataReadRunner<OutputT> createRunnerForPTransform( - PipelineOptions pipelineOptions, - BeamFnDataClient beamFnDataClient, - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier<String> processBundleInstructionId, - Map<String, RunnerApi.PCollection> pCollections, - Map<String, RunnerApi.Coder> coders, - Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, - Consumer<ThrowingRunnable> addStartFunction, - Consumer<ThrowingRunnable> addFinishFunction) throws IOException { - - BeamFnApi.Target target = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference(pTransformId) - .setName(getOnlyElement(pTransform.getOutputsMap().keySet())) - .build(); - RunnerApi.Coder coderSpec = coders.get(pCollections.get( - getOnlyElement(pTransform.getOutputsMap().values())).getCoderId()); - Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers = - (Collection) pCollectionIdsToConsumers.get( - getOnlyElement(pTransform.getOutputsMap().values())); - - BeamFnDataReadRunner<OutputT> runner = new BeamFnDataReadRunner<>( - pTransform.getSpec(), - processBundleInstructionId, - target, - coderSpec, - beamFnDataClient, - consumers); - addStartFunction.accept(runner::registerInputLocation); - addFinishFunction.accept(runner::blockTillReadFinishes); - return runner; - } - } + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); private final BeamFnApi.ApiServiceDescriptor apiServiceDescriptor; private final Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers; @@ -119,20 +61,20 @@ public class BeamFnDataReadRunner<OutputT> { private CompletableFuture<Void> readFuture; - BeamFnDataReadRunner( + public BeamFnDataReadRunner( RunnerApi.FunctionSpec functionSpec, Supplier<String> processBundleInstructionIdSupplier, BeamFnApi.Target inputTarget, RunnerApi.Coder coderSpec, BeamFnDataClient beamFnDataClientFactory, - Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers) + Map<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap) throws IOException { this.apiServiceDescriptor = functionSpec.getParameter().unpack(BeamFnApi.RemoteGrpcPort.class) .getApiServiceDescriptor(); this.inputTarget = inputTarget; this.processBundleInstructionIdSupplier = processBundleInstructionIdSupplier; this.beamFnDataClientFactory = beamFnDataClientFactory; - this.consumers = consumers; + this.consumers = ImmutableList.copyOf(FluentIterable.concat(outputMap.values())); @SuppressWarnings("unchecked") Coder<WindowedValue<OutputT>> coder =
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java index c2a996b..a48df12 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BeamFnDataWriteRunner.java @@ -18,91 +18,30 @@ package org.apache.beam.runners.core; -import static com.google.common.collect.Iterables.getOnlyElement; - import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.auto.service.AutoService; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Multimap; import com.google.protobuf.BytesValue; import java.io.IOException; import java.util.Map; -import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer; -import org.apache.beam.fn.harness.fn.ThrowingConsumer; -import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.v1.BeamFnApi; import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; -import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; /** - * Registers as a consumer with the Beam Fn Data Api. Consumes elements and encodes them for - * transmission. + * Registers as a consumer with the Beam Fn Data API. Propagates and elements consumed to + * the the registered consumer. * * <p>Can be re-used serially across {@link org.apache.beam.fn.v1.BeamFnApi.ProcessBundleRequest}s. * For each request, call {@link #registerForOutput()} to start and call {@link #close()} to finish. */ public class BeamFnDataWriteRunner<InputT> { - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - private static final String URN = "urn:org.apache.beam:sink:runner:0.1"; - - /** A registrar which provides a factory to handle writing to the Fn Api Data Plane. */ - @AutoService(PTransformRunnerFactory.Registrar.class) - public static class Registrar implements - PTransformRunnerFactory.Registrar { - - @Override - public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() { - return ImmutableMap.of(URN, new Factory()); - } - } - - /** A factory for {@link BeamFnDataWriteRunner}s. */ - static class Factory<InputT> - implements PTransformRunnerFactory<BeamFnDataWriteRunner<InputT>> { - - @Override - public BeamFnDataWriteRunner<InputT> createRunnerForPTransform( - PipelineOptions pipelineOptions, - BeamFnDataClient beamFnDataClient, - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier<String> processBundleInstructionId, - Map<String, RunnerApi.PCollection> pCollections, - Map<String, RunnerApi.Coder> coders, - Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, - Consumer<ThrowingRunnable> addStartFunction, - Consumer<ThrowingRunnable> addFinishFunction) throws IOException { - BeamFnApi.Target target = BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference(pTransformId) - .setName(getOnlyElement(pTransform.getInputsMap().keySet())) - .build(); - RunnerApi.Coder coderSpec = coders.get( - pCollections.get(getOnlyElement(pTransform.getInputsMap().values())).getCoderId()); - BeamFnDataWriteRunner<InputT> runner = - new BeamFnDataWriteRunner<>( - pTransform.getSpec(), - processBundleInstructionId, - target, - coderSpec, - beamFnDataClient); - addStartFunction.accept(runner::registerForOutput); - pCollectionIdsToConsumers.put( - getOnlyElement(pTransform.getInputsMap().values()), - (ThrowingConsumer) - (ThrowingConsumer<WindowedValue<InputT>>) runner::consume); - addFinishFunction.accept(runner::close); - return runner; - } - } private final BeamFnApi.ApiServiceDescriptor apiServiceDescriptor; private final BeamFnApi.Target outputTarget; @@ -112,7 +51,7 @@ public class BeamFnDataWriteRunner<InputT> { private CloseableThrowingConsumer<WindowedValue<InputT>> consumer; - BeamFnDataWriteRunner( + public BeamFnDataWriteRunner( RunnerApi.FunctionSpec functionSpec, Supplier<String> processBundleInstructionIdSupplier, BeamFnApi.Target outputTarget, http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java index 3338c3a..4d530b8 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/BoundedSourceRunner.java @@ -18,20 +18,14 @@ package org.apache.beam.runners.core; -import com.google.auto.service.AutoService; +import com.google.common.collect.FluentIterable; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Multimap; import com.google.protobuf.BytesValue; import com.google.protobuf.InvalidProtocolBufferException; import java.io.IOException; import java.util.Collection; import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Supplier; -import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fn.ThrowingConsumer; -import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.sdk.common.runner.v1.RunnerApi; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.Source.Reader; @@ -40,77 +34,21 @@ import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; /** - * A runner which creates {@link Reader}s for each {@link BoundedSource} sent as an input and - * executes the {@link Reader}s read loop. + * A runner which creates {@link Reader}s for each {@link BoundedSource} and executes + * the {@link Reader}s read loop. */ public class BoundedSourceRunner<InputT extends BoundedSource<OutputT>, OutputT> { - - private static final String URN = "urn:org.apache.beam:source:java:0.1"; - - /** A registrar which provides a factory to handle Java {@link BoundedSource}s. */ - @AutoService(PTransformRunnerFactory.Registrar.class) - public static class Registrar implements - PTransformRunnerFactory.Registrar { - - @Override - public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() { - return ImmutableMap.of(URN, new Factory()); - } - } - - /** A factory for {@link BoundedSourceRunner}. */ - static class Factory<InputT extends BoundedSource<OutputT>, OutputT> - implements PTransformRunnerFactory<BoundedSourceRunner<InputT, OutputT>> { - @Override - public BoundedSourceRunner<InputT, OutputT> createRunnerForPTransform( - PipelineOptions pipelineOptions, - BeamFnDataClient beamFnDataClient, - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier<String> processBundleInstructionId, - Map<String, RunnerApi.PCollection> pCollections, - Map<String, RunnerApi.Coder> coders, - Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, - Consumer<ThrowingRunnable> addStartFunction, - Consumer<ThrowingRunnable> addFinishFunction) { - - ImmutableList.Builder<ThrowingConsumer<WindowedValue<?>>> consumers = ImmutableList.builder(); - for (String pCollectionId : pTransform.getOutputsMap().values()) { - consumers.addAll(pCollectionIdsToConsumers.get(pCollectionId)); - } - - @SuppressWarnings({"rawtypes", "unchecked"}) - BoundedSourceRunner<InputT, OutputT> runner = new BoundedSourceRunner( - pipelineOptions, - pTransform.getSpec(), - consumers.build()); - - // TODO: Remove and replace with source being sent across gRPC port - addStartFunction.accept(runner::start); - - ThrowingConsumer runReadLoop = - (ThrowingConsumer<WindowedValue<InputT>>) runner::runReadLoop; - for (String pCollectionId : pTransform.getInputsMap().values()) { - pCollectionIdsToConsumers.put( - pCollectionId, - runReadLoop); - } - - return runner; - } - } - private final PipelineOptions pipelineOptions; private final RunnerApi.FunctionSpec definition; private final Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers; - BoundedSourceRunner( + public BoundedSourceRunner( PipelineOptions pipelineOptions, RunnerApi.FunctionSpec definition, - Collection<ThrowingConsumer<WindowedValue<OutputT>>> consumers) { + Map<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap) { this.pipelineOptions = pipelineOptions; this.definition = definition; - this.consumers = consumers; + this.consumers = ImmutableList.copyOf(FluentIterable.concat(outputMap.values())); } /** http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java deleted file mode 100644 index b3cf3a7..0000000 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/FnApiDoFnRunner.java +++ /dev/null @@ -1,547 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.core; - -import static com.google.common.base.Preconditions.checkArgument; - -import com.google.auto.service.AutoService; -import com.google.common.collect.Collections2; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; -import com.google.common.collect.Multimap; -import com.google.protobuf.ByteString; -import com.google.protobuf.BytesValue; -import com.google.protobuf.InvalidProtocolBufferException; -import java.util.Collection; -import java.util.HashSet; -import java.util.Iterator; -import java.util.Map; -import java.util.Objects; -import java.util.function.Consumer; -import java.util.function.Supplier; -import org.apache.beam.fn.harness.data.BeamFnDataClient; -import org.apache.beam.fn.harness.fn.ThrowingConsumer; -import org.apache.beam.fn.harness.fn.ThrowingRunnable; -import org.apache.beam.runners.core.construction.ParDoTranslation; -import org.apache.beam.runners.dataflow.util.DoFnInfo; -import org.apache.beam.sdk.common.runner.v1.RunnerApi; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.state.State; -import org.apache.beam.sdk.state.TimeDomain; -import org.apache.beam.sdk.state.Timer; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.DoFn.OnTimerContext; -import org.apache.beam.sdk.transforms.DoFn.ProcessContext; -import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; -import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; -import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.PaneInfo; -import org.apache.beam.sdk.util.SerializableUtils; -import org.apache.beam.sdk.util.UserCodeException; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.WindowingStrategy; -import org.joda.time.Instant; - -/** - * A {@link DoFnRunner} specific to integrating with the Fn Api. This is to remove the layers - * of abstraction caused by StateInternals/TimerInternals since they model state and timer - * concepts differently. - */ -public class FnApiDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, OutputT> { - /** - * A registrar which provides a factory to handle Java {@link DoFn}s. - */ - @AutoService(PTransformRunnerFactory.Registrar.class) - public static class Registrar implements - PTransformRunnerFactory.Registrar { - - @Override - public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() { - return ImmutableMap.of(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN, new Factory()); - } - } - - /** A factory for {@link FnApiDoFnRunner}. */ - static class Factory<InputT, OutputT> - implements PTransformRunnerFactory<DoFnRunner<InputT, OutputT>> { - - @Override - public DoFnRunner<InputT, OutputT> createRunnerForPTransform( - PipelineOptions pipelineOptions, - BeamFnDataClient beamFnDataClient, - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier<String> processBundleInstructionId, - Map<String, RunnerApi.PCollection> pCollections, - Map<String, RunnerApi.Coder> coders, - Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, - Consumer<ThrowingRunnable> addStartFunction, - Consumer<ThrowingRunnable> addFinishFunction) { - - // For every output PCollection, create a map from output name to Consumer - ImmutableMap.Builder<String, Collection<ThrowingConsumer<WindowedValue<?>>>> - outputMapBuilder = ImmutableMap.builder(); - for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) { - outputMapBuilder.put( - entry.getKey(), - pCollectionIdsToConsumers.get(entry.getValue())); - } - ImmutableMap<String, Collection<ThrowingConsumer<WindowedValue<?>>>> outputMap = - outputMapBuilder.build(); - - // Get the DoFnInfo from the serialized blob. - ByteString serializedFn; - try { - serializedFn = pTransform.getSpec().getParameter().unpack(BytesValue.class).getValue(); - } catch (InvalidProtocolBufferException e) { - throw new IllegalArgumentException( - String.format("Unable to unwrap DoFn %s", pTransform.getSpec()), e); - } - @SuppressWarnings({"unchecked", "rawtypes"}) - DoFnInfo<InputT, OutputT> doFnInfo = (DoFnInfo) SerializableUtils.deserializeFromByteArray( - serializedFn.toByteArray(), "DoFnInfo"); - - // Verify that the DoFnInfo tag to output map matches the output map on the PTransform. - checkArgument( - Objects.equals( - new HashSet<>(Collections2.transform(outputMap.keySet(), Long::parseLong)), - doFnInfo.getOutputMap().keySet()), - "Unexpected mismatch between transform output map %s and DoFnInfo output map %s.", - outputMap.keySet(), - doFnInfo.getOutputMap()); - - ImmutableMultimap.Builder<TupleTag<?>, - ThrowingConsumer<WindowedValue<?>>> tagToOutputMapBuilder = - ImmutableMultimap.builder(); - for (Map.Entry<Long, TupleTag<?>> entry : doFnInfo.getOutputMap().entrySet()) { - @SuppressWarnings({"unchecked", "rawtypes"}) - Collection<ThrowingConsumer<WindowedValue<?>>> consumers = - outputMap.get(Long.toString(entry.getKey())); - tagToOutputMapBuilder.putAll(entry.getValue(), consumers); - } - - ImmutableMultimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> tagToOutputMap = - tagToOutputMapBuilder.build(); - - @SuppressWarnings({"unchecked", "rawtypes"}) - DoFnRunner<InputT, OutputT> runner = new FnApiDoFnRunner<>( - pipelineOptions, - doFnInfo.getDoFn(), - (Collection<ThrowingConsumer<WindowedValue<OutputT>>>) (Collection) - tagToOutputMap.get(doFnInfo.getOutputMap().get(doFnInfo.getMainOutput())), - tagToOutputMap, - doFnInfo.getWindowingStrategy()); - - // Register the appropriate handlers. - addStartFunction.accept(runner::startBundle); - for (String pcollectionId : pTransform.getInputsMap().values()) { - pCollectionIdsToConsumers.put( - pcollectionId, - (ThrowingConsumer) (ThrowingConsumer<WindowedValue<InputT>>) runner::processElement); - } - addFinishFunction.accept(runner::finishBundle); - return runner; - } - } - - ////////////////////////////////////////////////////////////////////////////////////////////////// - - private final PipelineOptions pipelineOptions; - private final DoFn<InputT, OutputT> doFn; - private final Collection<ThrowingConsumer<WindowedValue<OutputT>>> mainOutputConsumers; - private final Multimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> outputMap; - private final DoFnInvoker<InputT, OutputT> doFnInvoker; - private final StartBundleContext startBundleContext; - private final ProcessBundleContext processBundleContext; - private final FinishBundleContext finishBundleContext; - - /** - * The lifetime of this member is only valid during {@link #processElement(WindowedValue)}. - */ - private WindowedValue<InputT> currentElement; - - /** - * The lifetime of this member is only valid during {@link #processElement(WindowedValue)}. - */ - private BoundedWindow currentWindow; - - FnApiDoFnRunner( - PipelineOptions pipelineOptions, - DoFn<InputT, OutputT> doFn, - Collection<ThrowingConsumer<WindowedValue<OutputT>>> mainOutputConsumers, - Multimap<TupleTag<?>, ThrowingConsumer<WindowedValue<?>>> outputMap, - WindowingStrategy windowingStrategy) { - this.pipelineOptions = pipelineOptions; - this.doFn = doFn; - this.mainOutputConsumers = mainOutputConsumers; - this.outputMap = outputMap; - this.doFnInvoker = DoFnInvokers.invokerFor(doFn); - this.startBundleContext = new StartBundleContext(); - this.processBundleContext = new ProcessBundleContext(); - this.finishBundleContext = new FinishBundleContext(); - } - - @Override - public void startBundle() { - doFnInvoker.invokeStartBundle(startBundleContext); - } - - @Override - public void processElement(WindowedValue<InputT> elem) { - currentElement = elem; - try { - Iterator<BoundedWindow> windowIterator = - (Iterator<BoundedWindow>) elem.getWindows().iterator(); - while (windowIterator.hasNext()) { - currentWindow = windowIterator.next(); - doFnInvoker.invokeProcessElement(processBundleContext); - } - } finally { - currentElement = null; - currentWindow = null; - } - } - - @Override - public void onTimer( - String timerId, - BoundedWindow window, - Instant timestamp, - TimeDomain timeDomain) { - throw new UnsupportedOperationException("TODO: Add support for timers"); - } - - @Override - public void finishBundle() { - doFnInvoker.invokeFinishBundle(finishBundleContext); - } - - /** - * Outputs the given element to the specified set of consumers wrapping any exceptions. - */ - private <T> void outputTo( - Collection<ThrowingConsumer<WindowedValue<T>>> consumers, - WindowedValue<T> output) { - Iterator<ThrowingConsumer<WindowedValue<T>>> consumerIterator; - try { - for (ThrowingConsumer<WindowedValue<T>> consumer : consumers) { - consumer.accept(output); - } - } catch (Throwable t) { - throw UserCodeException.wrap(t); - } - } - - /** - * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.StartBundle @StartBundle}. - */ - private class StartBundleContext - extends DoFn<InputT, OutputT>.StartBundleContext - implements DoFnInvoker.ArgumentProvider<InputT, OutputT> { - - private StartBundleContext() { - doFn.super(); - } - - @Override - public PipelineOptions getPipelineOptions() { - return pipelineOptions; - } - - @Override - public PipelineOptions pipelineOptions() { - return pipelineOptions; - } - - @Override - public BoundedWindow window() { - throw new UnsupportedOperationException( - "Cannot access window outside of @ProcessElement and @OnTimer methods."); - } - - @Override - public DoFn<InputT, OutputT>.StartBundleContext startBundleContext( - DoFn<InputT, OutputT> doFn) { - return this; - } - - @Override - public DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext( - DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access FinishBundleContext outside of @FinishBundle method."); - } - - @Override - public DoFn<InputT, OutputT>.ProcessContext processContext(DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access ProcessContext outside of @ProcessElement method."); - } - - @Override - public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access OnTimerContext outside of @OnTimer methods."); - } - - @Override - public RestrictionTracker<?> restrictionTracker() { - throw new UnsupportedOperationException( - "Cannot access RestrictionTracker outside of @ProcessElement method."); - } - - @Override - public State state(String stateId) { - throw new UnsupportedOperationException( - "Cannot access state outside of @ProcessElement and @OnTimer methods."); - } - - @Override - public Timer timer(String timerId) { - throw new UnsupportedOperationException( - "Cannot access timers outside of @ProcessElement and @OnTimer methods."); - } - } - - /** - * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.ProcessElement @ProcessElement}. - */ - private class ProcessBundleContext - extends DoFn<InputT, OutputT>.ProcessContext - implements DoFnInvoker.ArgumentProvider<InputT, OutputT> { - - private ProcessBundleContext() { - doFn.super(); - } - - @Override - public BoundedWindow window() { - return currentWindow; - } - - @Override - public DoFn.StartBundleContext startBundleContext(DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access StartBundleContext outside of @StartBundle method."); - } - - @Override - public DoFn.FinishBundleContext finishBundleContext(DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access FinishBundleContext outside of @FinishBundle method."); - } - - @Override - public ProcessContext processContext(DoFn<InputT, OutputT> doFn) { - return this; - } - - @Override - public OnTimerContext onTimerContext(DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException("TODO: Add support for timers"); - } - - @Override - public RestrictionTracker<?> restrictionTracker() { - throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn"); - } - - @Override - public State state(String stateId) { - throw new UnsupportedOperationException("TODO: Add support for state"); - } - - @Override - public Timer timer(String timerId) { - throw new UnsupportedOperationException("TODO: Add support for timers"); - } - - @Override - public PipelineOptions getPipelineOptions() { - return pipelineOptions; - } - - @Override - public PipelineOptions pipelineOptions() { - return pipelineOptions; - } - - @Override - public void output(OutputT output) { - outputTo(mainOutputConsumers, - WindowedValue.of( - output, - currentElement.getTimestamp(), - currentWindow, - currentElement.getPane())); - } - - @Override - public void outputWithTimestamp(OutputT output, Instant timestamp) { - outputTo(mainOutputConsumers, - WindowedValue.of( - output, - timestamp, - currentWindow, - currentElement.getPane())); - } - - @Override - public <T> void output(TupleTag<T> tag, T output) { - Collection<ThrowingConsumer<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag); - if (consumers == null) { - throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); - } - outputTo(consumers, - WindowedValue.of( - output, - currentElement.getTimestamp(), - currentWindow, - currentElement.getPane())); - } - - @Override - public <T> void outputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) { - Collection<ThrowingConsumer<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag); - if (consumers == null) { - throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); - } - outputTo(consumers, - WindowedValue.of( - output, - timestamp, - currentWindow, - currentElement.getPane())); - } - - @Override - public InputT element() { - return currentElement.getValue(); - } - - @Override - public <T> T sideInput(PCollectionView<T> view) { - throw new UnsupportedOperationException("TODO: Support side inputs"); - } - - @Override - public Instant timestamp() { - return currentElement.getTimestamp(); - } - - @Override - public PaneInfo pane() { - return currentElement.getPane(); - } - - @Override - public void updateWatermark(Instant watermark) { - throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn"); - } - } - - /** - * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.FinishBundle @FinishBundle}. - */ - private class FinishBundleContext - extends DoFn<InputT, OutputT>.FinishBundleContext - implements DoFnInvoker.ArgumentProvider<InputT, OutputT> { - - private FinishBundleContext() { - doFn.super(); - } - - @Override - public PipelineOptions getPipelineOptions() { - return pipelineOptions; - } - - @Override - public PipelineOptions pipelineOptions() { - return pipelineOptions; - } - - @Override - public BoundedWindow window() { - throw new UnsupportedOperationException( - "Cannot access window outside of @ProcessElement and @OnTimer methods."); - } - - @Override - public DoFn<InputT, OutputT>.StartBundleContext startBundleContext( - DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access StartBundleContext outside of @StartBundle method."); - } - - @Override - public DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext( - DoFn<InputT, OutputT> doFn) { - return this; - } - - @Override - public DoFn<InputT, OutputT>.ProcessContext processContext(DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access ProcessContext outside of @ProcessElement method."); - } - - @Override - public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT> doFn) { - throw new UnsupportedOperationException( - "Cannot access OnTimerContext outside of @OnTimer methods."); - } - - @Override - public RestrictionTracker<?> restrictionTracker() { - throw new UnsupportedOperationException( - "Cannot access RestrictionTracker outside of @ProcessElement method."); - } - - @Override - public State state(String stateId) { - throw new UnsupportedOperationException( - "Cannot access state outside of @ProcessElement and @OnTimer methods."); - } - - @Override - public Timer timer(String timerId) { - throw new UnsupportedOperationException( - "Cannot access timers outside of @ProcessElement and @OnTimer methods."); - } - - @Override - public void output(OutputT output, Instant timestamp, BoundedWindow window) { - outputTo(mainOutputConsumers, - WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); - } - - @Override - public <T> void output(TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) { - Collection<ThrowingConsumer<WindowedValue<T>>> consumers = (Collection) outputMap.get(tag); - if (consumers == null) { - throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); - } - outputTo(consumers, - WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); - } - } -} http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java deleted file mode 100644 index b325db4..0000000 --- a/sdks/java/harness/src/main/java/org/apache/beam/runners/core/PTransformRunnerFactory.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.core; - -import com.google.common.collect.Multimap; -import java.io.IOException; -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Supplier; -import org.apache.beam.fn.harness.data.BeamFnDataClient; -import org.apache.beam.fn.harness.fn.ThrowingConsumer; -import org.apache.beam.fn.harness.fn.ThrowingRunnable; -import org.apache.beam.sdk.common.runner.v1.RunnerApi; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.util.WindowedValue; - -/** - * A factory able to instantiate an appropriate handler for a given PTransform. - */ -public interface PTransformRunnerFactory<T> { - - /** - * Creates and returns a handler for a given PTransform. Note that the handler must support - * processing multiple bundles. The handler will be discarded if an error is thrown during - * element processing, or during execution of start/finish. - * - * @param pipelineOptions Pipeline options - * @param beamFnDataClient - * @param pTransformId The id of the PTransform. - * @param pTransform The PTransform definition. - * @param processBundleInstructionId A supplier containing the active process bundle instruction - * id. - * @param pCollections A mapping from PCollection id to PCollection definition. - * @param coders A mapping from coder id to coder definition. - * @param pCollectionIdsToConsumers A mapping from PCollection id to a collection of consumers. - * Note that if this handler is a consumer, it should register itself within this multimap under - * the appropriate PCollection ids. Also note that all output consumers needed by this PTransform - * (based on the values of the {@link RunnerApi.PTransform#getOutputsMap()} will have already - * registered within this multimap. - * @param addStartFunction A consumer to register a start bundle handler with. - * @param addFinishFunction A consumer to register a finish bundle handler with. - */ - T createRunnerForPTransform( - PipelineOptions pipelineOptions, - BeamFnDataClient beamFnDataClient, - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier<String> processBundleInstructionId, - Map<String, RunnerApi.PCollection> pCollections, - Map<String, RunnerApi.Coder> coders, - Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, - Consumer<ThrowingRunnable> addStartFunction, - Consumer<ThrowingRunnable> addFinishFunction) throws IOException; - - /** - * A registrar which can return a mapping from {@link RunnerApi.FunctionSpec#getUrn()} to - * a factory capable of instantiating an appropriate handler. - */ - interface Registrar { - /** - * Returns a mapping from {@link RunnerApi.FunctionSpec#getUrn()} to a factory capable of - * instantiating an appropriate handler. - */ - Map<String, PTransformRunnerFactory> getPTransformRunnerFactories(); - } -} http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index a616b2c..562f91f 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -18,28 +18,62 @@ package org.apache.beam.fn.harness.control; +import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow; +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Suppliers; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; import com.google.protobuf.Message; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.data.BeamFnDataClient; +import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingConsumer; import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.v1.BeamFnApi; -import org.apache.beam.runners.core.PTransformRunnerFactory; +import org.apache.beam.runners.dataflow.util.CloudObjects; +import org.apache.beam.runners.dataflow.util.DoFnInfo; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; -import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.io.CountingSource; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -48,14 +82,55 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.Captor; +import org.mockito.Matchers; import org.mockito.Mock; import org.mockito.MockitoAnnotations; /** Tests for {@link ProcessBundleHandler}. */ @RunWith(JUnit4.class) public class ProcessBundleHandlerTest { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private static final Coder<WindowedValue<String>> STRING_CODER = + WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); + private static final String LONG_CODER_SPEC_ID = "998L"; + private static final String STRING_CODER_SPEC_ID = "999L"; + private static final BeamFnApi.RemoteGrpcPort REMOTE_PORT = BeamFnApi.RemoteGrpcPort.newBuilder() + .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.newBuilder() + .setId("58L") + .setUrl("TestUrl")) + .build(); + private static final RunnerApi.Coder LONG_CODER_SPEC; + private static final RunnerApi.Coder STRING_CODER_SPEC; + static { + try { + STRING_CODER_SPEC = RunnerApi.Coder.newBuilder() + .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder() + .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(STRING_CODER)))) + .build()))) + .build()) + .build(); + LONG_CODER_SPEC = RunnerApi.Coder.newBuilder() + .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder() + .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom( + OBJECT_MAPPER.writeValueAsBytes( + CloudObjects.asCloudObject(WindowedValue.getFullCoder(VarLongCoder.of(), + GlobalWindow.Coder.INSTANCE))))) + .build()))) + .build()) + .build(); + } catch (IOException e) { + throw new ExceptionInInitializerError(e); + } + } + private static final String DATA_INPUT_URN = "urn:org.apache.beam:source:runner:0.1"; private static final String DATA_OUTPUT_URN = "urn:org.apache.beam:sink:runner:0.1"; + private static final String JAVA_DO_FN_URN = "urn:org.apache.beam:dofn:java:0.1"; + private static final String JAVA_SOURCE_URN = "urn:org.apache.beam:source:java:0.1"; @Rule public ExpectedException thrown = ExpectedException.none(); @@ -86,16 +161,16 @@ public class ProcessBundleHandlerTest { List<RunnerApi.PTransform> transformsProcessed = new ArrayList<>(); List<String> orderOfOperations = new ArrayList<>(); - PTransformRunnerFactory<Object> startFinishRecorder = new PTransformRunnerFactory<Object>() { + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient) { @Override - public Object createRunnerForPTransform( - PipelineOptions pipelineOptions, - BeamFnDataClient beamFnDataClient, + protected void createRunnerForPTransform( String pTransformId, RunnerApi.PTransform pTransform, Supplier<String> processBundleInstructionId, Map<String, RunnerApi.PCollection> pCollections, - Map<String, RunnerApi.Coder> coders, Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, Consumer<ThrowingRunnable> addStartFunction, Consumer<ThrowingRunnable> addFinishFunction) throws IOException { @@ -107,18 +182,8 @@ public class ProcessBundleHandlerTest { () -> orderOfOperations.add("Start" + pTransformId)); addFinishFunction.accept( () -> orderOfOperations.add("Finish" + pTransformId)); - return null; } }; - - ProcessBundleHandler handler = new ProcessBundleHandler( - PipelineOptionsFactory.create(), - fnApiRegistry::get, - beamFnDataClient, - ImmutableMap.of( - DATA_INPUT_URN, startFinishRecorder, - DATA_OUTPUT_URN, startFinishRecorder)); - handler.processBundle(BeamFnApi.InstructionRequest.newBuilder() .setInstructionId("999L") .setProcessBundle( @@ -146,25 +211,21 @@ public class ProcessBundleHandlerTest { ProcessBundleHandler handler = new ProcessBundleHandler( PipelineOptionsFactory.create(), fnApiRegistry::get, - beamFnDataClient, - ImmutableMap.of(DATA_INPUT_URN, new PTransformRunnerFactory<Object>() { - @Override - public Object createRunnerForPTransform( - PipelineOptions pipelineOptions, - BeamFnDataClient beamFnDataClient, - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier<String> processBundleInstructionId, - Map<String, RunnerApi.PCollection> pCollections, - Map<String, RunnerApi.Coder> coders, - Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, - Consumer<ThrowingRunnable> addStartFunction, - Consumer<ThrowingRunnable> addFinishFunction) throws IOException { - thrown.expect(IllegalStateException.class); - thrown.expectMessage("TestException"); - throw new IllegalStateException("TestException"); - } - })); + beamFnDataClient) { + @Override + protected void createRunnerForPTransform( + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier<String> processBundleInstructionId, + Map<String, RunnerApi.PCollection> pCollections, + Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, + Consumer<ThrowingRunnable> addStartFunction, + Consumer<ThrowingRunnable> addFinishFunction) throws IOException { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("TestException"); + throw new IllegalStateException("TestException"); + } + }; handler.processBundle( BeamFnApi.InstructionRequest.newBuilder().setProcessBundle( BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference("1L")) @@ -184,26 +245,25 @@ public class ProcessBundleHandlerTest { ProcessBundleHandler handler = new ProcessBundleHandler( PipelineOptionsFactory.create(), fnApiRegistry::get, - beamFnDataClient, - ImmutableMap.of(DATA_INPUT_URN, new PTransformRunnerFactory<Object>() { - @Override - public Object createRunnerForPTransform( - PipelineOptions pipelineOptions, - BeamFnDataClient beamFnDataClient, - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier<String> processBundleInstructionId, - Map<String, RunnerApi.PCollection> pCollections, - Map<String, RunnerApi.Coder> coders, - Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, - Consumer<ThrowingRunnable> addStartFunction, - Consumer<ThrowingRunnable> addFinishFunction) throws IOException { - thrown.expect(IllegalStateException.class); - thrown.expectMessage("TestException"); - addStartFunction.accept(ProcessBundleHandlerTest::throwException); - return null; - } - })); + beamFnDataClient) { + @Override + protected void createRunnerForPTransform( + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier<String> processBundleInstructionId, + Map<String, RunnerApi.PCollection> pCollections, + Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, + Consumer<ThrowingRunnable> addStartFunction, + Consumer<ThrowingRunnable> addFinishFunction) throws IOException { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("TestException"); + addStartFunction.accept(this::throwException); + } + + private void throwException() { + throw new IllegalStateException("TestException"); + } + }; handler.processBundle( BeamFnApi.InstructionRequest.newBuilder().setProcessBundle( BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference("1L")) @@ -223,33 +283,338 @@ public class ProcessBundleHandlerTest { ProcessBundleHandler handler = new ProcessBundleHandler( PipelineOptionsFactory.create(), fnApiRegistry::get, - beamFnDataClient, - ImmutableMap.of(DATA_INPUT_URN, new PTransformRunnerFactory<Object>() { - @Override - public Object createRunnerForPTransform( - PipelineOptions pipelineOptions, - BeamFnDataClient beamFnDataClient, - String pTransformId, - RunnerApi.PTransform pTransform, - Supplier<String> processBundleInstructionId, - Map<String, RunnerApi.PCollection> pCollections, - Map<String, RunnerApi.Coder> coders, - Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, - Consumer<ThrowingRunnable> addStartFunction, - Consumer<ThrowingRunnable> addFinishFunction) throws IOException { - thrown.expect(IllegalStateException.class); - thrown.expectMessage("TestException"); - addFinishFunction.accept(ProcessBundleHandlerTest::throwException); - return null; - } - })); + beamFnDataClient) { + @Override + protected void createRunnerForPTransform( + String pTransformId, + RunnerApi.PTransform pTransform, + Supplier<String> processBundleInstructionId, + Map<String, RunnerApi.PCollection> pCollections, + Multimap<String, ThrowingConsumer<WindowedValue<?>>> pCollectionIdsToConsumers, + Consumer<ThrowingRunnable> addStartFunction, + Consumer<ThrowingRunnable> addFinishFunction) throws IOException { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("TestException"); + addFinishFunction.accept(this::throwException); + } + + private void throwException() { + throw new IllegalStateException("TestException"); + } + }; handler.processBundle( BeamFnApi.InstructionRequest.newBuilder().setProcessBundle( BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorReference("1L")) .build()); } - private static void throwException() { - throw new IllegalStateException("TestException"); + private static class TestDoFn extends DoFn<String, String> { + private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput"); + private static final TupleTag<String> additionalOutput = new TupleTag<>("output"); + + private BoundedWindow window; + + @ProcessElement + public void processElement(ProcessContext context, BoundedWindow window) { + context.output("MainOutput" + context.element()); + context.output(additionalOutput, "AdditionalOutput" + context.element()); + this.window = window; + } + + @FinishBundle + public void finishBundle(FinishBundleContext context) { + if (window != null) { + context.output("FinishBundle", window.maxTimestamp(), window); + window = null; + } + } + } + + /** + * Create a DoFn that has 3 inputs (inputATarget1, inputATarget2, inputBTarget) and 2 outputs + * (mainOutput, output). Validate that inputs are fed to the {@link DoFn} and that outputs + * are directed to the correct consumers. + */ + @Test + public void testCreatingAndProcessingDoFn() throws Exception { + Map<String, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); + String pTransformId = "100L"; + String mainOutputId = "101"; + String additionalOutputId = "102"; + + DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn( + new TestDoFn(), + WindowingStrategy.globalDefault(), + ImmutableList.of(), + StringUtf8Coder.of(), + Long.parseLong(mainOutputId), + ImmutableMap.of( + Long.parseLong(mainOutputId), TestDoFn.mainOutput, + Long.parseLong(additionalOutputId), TestDoFn.additionalOutput)); + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() + .setUrn(JAVA_DO_FN_URN) + .setParameter(Any.pack(BytesValue.newBuilder() + .setValue(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo))) + .build())) + .build(); + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs("inputA", "inputATarget") + .putInputs("inputB", "inputBTarget") + .putOutputs(mainOutputId, "mainOutputTarget") + .putOutputs(additionalOutputId, "additionalOutputTarget") + .build(); + + List<WindowedValue<String>> mainOutputValues = new ArrayList<>(); + List<WindowedValue<String>> additionalOutputValues = new ArrayList<>(); + Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create(); + consumers.put("mainOutputTarget", + (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) mainOutputValues::add); + consumers.put("additionalOutputTarget", + (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) additionalOutputValues::add); + List<ThrowingRunnable> startFunctions = new ArrayList<>(); + List<ThrowingRunnable> finishFunctions = new ArrayList<>(); + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient); + handler.createRunnerForPTransform( + pTransformId, + pTransform, + Suppliers.ofInstance("57L")::get, + ImmutableMap.of(), + consumers, + startFunctions::add, + finishFunctions::add); + + Iterables.getOnlyElement(startFunctions).run(); + mainOutputValues.clear(); + + assertThat(consumers.keySet(), containsInAnyOrder( + "inputATarget", "inputBTarget", "mainOutputTarget", "additionalOutputTarget")); + + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A1")); + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A2")); + Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("B")); + assertThat(mainOutputValues, contains( + valueInGlobalWindow("MainOutputA1"), + valueInGlobalWindow("MainOutputA2"), + valueInGlobalWindow("MainOutputB"))); + assertThat(additionalOutputValues, contains( + valueInGlobalWindow("AdditionalOutputA1"), + valueInGlobalWindow("AdditionalOutputA2"), + valueInGlobalWindow("AdditionalOutputB"))); + mainOutputValues.clear(); + additionalOutputValues.clear(); + + Iterables.getOnlyElement(finishFunctions).run(); + assertThat( + mainOutputValues, + contains( + timestampedValueInGlobalWindow("FinishBundle", GlobalWindow.INSTANCE.maxTimestamp()))); + mainOutputValues.clear(); + } + + @Test + public void testCreatingAndProcessingSource() throws Exception { + Map<String, Message> fnApiRegistry = ImmutableMap.of(LONG_CODER_SPEC_ID, LONG_CODER_SPEC); + List<WindowedValue<String>> outputValues = new ArrayList<>(); + + Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create(); + consumers.put("outputPC", + (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) outputValues::add); + List<ThrowingRunnable> startFunctions = new ArrayList<>(); + List<ThrowingRunnable> finishFunctions = new ArrayList<>(); + + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() + .setUrn(JAVA_SOURCE_URN) + .setParameter(Any.pack(BytesValue.newBuilder() + .setValue(ByteString.copyFrom( + SerializableUtils.serializeToByteArray(CountingSource.upTo(3)))) + .build())) + .build(); + + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs("input", "inputPC") + .putOutputs("output", "outputPC") + .build(); + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient); + + handler.createRunnerForPTransform( + "pTransformId", + pTransform, + Suppliers.ofInstance("57L")::get, + ImmutableMap.of(), + consumers, + startFunctions::add, + finishFunctions::add); + + // This is testing a deprecated way of running sources and should be removed + // once all source definitions are instead propagated along the input edge. + Iterables.getOnlyElement(startFunctions).run(); + assertThat(outputValues, contains( + valueInGlobalWindow(0L), + valueInGlobalWindow(1L), + valueInGlobalWindow(2L))); + outputValues.clear(); + + // Check that when passing a source along as an input, the source is processed. + assertThat(consumers.keySet(), containsInAnyOrder("inputPC", "outputPC")); + Iterables.getOnlyElement(consumers.get("inputPC")).accept( + valueInGlobalWindow(CountingSource.upTo(2))); + assertThat(outputValues, contains( + valueInGlobalWindow(0L), + valueInGlobalWindow(1L))); + + assertThat(finishFunctions, empty()); + } + + @Test + public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { + Map<String, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); + String bundleId = "57"; + String outputId = "101"; + + List<WindowedValue<String>> outputValues = new ArrayList<>(); + + Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create(); + consumers.put("outputPC", + (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) outputValues::add); + List<ThrowingRunnable> startFunctions = new ArrayList<>(); + List<ThrowingRunnable> finishFunctions = new ArrayList<>(); + + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() + .setUrn(DATA_INPUT_URN) + .setParameter(Any.pack(REMOTE_PORT)) + .build(); + + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putOutputs(outputId, "outputPC") + .build(); + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient); + + handler.createRunnerForPTransform( + "pTransformId", + pTransform, + Suppliers.ofInstance(bundleId)::get, + ImmutableMap.of("outputPC", + RunnerApi.PCollection.newBuilder().setCoderId(STRING_CODER_SPEC_ID).build()), + consumers, + startFunctions::add, + finishFunctions::add); + + verifyZeroInteractions(beamFnDataClient); + + CompletableFuture<Void> completionFuture = new CompletableFuture<>(); + when(beamFnDataClient.forInboundConsumer(any(), any(), any(), any())) + .thenReturn(completionFuture); + Iterables.getOnlyElement(startFunctions).run(); + verify(beamFnDataClient).forInboundConsumer( + eq(REMOTE_PORT.getApiServiceDescriptor()), + eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference("pTransformId") + .setName(outputId) + .build())), + eq(STRING_CODER), + consumerCaptor.capture()); + + consumerCaptor.getValue().accept(valueInGlobalWindow("TestValue")); + assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); + outputValues.clear(); + + assertThat(consumers.keySet(), containsInAnyOrder("outputPC")); + + completionFuture.complete(null); + Iterables.getOnlyElement(finishFunctions).run(); + + verifyNoMoreInteractions(beamFnDataClient); + } + + @Test + public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception { + Map<String, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC); + String bundleId = "57L"; + String inputId = "100L"; + + Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create(); + List<ThrowingRunnable> startFunctions = new ArrayList<>(); + List<ThrowingRunnable> finishFunctions = new ArrayList<>(); + + RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() + .setUrn(DATA_OUTPUT_URN) + .setParameter(Any.pack(REMOTE_PORT)) + .build(); + + RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() + .setSpec(functionSpec) + .putInputs(inputId, "inputPC") + .build(); + + ProcessBundleHandler handler = new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient); + + handler.createRunnerForPTransform( + "ptransformId", + pTransform, + Suppliers.ofInstance(bundleId)::get, + ImmutableMap.of("inputPC", + RunnerApi.PCollection.newBuilder().setCoderId(STRING_CODER_SPEC_ID).build()), + consumers, + startFunctions::add, + finishFunctions::add); + + verifyZeroInteractions(beamFnDataClient); + + List<WindowedValue<String>> outputValues = new ArrayList<>(); + AtomicBoolean wasCloseCalled = new AtomicBoolean(); + CloseableThrowingConsumer<WindowedValue<String>> outputConsumer = + new CloseableThrowingConsumer<WindowedValue<String>>(){ + @Override + public void close() throws Exception { + wasCloseCalled.set(true); + } + + @Override + public void accept(WindowedValue<String> t) throws Exception { + outputValues.add(t); + } + }; + + when(beamFnDataClient.forOutboundConsumer( + any(), + any(), + Matchers.<Coder<WindowedValue<String>>>any())).thenReturn(outputConsumer); + Iterables.getOnlyElement(startFunctions).run(); + verify(beamFnDataClient).forOutboundConsumer( + eq(REMOTE_PORT.getApiServiceDescriptor()), + eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() + .setPrimitiveTransformReference("ptransformId") + .setName(inputId) + .build())), + eq(STRING_CODER)); + + assertThat(consumers.keySet(), containsInAnyOrder("inputPC")); + Iterables.getOnlyElement(consumers.get("inputPC")).accept(valueInGlobalWindow("TestValue")); + assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); + outputValues.clear(); + + assertFalse(wasCloseCalled.get()); + Iterables.getOnlyElement(finishFunctions).run(); + assertTrue(wasCloseCalled.get()); + + verifyNoMoreInteractions(beamFnDataClient); } } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java index 2b275af..b1f4410 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/RegisterHandlerTest.java @@ -44,14 +44,14 @@ public class RegisterHandlerTest { .setRegister(BeamFnApi.RegisterRequest.newBuilder() .addProcessBundleDescriptor(BeamFnApi.ProcessBundleDescriptor.newBuilder() .setId("1L") - .putCoders("10L", RunnerApi.Coder.newBuilder() + .putCodersyyy("10L", RunnerApi.Coder.newBuilder() .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("urn:10L").build()) .build()) .build()) .build()) .addProcessBundleDescriptor(BeamFnApi.ProcessBundleDescriptor.newBuilder().setId("2L") - .putCoders("20L", RunnerApi.Coder.newBuilder() + .putCodersyyy("20L", RunnerApi.Coder.newBuilder() .setSpec(RunnerApi.SdkFunctionSpec.newBuilder() .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("urn:20L").build()) .build()) @@ -82,10 +82,10 @@ public class RegisterHandlerTest { assertEquals(REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(1), handler.getById("2L")); assertEquals( - REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(0).getCodersOrThrow("10L"), + REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(0).getCodersyyyOrThrow("10L"), handler.getById("10L")); assertEquals( - REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(1).getCodersOrThrow("20L"), + REGISTER_REQUEST.getRegister().getProcessBundleDescriptor(1).getCodersyyyOrThrow("20L"), handler.getById("20L")); assertEquals(REGISTER_RESPONSE, responseFuture.get()); } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java index d6a476e..7e8ab1a 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataReadRunnerTest.java @@ -20,51 +20,41 @@ package org.apache.beam.runners.core; import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.Assert.assertThat; -import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.base.Suppliers; -import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterables; -import com.google.common.collect.Multimap; import com.google.common.util.concurrent.Uninterruptibles; import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.BytesValue; import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.List; -import java.util.ServiceLoader; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fn.ThrowingConsumer; -import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.harness.test.TestExecutors; import org.apache.beam.fn.harness.test.TestExecutors.TestExecutorService; import org.apache.beam.fn.v1.BeamFnApi; -import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar; import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; -import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; -import org.hamcrest.collection.IsMapContaining; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -78,18 +68,15 @@ import org.mockito.MockitoAnnotations; /** Tests for {@link BeamFnDataReadRunner}. */ @RunWith(JUnit4.class) public class BeamFnDataReadRunnerTest { - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder() .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build(); private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder() .setParameter(Any.pack(PORT_SPEC)).build(); private static final Coder<WindowedValue<String>> CODER = WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); - private static final String CODER_SPEC_ID = "string-coder-id"; private static final RunnerApi.Coder CODER_SPEC; - private static final String URN = "urn:org.apache.beam:source:runner:0.1"; - static { try { CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec( @@ -111,7 +98,7 @@ public class BeamFnDataReadRunnerTest { .build(); @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool); - @Mock private BeamFnDataClient mockBeamFnDataClient; + @Mock private BeamFnDataClient mockBeamFnDataClientFactory; @Captor private ArgumentCaptor<ThrowingConsumer<WindowedValue<String>>> consumerCaptor; @Before @@ -120,93 +107,32 @@ public class BeamFnDataReadRunnerTest { } @Test - public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { - String bundleId = "57"; - String outputId = "101"; - - List<WindowedValue<String>> outputValues = new ArrayList<>(); - - Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create(); - consumers.put("outputPC", - (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) outputValues::add); - List<ThrowingRunnable> startFunctions = new ArrayList<>(); - List<ThrowingRunnable> finishFunctions = new ArrayList<>(); - - RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() - .setUrn("urn:org.apache.beam:source:runner:0.1") - .setParameter(Any.pack(PORT_SPEC)) - .build(); - - RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() - .setSpec(functionSpec) - .putOutputs(outputId, "outputPC") - .build(); - - new BeamFnDataReadRunner.Factory<String>().createRunnerForPTransform( - PipelineOptionsFactory.create(), - mockBeamFnDataClient, - "pTransformId", - pTransform, - Suppliers.ofInstance(bundleId)::get, - ImmutableMap.of("outputPC", - RunnerApi.PCollection.newBuilder().setCoderId(CODER_SPEC_ID).build()), - ImmutableMap.of(CODER_SPEC_ID, CODER_SPEC), - consumers, - startFunctions::add, - finishFunctions::add); - - verifyZeroInteractions(mockBeamFnDataClient); - - CompletableFuture<Void> completionFuture = new CompletableFuture<>(); - when(mockBeamFnDataClient.forInboundConsumer(any(), any(), any(), any())) - .thenReturn(completionFuture); - Iterables.getOnlyElement(startFunctions).run(); - verify(mockBeamFnDataClient).forInboundConsumer( - eq(PORT_SPEC.getApiServiceDescriptor()), - eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference("pTransformId") - .setName(outputId) - .build())), - eq(CODER), - consumerCaptor.capture()); - - consumerCaptor.getValue().accept(valueInGlobalWindow("TestValue")); - assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); - outputValues.clear(); - - assertThat(consumers.keySet(), containsInAnyOrder("outputPC")); - - completionFuture.complete(null); - Iterables.getOnlyElement(finishFunctions).run(); - - verifyNoMoreInteractions(mockBeamFnDataClient); - } - - @Test public void testReuseForMultipleBundles() throws Exception { CompletableFuture<Void> bundle1Future = new CompletableFuture<>(); CompletableFuture<Void> bundle2Future = new CompletableFuture<>(); - when(mockBeamFnDataClient.forInboundConsumer( + when(mockBeamFnDataClientFactory.forInboundConsumer( any(), any(), any(), any())).thenReturn(bundle1Future).thenReturn(bundle2Future); List<WindowedValue<String>> valuesA = new ArrayList<>(); List<WindowedValue<String>> valuesB = new ArrayList<>(); - + Map<String, Collection<ThrowingConsumer<WindowedValue<String>>>> outputMap = ImmutableMap.of( + "outA", ImmutableList.of(valuesA::add), + "outB", ImmutableList.of(valuesB::add)); AtomicReference<String> bundleId = new AtomicReference<>("0"); BeamFnDataReadRunner<String> readRunner = new BeamFnDataReadRunner<>( FUNCTION_SPEC, bundleId::get, INPUT_TARGET, CODER_SPEC, - mockBeamFnDataClient, - ImmutableList.of(valuesA::add, valuesB::add)); + mockBeamFnDataClientFactory, + outputMap); // Process for bundle id 0 readRunner.registerInputLocation(); - verify(mockBeamFnDataClient).forInboundConsumer( + verify(mockBeamFnDataClientFactory).forInboundConsumer( eq(PORT_SPEC.getApiServiceDescriptor()), eq(KV.of(bundleId.get(), INPUT_TARGET)), eq(CODER), @@ -238,7 +164,7 @@ public class BeamFnDataReadRunnerTest { valuesB.clear(); readRunner.registerInputLocation(); - verify(mockBeamFnDataClient).forInboundConsumer( + verify(mockBeamFnDataClientFactory).forInboundConsumer( eq(PORT_SPEC.getApiServiceDescriptor()), eq(KV.of(bundleId.get(), INPUT_TARGET)), eq(CODER), @@ -264,18 +190,6 @@ public class BeamFnDataReadRunnerTest { assertThat(valuesA, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); - verifyNoMoreInteractions(mockBeamFnDataClient); - } - - @Test - public void testRegistration() { - for (Registrar registrar : - ServiceLoader.load(Registrar.class)) { - if (registrar instanceof BeamFnDataReadRunner.Registrar) { - assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN)); - return; - } - } - fail("Expected registrar not found."); + verifyNoMoreInteractions(mockBeamFnDataClientFactory); } } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java index 64d9ea6..a3c874e 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java @@ -20,48 +20,31 @@ package org.apache.beam.runners.core; import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.base.Suppliers; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterables; -import com.google.common.collect.Multimap; import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.BytesValue; import java.io.IOException; import java.util.ArrayList; -import java.util.List; -import java.util.ServiceLoader; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer; -import org.apache.beam.fn.harness.fn.ThrowingConsumer; -import org.apache.beam.fn.harness.fn.ThrowingRunnable; import org.apache.beam.fn.v1.BeamFnApi; -import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar; import org.apache.beam.runners.dataflow.util.CloudObjects; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.common.runner.v1.RunnerApi; -import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; -import org.hamcrest.collection.IsMapContaining; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -73,18 +56,15 @@ import org.mockito.MockitoAnnotations; /** Tests for {@link BeamFnDataWriteRunner}. */ @RunWith(JUnit4.class) public class BeamFnDataWriteRunnerTest { - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder() .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build(); private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder() .setParameter(Any.pack(PORT_SPEC)).build(); - private static final String CODER_ID = "string-coder-id"; private static final Coder<WindowedValue<String>> CODER = WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE); private static final RunnerApi.Coder CODER_SPEC; - private static final String URN = "urn:org.apache.beam:sink:runner:0.1"; - static { try { CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec( @@ -105,93 +85,18 @@ public class BeamFnDataWriteRunnerTest { .setName("out") .build(); - @Mock private BeamFnDataClient mockBeamFnDataClient; + @Mock private BeamFnDataClient mockBeamFnDataClientFactory; @Before public void setUp() { MockitoAnnotations.initMocks(this); } - - @Test - public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception { - String bundleId = "57L"; - String inputId = "100L"; - - Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create(); - List<ThrowingRunnable> startFunctions = new ArrayList<>(); - List<ThrowingRunnable> finishFunctions = new ArrayList<>(); - - RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() - .setUrn("urn:org.apache.beam:sink:runner:0.1") - .setParameter(Any.pack(PORT_SPEC)) - .build(); - - RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder() - .setSpec(functionSpec) - .putInputs(inputId, "inputPC") - .build(); - - new BeamFnDataWriteRunner.Factory<String>().createRunnerForPTransform( - PipelineOptionsFactory.create(), - mockBeamFnDataClient, - "ptransformId", - pTransform, - Suppliers.ofInstance(bundleId)::get, - ImmutableMap.of("inputPC", - RunnerApi.PCollection.newBuilder().setCoderId(CODER_ID).build()), - ImmutableMap.of(CODER_ID, CODER_SPEC), - consumers, - startFunctions::add, - finishFunctions::add); - - verifyZeroInteractions(mockBeamFnDataClient); - - List<WindowedValue<String>> outputValues = new ArrayList<>(); - AtomicBoolean wasCloseCalled = new AtomicBoolean(); - CloseableThrowingConsumer<WindowedValue<String>> outputConsumer = - new CloseableThrowingConsumer<WindowedValue<String>>(){ - @Override - public void close() throws Exception { - wasCloseCalled.set(true); - } - - @Override - public void accept(WindowedValue<String> t) throws Exception { - outputValues.add(t); - } - }; - - when(mockBeamFnDataClient.forOutboundConsumer( - any(), - any(), - Matchers.<Coder<WindowedValue<String>>>any())).thenReturn(outputConsumer); - Iterables.getOnlyElement(startFunctions).run(); - verify(mockBeamFnDataClient).forOutboundConsumer( - eq(PORT_SPEC.getApiServiceDescriptor()), - eq(KV.of(bundleId, BeamFnApi.Target.newBuilder() - .setPrimitiveTransformReference("ptransformId") - .setName(inputId) - .build())), - eq(CODER)); - - assertThat(consumers.keySet(), containsInAnyOrder("inputPC")); - Iterables.getOnlyElement(consumers.get("inputPC")).accept(valueInGlobalWindow("TestValue")); - assertThat(outputValues, contains(valueInGlobalWindow("TestValue"))); - outputValues.clear(); - - assertFalse(wasCloseCalled.get()); - Iterables.getOnlyElement(finishFunctions).run(); - assertTrue(wasCloseCalled.get()); - - verifyNoMoreInteractions(mockBeamFnDataClient); - } - @Test public void testReuseForMultipleBundles() throws Exception { RecordingConsumer<WindowedValue<String>> valuesA = new RecordingConsumer<>(); RecordingConsumer<WindowedValue<String>> valuesB = new RecordingConsumer<>(); - when(mockBeamFnDataClient.forOutboundConsumer( + when(mockBeamFnDataClientFactory.forOutboundConsumer( any(), any(), Matchers.<Coder<WindowedValue<String>>>any())).thenReturn(valuesA).thenReturn(valuesB); @@ -201,12 +106,12 @@ public class BeamFnDataWriteRunnerTest { bundleId::get, OUTPUT_TARGET, CODER_SPEC, - mockBeamFnDataClient); + mockBeamFnDataClientFactory); // Process for bundle id 0 writeRunner.registerForOutput(); - verify(mockBeamFnDataClient).forOutboundConsumer( + verify(mockBeamFnDataClientFactory).forOutboundConsumer( eq(PORT_SPEC.getApiServiceDescriptor()), eq(KV.of(bundleId.get(), OUTPUT_TARGET)), eq(CODER)); @@ -224,7 +129,7 @@ public class BeamFnDataWriteRunnerTest { valuesB.clear(); writeRunner.registerForOutput(); - verify(mockBeamFnDataClient).forOutboundConsumer( + verify(mockBeamFnDataClientFactory).forOutboundConsumer( eq(PORT_SPEC.getApiServiceDescriptor()), eq(KV.of(bundleId.get(), OUTPUT_TARGET)), eq(CODER)); @@ -235,7 +140,7 @@ public class BeamFnDataWriteRunnerTest { assertTrue(valuesB.closed); assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL"))); - verifyNoMoreInteractions(mockBeamFnDataClient); + verifyNoMoreInteractions(mockBeamFnDataClientFactory); } private static class RecordingConsumer<T> extends ArrayList<T> @@ -253,17 +158,6 @@ public class BeamFnDataWriteRunnerTest { } add(t); } - } - @Test - public void testRegistration() { - for (Registrar registrar : - ServiceLoader.load(Registrar.class)) { - if (registrar instanceof BeamFnDataWriteRunner.Registrar) { - assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN)); - return; - } - } - fail("Expected registrar not found."); } }
