[BEAM-1498] Use Flink-native side outputs
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/b0601fd4 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/b0601fd4 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/b0601fd4 Branch: refs/heads/master Commit: b0601fd43e0929e8b925dbe566e564460f91d9fc Parents: 88f78fa Author: JingsongLi <[email protected]> Authored: Sun Jun 4 21:56:10 2017 +0800 Committer: Aljoscha Krettek <[email protected]> Committed: Tue Jun 6 14:33:36 2017 +0200 ---------------------------------------------------------------------- .../FlinkStreamingTransformTranslators.java | 145 ++++++------------- .../wrappers/streaming/DoFnOperator.java | 40 +++-- .../wrappers/streaming/WindowDoFnOperator.java | 4 +- .../beam/runners/flink/PipelineOptionsTest.java | 5 +- .../flink/streaming/DoFnOperatorTest.java | 65 +++++---- 5 files changed, 112 insertions(+), 147 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/b0601fd4/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index 00e9934..d8c3049 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -18,9 +18,6 @@ package org.apache.beam.runners.flink; -import static com.google.common.base.Preconditions.checkArgument; - -import com.google.common.collect.Lists; import com.google.common.collect.Maps; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -29,7 +26,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems; import org.apache.beam.runners.core.SystemReduceFn; @@ -84,16 +80,15 @@ import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.GenericTypeInfo; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; -import org.apache.flink.streaming.api.collector.selector.OutputSelector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.DataStreamSource; import org.apache.flink.streaming.api.datastream.KeyedStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; -import org.apache.flink.streaming.api.datastream.SplitStream; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; import org.apache.flink.streaming.api.transformations.TwoInputTransformation; import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; /** * This class contains all the mappings between Beam and Flink @@ -337,7 +332,7 @@ class FlinkStreamingTransformTranslators { static class ParDoTranslationHelper { interface DoFnOperatorFactory<InputT, OutputT> { - DoFnOperator<InputT, OutputT, RawUnionValue> createDoFnOperator( + DoFnOperator<InputT, OutputT, OutputT> createDoFnOperator( DoFn<InputT, OutputT> doFn, String stepName, List<PCollectionView<?>> sideInputs, @@ -345,7 +340,7 @@ class FlinkStreamingTransformTranslators { List<TupleTag<?>> additionalOutputTags, FlinkStreamingTranslationContext context, WindowingStrategy<?, ?> windowingStrategy, - Map<TupleTag<?>, Integer> tagsToLabels, + Map<TupleTag<?>, OutputTag<WindowedValue<?>>> tagsToLabels, Coder<WindowedValue<InputT>> inputCoder, Coder keyCoder, Map<Integer, PCollectionView<?>> transformedSideInputs); @@ -354,7 +349,6 @@ class FlinkStreamingTransformTranslators { static <InputT, OutputT> void translateParDo( String transformName, DoFn<InputT, OutputT> doFn, - String stepName, PCollection<InputT> input, List<PCollectionView<?>> sideInputs, Map<TupleTag<?>, PValue> outputs, @@ -366,10 +360,15 @@ class FlinkStreamingTransformTranslators { // we assume that the transformation does not change the windowing strategy. WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy(); - Map<TupleTag<?>, Integer> tagsToLabels = - transformTupleTagsToLabels(mainOutputTag, outputs); + Map<TupleTag<?>, OutputTag<WindowedValue<?>>> tagsToOutputTags = Maps.newHashMap(); + for (Map.Entry<TupleTag<?>, PValue> entry : outputs.entrySet()) { + if (!tagsToOutputTags.containsKey(entry.getKey())) { + tagsToOutputTags.put(entry.getKey(), new OutputTag<>(entry.getKey().getId(), + (TypeInformation) context.getTypeInfo((PCollection<?>) entry.getValue()))); + } + } - SingleOutputStreamOperator<RawUnionValue> unionOutputStream; + SingleOutputStreamOperator<WindowedValue<OutputT>> outputStream; Coder<WindowedValue<InputT>> inputCoder = context.getCoder(input); @@ -391,8 +390,12 @@ class FlinkStreamingTransformTranslators { stateful = true; } + CoderTypeInformation<WindowedValue<OutputT>> outputTypeInformation = + new CoderTypeInformation<>( + context.getCoder((PCollection<OutputT>) outputs.get(mainOutputTag))); + if (sideInputs.isEmpty()) { - DoFnOperator<InputT, OutputT, RawUnionValue> doFnOperator = + DoFnOperator<InputT, OutputT, OutputT> doFnOperator = doFnOperatorFactory.createDoFnOperator( doFn, context.getCurrentTransform().getFullName(), @@ -401,24 +404,19 @@ class FlinkStreamingTransformTranslators { additionalOutputTags, context, windowingStrategy, - tagsToLabels, + tagsToOutputTags, inputCoder, keyCoder, new HashMap<Integer, PCollectionView<?>>() /* side-input mapping */); - UnionCoder outputUnionCoder = createUnionCoder(outputs); - - CoderTypeInformation<RawUnionValue> outputUnionTypeInformation = - new CoderTypeInformation<>(outputUnionCoder); - - unionOutputStream = inputDataStream - .transform(transformName, outputUnionTypeInformation, doFnOperator); + outputStream = inputDataStream + .transform(transformName, outputTypeInformation, doFnOperator); } else { Tuple2<Map<Integer, PCollectionView<?>>, DataStream<RawUnionValue>> transformedSideInputs = transformSideInputs(sideInputs, context); - DoFnOperator<InputT, OutputT, RawUnionValue> doFnOperator = + DoFnOperator<InputT, OutputT, OutputT> doFnOperator = doFnOperatorFactory.createDoFnOperator( doFn, context.getCurrentTransform().getFullName(), @@ -427,16 +425,11 @@ class FlinkStreamingTransformTranslators { additionalOutputTags, context, windowingStrategy, - tagsToLabels, + tagsToOutputTags, inputCoder, keyCoder, transformedSideInputs.f0); - UnionCoder outputUnionCoder = createUnionCoder(outputs); - - CoderTypeInformation<RawUnionValue> outputUnionTypeInformation = - new CoderTypeInformation<>(outputUnionCoder); - if (stateful) { // we have to manually contruct the two-input transform because we're not // allowed to have only one input keyed, normally. @@ -448,83 +441,35 @@ class FlinkStreamingTransformTranslators { keyedStream.getTransformation(), transformedSideInputs.f1.broadcast().getTransformation(), transformName, - (TwoInputStreamOperator) doFnOperator, - outputUnionTypeInformation, + doFnOperator, + outputTypeInformation, keyedStream.getParallelism()); rawFlinkTransform.setStateKeyType(keyedStream.getKeyType()); rawFlinkTransform.setStateKeySelectors(keyedStream.getKeySelector(), null); - unionOutputStream = new SingleOutputStreamOperator( - keyedStream.getExecutionEnvironment(), - rawFlinkTransform) {}; // we have to cheat around the ctor being protected + outputStream = new SingleOutputStreamOperator( + keyedStream.getExecutionEnvironment(), + rawFlinkTransform) { + }; // we have to cheat around the ctor being protected keyedStream.getExecutionEnvironment().addOperator(rawFlinkTransform); } else { - unionOutputStream = inputDataStream + outputStream = inputDataStream .connect(transformedSideInputs.f1.broadcast()) - .transform(transformName, outputUnionTypeInformation, doFnOperator); + .transform(transformName, outputTypeInformation, doFnOperator); } } - SplitStream<RawUnionValue> splitStream = unionOutputStream - .split(new OutputSelector<RawUnionValue>() { - @Override - public Iterable<String> select(RawUnionValue value) { - return Collections.singletonList(Integer.toString(value.getUnionTag())); - } - }); - - for (Entry<TupleTag<?>, PValue> output : outputs.entrySet()) { - final int outputTag = tagsToLabels.get(output.getKey()); - - TypeInformation outputTypeInfo = context.getTypeInfo((PCollection<?>) output.getValue()); - - @SuppressWarnings("unchecked") - DataStream unwrapped = splitStream.select(String.valueOf(outputTag)) - .flatMap(new FlatMapFunction<RawUnionValue, Object>() { - @Override - public void flatMap(RawUnionValue value, Collector<Object> out) throws Exception { - out.collect(value.getValue()); - } - }).returns(outputTypeInfo); - - context.setOutputDataStream(output.getValue(), unwrapped); - } - } - - private static Map<TupleTag<?>, Integer> transformTupleTagsToLabels( - TupleTag<?> mainTag, - Map<TupleTag<?>, PValue> allTaggedValues) { + context.setOutputDataStream(outputs.get(mainOutputTag), outputStream); - Map<TupleTag<?>, Integer> tagToLabelMap = Maps.newHashMap(); - int count = 0; - tagToLabelMap.put(mainTag, count++); - for (TupleTag<?> key : allTaggedValues.keySet()) { - if (!tagToLabelMap.containsKey(key)) { - tagToLabelMap.put(key, count++); + for (Map.Entry<TupleTag<?>, PValue> entry : outputs.entrySet()) { + if (!entry.getKey().equals(mainOutputTag)) { + context.setOutputDataStream(entry.getValue(), + outputStream.getSideOutput(tagsToOutputTags.get(entry.getKey()))); } } - return tagToLabelMap; - } - - private static UnionCoder createUnionCoder(Map<TupleTag<?>, PValue> taggedCollections) { - List<Coder<?>> outputCoders = Lists.newArrayList(); - for (PValue taggedColl : taggedCollections.values()) { - checkArgument( - taggedColl instanceof PCollection, - "A Union Coder can only be created for a Collection of Tagged %s. Got %s", - PCollection.class.getSimpleName(), - taggedColl.getClass().getSimpleName()); - PCollection<?> coll = (PCollection<?>) taggedColl; - WindowedValue.FullWindowedValueCoder<?> windowedValueCoder = - WindowedValue.getFullCoder( - coll.getCoder(), - coll.getWindowingStrategy().getWindowFn().windowCoder()); - outputCoders.add(windowedValueCoder); - } - return UnionCoder.of(outputCoders); } } @@ -540,7 +485,6 @@ class FlinkStreamingTransformTranslators { ParDoTranslationHelper.translateParDo( transform.getName(), transform.getFn(), - context.getCurrentTransform().getFullName(), (PCollection<InputT>) context.getInput(transform), transform.getSideInputs(), context.getOutputs(transform), @@ -549,7 +493,7 @@ class FlinkStreamingTransformTranslators { context, new ParDoTranslationHelper.DoFnOperatorFactory<InputT, OutputT>() { @Override - public DoFnOperator<InputT, OutputT, RawUnionValue> createDoFnOperator( + public DoFnOperator<InputT, OutputT, OutputT> createDoFnOperator( DoFn<InputT, OutputT> doFn, String stepName, List<PCollectionView<?>> sideInputs, @@ -557,7 +501,7 @@ class FlinkStreamingTransformTranslators { List<TupleTag<?>> additionalOutputTags, FlinkStreamingTranslationContext context, WindowingStrategy<?, ?> windowingStrategy, - Map<TupleTag<?>, Integer> tagsToLabels, + Map<TupleTag<?>, OutputTag<WindowedValue<?>>> tagsToOutputTags, Coder<WindowedValue<InputT>> inputCoder, Coder keyCoder, Map<Integer, PCollectionView<?>> transformedSideInputs) { @@ -567,7 +511,7 @@ class FlinkStreamingTransformTranslators { inputCoder, mainOutputTag, additionalOutputTags, - new DoFnOperator.MultiOutputOutputManagerFactory(tagsToLabels), + new DoFnOperator.MultiOutputOutputManagerFactory(mainOutputTag, tagsToOutputTags), windowingStrategy, transformedSideInputs, sideInputs, @@ -592,7 +536,6 @@ class FlinkStreamingTransformTranslators { ParDoTranslationHelper.translateParDo( transform.getName(), transform.newProcessFn(transform.getFn()), - context.getCurrentTransform().getFullName(), context.getInput(transform), transform.getSideInputs(), context.getOutputs(transform), @@ -604,8 +547,7 @@ class FlinkStreamingTransformTranslators { @Override public DoFnOperator< KeyedWorkItem<String, ElementAndRestriction<InputT, RestrictionT>>, - OutputT, - RawUnionValue> createDoFnOperator( + OutputT, OutputT> createDoFnOperator( DoFn< KeyedWorkItem<String, ElementAndRestriction<InputT, RestrictionT>>, OutputT> doFn, @@ -615,7 +557,7 @@ class FlinkStreamingTransformTranslators { List<TupleTag<?>> additionalOutputTags, FlinkStreamingTranslationContext context, WindowingStrategy<?, ?> windowingStrategy, - Map<TupleTag<?>, Integer> tagsToLabels, + Map<TupleTag<?>, OutputTag<WindowedValue<?>>> tagsToOutputTags, Coder< WindowedValue< KeyedWorkItem< @@ -629,7 +571,7 @@ class FlinkStreamingTransformTranslators { inputCoder, mainOutputTag, additionalOutputTags, - new DoFnOperator.MultiOutputOutputManagerFactory(tagsToLabels), + new DoFnOperator.MultiOutputOutputManagerFactory(mainOutputTag, tagsToOutputTags), windowingStrategy, transformedSideInputs, sideInputs, @@ -756,8 +698,7 @@ class FlinkStreamingTransformTranslators { TypeInformation<WindowedValue<KV<K, Iterable<InputT>>>> outputTypeInfo = context.getTypeInfo(context.getOutput(transform)); - DoFnOperator.DefaultOutputManagerFactory< - WindowedValue<KV<K, Iterable<InputT>>>> outputManagerFactory = + DoFnOperator.DefaultOutputManagerFactory<KV<K, Iterable<InputT>>> outputManagerFactory = new DoFnOperator.DefaultOutputManagerFactory<>(); WindowDoFnOperator<K, InputT, Iterable<InputT>> doFnOperator = @@ -868,7 +809,7 @@ class FlinkStreamingTransformTranslators { (Coder) windowedWorkItemCoder, new TupleTag<KV<K, OutputT>>("main output"), Collections.<TupleTag<?>>emptyList(), - new DoFnOperator.DefaultOutputManagerFactory<WindowedValue<KV<K, OutputT>>>(), + new DoFnOperator.DefaultOutputManagerFactory<KV<K, OutputT>>(), windowingStrategy, new HashMap<Integer, PCollectionView<?>>(), /* side-input mapping */ Collections.<PCollectionView<?>>emptyList(), /* side inputs */ @@ -894,7 +835,7 @@ class FlinkStreamingTransformTranslators { (Coder) windowedWorkItemCoder, new TupleTag<KV<K, OutputT>>("main output"), Collections.<TupleTag<?>>emptyList(), - new DoFnOperator.DefaultOutputManagerFactory<WindowedValue<KV<K, OutputT>>>(), + new DoFnOperator.DefaultOutputManagerFactory<KV<K, OutputT>>(), windowingStrategy, transformSideInputs.f0, sideInputs, http://git-wip-us.apache.org/repos/asf/beam/blob/b0601fd4/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java index 594fe0e..8c27ed9 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java @@ -87,6 +87,7 @@ import org.apache.flink.streaming.api.operators.Triggerable; import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.OutputTag; import org.joda.time.Instant; /** @@ -98,9 +99,9 @@ import org.joda.time.Instant; * type when we have additional tagged outputs */ public class DoFnOperator<InputT, FnOutputT, OutputT> - extends AbstractStreamOperator<OutputT> - implements OneInputStreamOperator<WindowedValue<InputT>, OutputT>, - TwoInputStreamOperator<WindowedValue<InputT>, RawUnionValue, OutputT>, + extends AbstractStreamOperator<WindowedValue<OutputT>> + implements OneInputStreamOperator<WindowedValue<InputT>, WindowedValue<OutputT>>, + TwoInputStreamOperator<WindowedValue<InputT>, RawUnionValue, WindowedValue<OutputT>>, KeyGroupCheckpointedOperator, Triggerable<Object, TimerData> { protected DoFn<InputT, FnOutputT> doFn; @@ -662,7 +663,7 @@ public class DoFnOperator<InputT, FnOutputT, OutputT> * a Flink {@link Output}. */ interface OutputManagerFactory<OutputT> extends Serializable { - DoFnRunners.OutputManager create(Output<StreamRecord<OutputT>> output); + DoFnRunners.OutputManager create(Output<StreamRecord<WindowedValue<OutputT>>> output); } /** @@ -673,14 +674,15 @@ public class DoFnOperator<InputT, FnOutputT, OutputT> public static class DefaultOutputManagerFactory<OutputT> implements OutputManagerFactory<OutputT> { @Override - public DoFnRunners.OutputManager create(final Output<StreamRecord<OutputT>> output) { + public DoFnRunners.OutputManager create( + final Output<StreamRecord<WindowedValue<OutputT>>> output) { return new DoFnRunners.OutputManager() { @Override public <T> void output(TupleTag<T> tag, WindowedValue<T> value) { // with tagged outputs we can't get around this because we don't // know our own output type... @SuppressWarnings("unchecked") - OutputT castValue = (OutputT) value; + WindowedValue<OutputT> castValue = (WindowedValue<OutputT>) value; output.collect(new StreamRecord<>(castValue)); } }; @@ -692,22 +694,34 @@ public class DoFnOperator<InputT, FnOutputT, OutputT> * {@link DoFnRunners.OutputManager} that can write to multiple logical * outputs by unioning them in a {@link RawUnionValue}. */ - public static class MultiOutputOutputManagerFactory - implements OutputManagerFactory<RawUnionValue> { + public static class MultiOutputOutputManagerFactory<OutputT> + implements OutputManagerFactory<OutputT> { - Map<TupleTag<?>, Integer> mapping; + private TupleTag<?> mainTag; + Map<TupleTag<?>, OutputTag<WindowedValue<?>>> mapping; - public MultiOutputOutputManagerFactory(Map<TupleTag<?>, Integer> mapping) { + public MultiOutputOutputManagerFactory( + TupleTag<?> mainTag, + Map<TupleTag<?>, OutputTag<WindowedValue<?>>> mapping) { + this.mainTag = mainTag; this.mapping = mapping; } @Override - public DoFnRunners.OutputManager create(final Output<StreamRecord<RawUnionValue>> output) { + public DoFnRunners.OutputManager create( + final Output<StreamRecord<WindowedValue<OutputT>>> output) { return new DoFnRunners.OutputManager() { @Override public <T> void output(TupleTag<T> tag, WindowedValue<T> value) { - int intTag = mapping.get(tag); - output.collect(new StreamRecord<>(new RawUnionValue(intTag, value))); + if (tag.equals(mainTag)) { + @SuppressWarnings("unchecked") + WindowedValue<OutputT> outputValue = (WindowedValue<OutputT>) value; + output.collect(new StreamRecord<>(outputValue)); + } else { + @SuppressWarnings("unchecked") + OutputTag<WindowedValue<T>> outputTag = (OutputTag) mapping.get(tag); + output.<WindowedValue<T>>collect(outputTag, new StreamRecord<>(value)); + } } }; } http://git-wip-us.apache.org/repos/asf/beam/blob/b0601fd4/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java index bf64ede..ea578b9 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java @@ -46,7 +46,7 @@ import org.apache.flink.streaming.api.operators.InternalTimer; * Flink operator for executing window {@link DoFn DoFns}. */ public class WindowDoFnOperator<K, InputT, OutputT> - extends DoFnOperator<KeyedWorkItem<K, InputT>, KV<K, OutputT>, WindowedValue<KV<K, OutputT>>> { + extends DoFnOperator<KeyedWorkItem<K, InputT>, KV<K, OutputT>, KV<K, OutputT>> { private final SystemReduceFn<K, InputT, ?, OutputT, BoundedWindow> systemReduceFn; @@ -56,7 +56,7 @@ public class WindowDoFnOperator<K, InputT, OutputT> Coder<WindowedValue<KeyedWorkItem<K, InputT>>> inputCoder, TupleTag<KV<K, OutputT>> mainOutputTag, List<TupleTag<?>> additionalOutputTags, - OutputManagerFactory<WindowedValue<KV<K, OutputT>>> outputManagerFactory, + OutputManagerFactory<KV<K, OutputT>> outputManagerFactory, WindowingStrategy<?, ?> windowingStrategy, Map<Integer, PCollectionView<?>> sideInputTagMapping, Collection<PCollectionView<?>> sideInputs, http://git-wip-us.apache.org/repos/asf/beam/blob/b0601fd4/runners/flink/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java index 8382a2d..bc0b1c2 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java @@ -173,13 +173,12 @@ public class PipelineOptionsTest { final byte[] serialized = SerializationUtils.serialize(doFnOperator); @SuppressWarnings("unchecked") - DoFnOperator<Object, Object, Object> deserialized = - (DoFnOperator<Object, Object, Object>) SerializationUtils.deserialize(serialized); + DoFnOperator<Object, Object, Object> deserialized = SerializationUtils.deserialize(serialized); TypeInformation<WindowedValue<Object>> typeInformation = TypeInformation.of( new TypeHint<WindowedValue<Object>>() {}); - OneInputStreamOperatorTestHarness<WindowedValue<Object>, Object> testHarness = + OneInputStreamOperatorTestHarness<WindowedValue<Object>, WindowedValue<Object>> testHarness = new OneInputStreamOperatorTestHarness<>(deserialized, typeInformation.createSerializer(new ExecutionConfig())); http://git-wip-us.apache.org/repos/asf/beam/blob/b0601fd4/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java index 79bc0e0..132242e 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java @@ -65,6 +65,7 @@ import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness; import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; import org.apache.flink.streaming.util.TwoInputStreamOperatorTestHarness; +import org.apache.flink.util.OutputTag; import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Test; @@ -123,7 +124,7 @@ public class DoFnOperatorTest { PipelineOptionsFactory.as(FlinkPipelineOptions.class), null); - OneInputStreamOperatorTestHarness<WindowedValue<String>, String> testHarness = + OneInputStreamOperatorTestHarness<WindowedValue<String>, WindowedValue<String>> testHarness = new OneInputStreamOperatorTestHarness<>(doFnOperator); testHarness.open(); @@ -147,26 +148,27 @@ public class DoFnOperatorTest { TupleTag<String> mainOutput = new TupleTag<>("main-output"); TupleTag<String> additionalOutput1 = new TupleTag<>("output-1"); TupleTag<String> additionalOutput2 = new TupleTag<>("output-2"); - ImmutableMap<TupleTag<?>, Integer> outputMapping = ImmutableMap.<TupleTag<?>, Integer>builder() - .put(mainOutput, 1) - .put(additionalOutput1, 2) - .put(additionalOutput2, 3) - .build(); + ImmutableMap<TupleTag<?>, OutputTag<?>> outputMapping = + ImmutableMap.<TupleTag<?>, OutputTag<?>>builder() + .put(mainOutput, new OutputTag<String>(mainOutput.getId()){}) + .put(additionalOutput1, new OutputTag<String>(additionalOutput1.getId()){}) + .put(additionalOutput2, new OutputTag<String>(additionalOutput2.getId()){}) + .build(); - DoFnOperator<String, String, RawUnionValue> doFnOperator = new DoFnOperator<>( + DoFnOperator<String, String, String> doFnOperator = new DoFnOperator<>( new MultiOutputDoFn(additionalOutput1, additionalOutput2), "stepName", windowedValueCoder, mainOutput, ImmutableList.<TupleTag<?>>of(additionalOutput1, additionalOutput2), - new DoFnOperator.MultiOutputOutputManagerFactory(outputMapping), + new DoFnOperator.MultiOutputOutputManagerFactory(mainOutput, outputMapping), WindowingStrategy.globalDefault(), new HashMap<Integer, PCollectionView<?>>(), /* side-input mapping */ Collections.<PCollectionView<?>>emptyList(), /* side inputs */ PipelineOptionsFactory.as(FlinkPipelineOptions.class), null); - OneInputStreamOperatorTestHarness<WindowedValue<String>, RawUnionValue> testHarness = + OneInputStreamOperatorTestHarness<WindowedValue<String>, WindowedValue<String>> testHarness = new OneInputStreamOperatorTestHarness<>(doFnOperator); testHarness.open(); @@ -176,17 +178,26 @@ public class DoFnOperatorTest { testHarness.processElement(new StreamRecord<>(WindowedValue.valueInGlobalWindow("hello"))); assertThat( - this.stripStreamRecordFromRawUnion(testHarness.getOutput()), + this.stripStreamRecord(testHarness.getOutput()), + contains( + WindowedValue.valueInGlobalWindow("got: hello"))); + + assertThat( + this.stripStreamRecord(testHarness.getSideOutput(outputMapping.get(additionalOutput1))), contains( - new RawUnionValue(2, WindowedValue.valueInGlobalWindow("extra: one")), - new RawUnionValue(3, WindowedValue.valueInGlobalWindow("extra: two")), - new RawUnionValue(1, WindowedValue.valueInGlobalWindow("got: hello")), - new RawUnionValue(2, WindowedValue.valueInGlobalWindow("got: hello")), - new RawUnionValue(3, WindowedValue.valueInGlobalWindow("got: hello")))); + WindowedValue.valueInGlobalWindow("extra: one"), + WindowedValue.valueInGlobalWindow("got: hello"))); + + assertThat( + this.stripStreamRecord(testHarness.getSideOutput(outputMapping.get(additionalOutput2))), + contains( + WindowedValue.valueInGlobalWindow("extra: two"), + WindowedValue.valueInGlobalWindow("got: hello"))); testHarness.close(); } + @Test public void testLateDroppingForStatefulFn() throws Exception { @@ -212,13 +223,13 @@ public class DoFnOperatorTest { TupleTag<String> outputTag = new TupleTag<>("main-output"); - DoFnOperator<Integer, String, WindowedValue<String>> doFnOperator = new DoFnOperator<>( + DoFnOperator<Integer, String, String> doFnOperator = new DoFnOperator<>( fn, "stepName", windowedValueCoder, outputTag, Collections.<TupleTag<?>>emptyList(), - new DoFnOperator.DefaultOutputManagerFactory<WindowedValue<String>>(), + new DoFnOperator.DefaultOutputManagerFactory<String>(), windowingStrategy, new HashMap<Integer, PCollectionView<?>>(), /* side-input mapping */ Collections.<PCollectionView<?>>emptyList(), /* side inputs */ @@ -325,14 +336,14 @@ public class DoFnOperatorTest { TupleTag<KV<String, Integer>> outputTag = new TupleTag<>("main-output"); DoFnOperator< - KV<String, Integer>, KV<String, Integer>, WindowedValue<KV<String, Integer>>> doFnOperator = + KV<String, Integer>, KV<String, Integer>, KV<String, Integer>> doFnOperator = new DoFnOperator<>( fn, "stepName", windowedValueCoder, outputTag, Collections.<TupleTag<?>>emptyList(), - new DoFnOperator.DefaultOutputManagerFactory<WindowedValue<KV<String, Integer>>>(), + new DoFnOperator.DefaultOutputManagerFactory<KV<String, Integer>>(), windowingStrategy, new HashMap<Integer, PCollectionView<?>>(), /* side-input mapping */ Collections.<PCollectionView<?>>emptyList(), /* side inputs */ @@ -435,8 +446,8 @@ public class DoFnOperatorTest { PipelineOptionsFactory.as(FlinkPipelineOptions.class), keyCoder); - TwoInputStreamOperatorTestHarness<WindowedValue<String>, RawUnionValue, String> testHarness = - new TwoInputStreamOperatorTestHarness<>(doFnOperator); + TwoInputStreamOperatorTestHarness<WindowedValue<String>, RawUnionValue, WindowedValue<String>> + testHarness = new TwoInputStreamOperatorTestHarness<>(doFnOperator); if (keyed) { // we use a dummy key for the second input since it is considered to be broadcast @@ -527,19 +538,19 @@ public class DoFnOperatorTest { }); } - private Iterable<RawUnionValue> stripStreamRecordFromRawUnion(Iterable<Object> input) { + private Iterable<WindowedValue<String>> stripStreamRecord(Iterable<?> input) { return FluentIterable.from(input).filter(new Predicate<Object>() { @Override public boolean apply(@Nullable Object o) { - return o instanceof StreamRecord && ((StreamRecord) o).getValue() instanceof RawUnionValue; + return o instanceof StreamRecord; } - }).transform(new Function<Object, RawUnionValue>() { + }).transform(new Function<Object, WindowedValue<String>>() { @Nullable @Override @SuppressWarnings({"unchecked", "rawtypes"}) - public RawUnionValue apply(@Nullable Object o) { - if (o instanceof StreamRecord && ((StreamRecord) o).getValue() instanceof RawUnionValue) { - return (RawUnionValue) ((StreamRecord) o).getValue(); + public WindowedValue<String> apply(@Nullable Object o) { + if (o instanceof StreamRecord) { + return (WindowedValue<String>) ((StreamRecord) o).getValue(); } throw new RuntimeException("unreachable"); }
