BEAM-261 Add support for ParDo.BoundMulti
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/047cff49 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/047cff49 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/047cff49 Branch: refs/heads/master Commit: 047cff492f1f804785dee73b4768293d3569e8de Parents: 0975494 Author: Thomas Weise <t...@apache.org> Authored: Thu Oct 6 22:36:01 2016 -0700 Committer: Thomas Weise <t...@apache.org> Committed: Sun Oct 16 23:27:15 2016 -0700 ---------------------------------------------------------------------- runners/apex/pom.xml | 3 +- .../runners/apex/ApexPipelineTranslator.java | 2 + .../apache/beam/runners/apex/ApexRunner.java | 3 +- .../FlattenPCollectionTranslator.java | 1 + .../translators/ParDoBoundMultiTranslator.java | 74 ++++++++++++++++++++ .../apex/translators/ParDoBoundTranslator.java | 5 +- .../apex/translators/TranslationContext.java | 17 +++++ .../functions/ApexFlattenOperator.java | 2 + .../functions/ApexParDoOperator.java | 68 ++++++++++++++---- .../FlattenPCollectionTranslatorTest.java | 42 +++++------ .../translators/ParDoBoundTranslatorTest.java | 29 ++++---- 11 files changed, 194 insertions(+), 52 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/047cff49/runners/apex/pom.xml ---------------------------------------------------------------------- diff --git a/runners/apex/pom.xml b/runners/apex/pom.xml index e9377b4..929feb4 100644 --- a/runners/apex/pom.xml +++ b/runners/apex/pom.xml @@ -185,8 +185,7 @@ <systemPropertyVariables> <beamTestPipelineOptions> [ - "--runner=org.apache.beam.runners.apex.TestApexRunner", - "--streaming=true" + "--runner=org.apache.beam.runners.apex.TestApexRunner" ] </beamTestPipelineOptions> </systemPropertyVariables> http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/047cff49/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexPipelineTranslator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexPipelineTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexPipelineTranslator.java index ad8c283..40edfb1 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexPipelineTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexPipelineTranslator.java @@ -22,6 +22,7 @@ import org.apache.beam.runners.apex.ApexRunner.CreateApexPCollectionView; import org.apache.beam.runners.apex.translators.CreateValuesTranslator; import org.apache.beam.runners.apex.translators.FlattenPCollectionTranslator; import org.apache.beam.runners.apex.translators.GroupByKeyTranslator; +import org.apache.beam.runners.apex.translators.ParDoBoundMultiTranslator; import org.apache.beam.runners.apex.translators.ParDoBoundTranslator; import org.apache.beam.runners.apex.translators.ReadUnboundedTranslator; import org.apache.beam.runners.apex.translators.TransformTranslator; @@ -66,6 +67,7 @@ public class ApexPipelineTranslator implements Pipeline.PipelineVisitor { static { // register TransformTranslators registerTransformTranslator(ParDo.Bound.class, new ParDoBoundTranslator()); + registerTransformTranslator(ParDo.BoundMulti.class, new ParDoBoundMultiTranslator<>()); registerTransformTranslator(Read.Unbounded.class, new ReadUnboundedTranslator()); registerTransformTranslator(Read.Bounded.class, new ReadBoundedTranslator()); registerTransformTranslator(GroupByKey.class, new GroupByKeyTranslator()); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/047cff49/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java index ae79a20..e2ebc29 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java @@ -230,7 +230,7 @@ public class ApexRunner extends PipelineRunner<ApexRunnerResult> { * Records that the {@link PTransform} requires a deterministic key coder. */ private void recordViewUsesNonDeterministicKeyCoder(PTransform<?, ?> ptransform) { - throw new UnsupportedOperationException(); + //throw new UnsupportedOperationException(); } /** @@ -369,7 +369,6 @@ public class ApexRunner extends PipelineRunner<ApexRunnerResult> { private final ApexRunner runner; - @SuppressWarnings("unused") // used via reflection in FlinkRunner#apply() public StreamingViewAsMap(ApexRunner runner, View.AsMap<K, V> transform) { this.runner = runner; } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/047cff49/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslator.java index 712466a..90ab81f 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslator.java @@ -80,6 +80,7 @@ public class FlattenPCollectionTranslator<T> implements if (firstCollection != null) { // push to next merge level remainingCollections.add(firstCollection); + firstCollection = null; } if (remainingCollections.size() > 1) { collections = remainingCollections; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/047cff49/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java new file mode 100644 index 0000000..6488bf6 --- /dev/null +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.apex.translators; + +import java.util.List; +import java.util.Map; + +import org.apache.beam.runners.apex.translators.functions.ApexParDoOperator; +import org.apache.beam.sdk.transforms.OldDoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; + +import com.datatorrent.api.Operator; +import com.datatorrent.api.Operator.OutputPort; +import com.google.common.collect.Maps; + +/** + * {@link ParDo.BoundMulti} is translated to Apex operator that wraps the {@link DoFn} + */ +public class ParDoBoundMultiTranslator<InputT, OutputT> implements TransformTranslator<ParDo.BoundMulti<InputT, OutputT>> { + private static final long serialVersionUID = 1L; + + @Override + public void translate(ParDo.BoundMulti<InputT, OutputT> transform, TranslationContext context) { + OldDoFn<InputT, OutputT> doFn = transform.getFn(); + PCollectionTuple output = context.getOutput(); + List<PCollectionView<?>> sideInputs = transform.getSideInputs(); + ApexParDoOperator<InputT, OutputT> operator = new ApexParDoOperator<>(context.getPipelineOptions(), + doFn, transform.getMainOutputTag(), transform.getSideOutputTags().getAll(), + context.<PCollection<?>>getInput().getWindowingStrategy(), sideInputs); + + Map<TupleTag<?>, PCollection<?>> outputs = output.getAll(); + Map<PCollection<?>, OutputPort<?>> ports = Maps.newHashMapWithExpectedSize(outputs.size()); + int i = 0; + for (Map.Entry<TupleTag<?>, PCollection<?>> outputEntry : outputs.entrySet()) { + ports.put(outputEntry.getValue(), operator.sideOutputPorts[i++]); + } + context.addOperator(operator, ports); + + context.addStream(context.getInput(), operator.input); + if (!sideInputs.isEmpty()) { + Operator.InputPort<?>[] sideInputPorts = {operator.sideInput1}; + for (i=0; i<sideInputs.size(); i++) { + // the number of input ports for side inputs are fixed and each port can only take one input. + // more (optional) ports can be added to give reasonable capacity or an explicit union operation introduced. + if (i == sideInputPorts.length) { + String msg = String.format("Too many side inputs in %s (currently only supporting %s).", + transform.toString(), sideInputPorts.length); + throw new UnsupportedOperationException(msg); + } + context.addStream(context.getViewInput(sideInputs.get(i)), sideInputPorts[i]); + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/047cff49/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java index 632829a..fa3df7c 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java @@ -25,6 +25,8 @@ import org.apache.beam.sdk.transforms.OldDoFn; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import com.datatorrent.api.Operator; @@ -41,7 +43,8 @@ public class ParDoBoundTranslator<InputT, OutputT> implements PCollection<OutputT> output = context.getOutput(); List<PCollectionView<?>> sideInputs = transform.getSideInputs(); ApexParDoOperator<InputT, OutputT> operator = new ApexParDoOperator<>(context.getPipelineOptions(), - doFn, output.getWindowingStrategy(), sideInputs); + doFn, new TupleTag<OutputT>(), TupleTagList.empty().getAll() /*sideOutputTags*/, + output.getWindowingStrategy(), sideInputs); context.addOperator(operator, operator.output); context.addStream(context.getInput(), operator.input); if (!sideInputs.isEmpty()) { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/047cff49/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java index 163cfd4..bd44a20 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java @@ -98,6 +98,23 @@ public class TranslationContext { } /** + * Register operator and output ports for the given collections. + * @param operator + * @param ports + */ + public void addOperator(Operator operator, Map<PCollection<?>, OutputPort<?>> ports) { + boolean first = true; + for (Map.Entry<PCollection<?>, OutputPort<?>> portEntry : ports.entrySet()) { + if (first) { + addOperator(operator, portEntry.getValue(), portEntry.getKey()); + first = false; + } else { + this.streams.put(portEntry.getKey(), (Pair)new ImmutablePair<>(portEntry.getValue(), new ArrayList<>())); + } + } + } + + /** * Add intermediate operator for the current transformation. * @param operator * @param port http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/047cff49/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexFlattenOperator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexFlattenOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexFlattenOperator.java index ce27abb..4675244 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexFlattenOperator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexFlattenOperator.java @@ -26,6 +26,7 @@ import org.slf4j.LoggerFactory; import com.datatorrent.api.DefaultInputPort; import com.datatorrent.api.DefaultOutputPort; +import com.datatorrent.api.annotation.OutputPortFieldAnnotation; import com.datatorrent.common.util.BaseOperator; /** @@ -109,5 +110,6 @@ public class ApexFlattenOperator<InputT> extends BaseOperator /** * Output port. */ + @OutputPortFieldAnnotation(optional=true) public final transient DefaultOutputPort<ApexStreamTuple<WindowedValue<InputT>>> out = new DefaultOutputPort<ApexStreamTuple<WindowedValue<InputT>>>(); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/047cff49/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java index 13a8fc9..995fee1 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java @@ -15,11 +15,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.beam.runners.apex.translators.functions; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.apache.beam.runners.apex.ApexPipelineOptions; import org.apache.beam.runners.apex.ApexRunner; @@ -47,7 +47,6 @@ import org.apache.beam.sdk.util.state.InMemoryStateInternals; import org.apache.beam.sdk.util.state.StateInternals; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.TupleTagList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -59,6 +58,7 @@ import com.datatorrent.api.annotation.OutputPortFieldAnnotation; import com.datatorrent.common.util.BaseOperator; import com.esotericsoftware.kryo.serializers.FieldSerializer.Bind; import com.google.common.collect.Iterables; +import com.google.common.collect.Maps; import com.esotericsoftware.kryo.serializers.JavaSerializer; /** @@ -68,43 +68,58 @@ public class ApexParDoOperator<InputT, OutputT> extends BaseOperator implements private static final Logger LOG = LoggerFactory.getLogger(ApexParDoOperator.class); private boolean traceTuples = true; - private transient final TupleTag<OutputT> mainTag = new TupleTag<OutputT>(); - private transient PushbackSideInputDoFnRunner<InputT, OutputT> pushbackDoFnRunner; - @Bind(JavaSerializer.class) private final SerializablePipelineOptions pipelineOptions; @Bind(JavaSerializer.class) private final OldDoFn<InputT, OutputT> doFn; @Bind(JavaSerializer.class) + private final TupleTag<OutputT> mainOutputTag; + @Bind(JavaSerializer.class) + private final List<TupleTag<?>> sideOutputTags; + @Bind(JavaSerializer.class) private final WindowingStrategy<?, ?> windowingStrategy; @Bind(JavaSerializer.class) - List<PCollectionView<?>> sideInputs; + private final List<PCollectionView<?>> sideInputs; + // TODO: not Kryo serializable, integrate codec //@Bind(JavaSerializer.class) private transient StateInternals<Void> sideInputStateInternals = InMemoryStateInternals.forKey(null); - private transient SideInputHandler sideInputHandler; // TODO: not Kryo serializable, integrate codec private List<WindowedValue<InputT>> pushedBack = new ArrayList<>(); private LongMin pushedBackWatermark = new LongMin(); private long currentInputWatermark = Long.MIN_VALUE; private long currentOutputWatermark = currentInputWatermark; + private transient PushbackSideInputDoFnRunner<InputT, OutputT> pushbackDoFnRunner; + private transient SideInputHandler sideInputHandler; + private transient Map<TupleTag<?>, DefaultOutputPort<ApexStreamTuple<?>>> sideOutputPortMapping = Maps.newHashMapWithExpectedSize(5); + public ApexParDoOperator( ApexPipelineOptions pipelineOptions, OldDoFn<InputT, OutputT> doFn, + TupleTag<OutputT> mainOutputTag, + List<TupleTag<?>> sideOutputTags, WindowingStrategy<?, ?> windowingStrategy, List<PCollectionView<?>> sideInputs ) { this.pipelineOptions = new SerializablePipelineOptions(pipelineOptions); this.doFn = doFn; + this.mainOutputTag = mainOutputTag; + this.sideOutputTags = sideOutputTags; this.windowingStrategy = windowingStrategy; this.sideInputs = sideInputs; + + if (sideOutputTags != null && sideOutputTags.size() > sideOutputPorts.length) { + String msg = String.format("Too many side outputs (currently only supporting %s).", + sideOutputPorts.length); + throw new UnsupportedOperationException(msg); + } } @SuppressWarnings("unused") // for Kryo private ApexParDoOperator() { - this(null, null, null, null); + this(null, null, null, null, null, null); } @@ -167,10 +182,28 @@ private transient StateInternals<Void> sideInputStateInternals = InMemoryStateIn @OutputPortFieldAnnotation(optional=true) public final transient DefaultOutputPort<ApexStreamTuple<?>> output = new DefaultOutputPort<>(); + @OutputPortFieldAnnotation(optional=true) + public final transient DefaultOutputPort<ApexStreamTuple<?>> sideOutput1 = new DefaultOutputPort<>(); + @OutputPortFieldAnnotation(optional=true) + public final transient DefaultOutputPort<ApexStreamTuple<?>> sideOutput2 = new DefaultOutputPort<>(); + @OutputPortFieldAnnotation(optional=true) + public final transient DefaultOutputPort<ApexStreamTuple<?>> sideOutput3 = new DefaultOutputPort<>(); + @OutputPortFieldAnnotation(optional=true) + public final transient DefaultOutputPort<ApexStreamTuple<?>> sideOutput4 = new DefaultOutputPort<>(); + @OutputPortFieldAnnotation(optional=true) + public final transient DefaultOutputPort<ApexStreamTuple<?>> sideOutput5 = new DefaultOutputPort<>(); + + public final transient DefaultOutputPort<?>[] sideOutputPorts = {sideOutput1, sideOutput2, sideOutput3, sideOutput4, sideOutput5}; + @Override public <T> void output(TupleTag<T> tag, WindowedValue<T> tuple) { - output.emit(ApexStreamTuple.DataTuple.of(tuple)); + DefaultOutputPort<ApexStreamTuple<?>> sideOutputPort = sideOutputPortMapping.get(tag); + if (sideOutputPort != null) { + sideOutputPort.emit(ApexStreamTuple.DataTuple.of(tuple)); + } else { + output.emit(ApexStreamTuple.DataTuple.of(tuple)); + } if (traceTuples) { LOG.debug("\nemitting {}\n", tuple); } @@ -178,7 +211,10 @@ private transient StateInternals<Void> sideInputStateInternals = InMemoryStateIn private Iterable<WindowedValue<InputT>> processElementInReadyWindows(WindowedValue<InputT> elem) { try { - return pushbackDoFnRunner.processElementInReadyWindows(elem); + pushbackDoFnRunner.startBundle(); + Iterable<WindowedValue<InputT>> pushedBack = pushbackDoFnRunner.processElementInReadyWindows(elem); + pushbackDoFnRunner.finishBundle(); + return pushedBack; } catch (UserCodeException ue) { if (ue.getCause() instanceof AssertionError) { ApexRunner.assertionError = (AssertionError)ue.getCause(); @@ -220,13 +256,19 @@ private transient StateInternals<Void> sideInputStateInternals = InMemoryStateIn sideInputReader = sideInputHandler; } + for (int i=0; i < sideOutputTags.size(); i++) { + @SuppressWarnings("unchecked") + DefaultOutputPort<ApexStreamTuple<?>> port = (DefaultOutputPort<ApexStreamTuple<?>>)sideOutputPorts[i]; + sideOutputPortMapping.put(sideOutputTags.get(i), port); + } + DoFnRunner<InputT, OutputT> doFnRunner = DoFnRunners.createDefault( pipelineOptions.get(), doFn, sideInputReader, this, - mainTag, - TupleTagList.empty().getAll() /*sideOutputTags*/, + mainOutputTag, + sideOutputTags, new NoOpStepContext(), new NoOpAggregatorFactory(), windowingStrategy @@ -246,7 +288,6 @@ private transient StateInternals<Void> sideInputStateInternals = InMemoryStateIn @Override public void beginWindow(long windowId) { - pushbackDoFnRunner.startBundle(); /* Collection<Aggregator<?, ?>> aggregators = AggregatorRetriever.getAggregators(doFn); if (!aggregators.isEmpty()) { @@ -258,7 +299,6 @@ private transient StateInternals<Void> sideInputStateInternals = InMemoryStateIn @Override public void endWindow() { - pushbackDoFnRunner.finishBundle(); } /** http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/047cff49/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslatorTest.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslatorTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslatorTest.java index d3b56bc..6b181ba 100644 --- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslatorTest.java +++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/FlattenPCollectionTranslatorTest.java @@ -31,15 +31,17 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import org.junit.Assert; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.HashSet; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.Set; /** * integration test for {@link FlattenPCollectionTranslator}. @@ -49,41 +51,41 @@ public class FlattenPCollectionTranslatorTest { @Test public void test() throws Exception { - ApexPipelineOptions options = - PipelineOptionsFactory.as(ApexPipelineOptions.class); + ApexPipelineOptions options = PipelineOptionsFactory.as(ApexPipelineOptions.class); options.setApplicationName("FlattenPCollection"); options.setRunner(ApexRunner.class); Pipeline p = Pipeline.create(options); - List<String> collection1 = Lists.newArrayList("1", "2", "3"); - List<String> collection2 = Lists.newArrayList("4", "5"); - List<String> expected = Lists.newArrayList("1", "2", "3", "4", "5"); - PCollection<String> pc1 = - p.apply(Create.of(collection1).withCoder(StringUtf8Coder.of())); - PCollection<String> pc2 = - p.apply(Create.of(collection2).withCoder(StringUtf8Coder.of())); - PCollectionList<String> pcs = PCollectionList.of(pc1).and(pc2); - PCollection<String> actual = pcs.apply(Flatten.<String>pCollections()); + String[][] collections = { + {"1"}, {"2"}, {"3"}, {"4"}, {"5"} + }; + + Set<String> expected = Sets.newHashSet(); + List<PCollection<String>> pcList = new ArrayList<PCollection<String>>(); + for (String[] collection : collections) { + pcList.add(p.apply(Create.of(collection).withCoder(StringUtf8Coder.of()))); + expected.addAll(Arrays.asList(collection)); + } + + PCollection<String> actual = PCollectionList.of(pcList).apply(Flatten.<String>pCollections()); actual.apply(ParDo.of(new EmbeddedCollector())); ApexRunnerResult result = (ApexRunnerResult)p.run(); // TODO: verify translation result.getApexDAG(); long timeout = System.currentTimeMillis() + 30000; - while (System.currentTimeMillis() < timeout) { - if (EmbeddedCollector.results.containsAll(expected)) { - break; - } + while (System.currentTimeMillis() < timeout && EmbeddedCollector.results.size() < expected.size()) { LOG.info("Waiting for expected results."); - Thread.sleep(1000); + Thread.sleep(500); } - org.junit.Assert.assertEquals(Sets.newHashSet(expected), EmbeddedCollector.results); + Assert.assertEquals("number results", expected.size(), EmbeddedCollector.results.size()); + Assert.assertEquals(expected, Sets.newHashSet(EmbeddedCollector.results)); } @SuppressWarnings("serial") private static class EmbeddedCollector extends OldDoFn<Object, Void> { - protected static final HashSet<Object> results = new HashSet<>(); + protected static final ArrayList<Object> results = new ArrayList<>(); public EmbeddedCollector() { } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/047cff49/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java index 6239021..301f6f8 100644 --- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java +++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java @@ -36,6 +36,8 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import com.datatorrent.api.DAG; import com.datatorrent.lib.util.KryoCloneUtils; @@ -129,6 +131,18 @@ public class ParDoBoundTranslatorTest { } } + private static Throwable runExpectingAssertionFailure(Pipeline pipeline) { + // We cannot use thrown.expect(AssertionError.class) because the AssertionError + // is first caught by JUnit and causes a test failure. + try { + pipeline.run(); + } catch (AssertionError exc) { + return exc; + } + fail("assertion should have failed"); + throw new RuntimeException("unreachable"); + } + @Test public void testAssertionFailure() throws Exception { ApexPipelineOptions options = PipelineOptionsFactory.create() @@ -163,24 +177,13 @@ public class ParDoBoundTranslatorTest { pipeline.run(); } - private static Throwable runExpectingAssertionFailure(Pipeline pipeline) { - // We cannot use thrown.expect(AssertionError.class) because the AssertionError - // is first caught by JUnit and causes a test failure. - try { - pipeline.run(); - } catch (AssertionError exc) { - return exc; - } - fail("assertion should have failed"); - throw new RuntimeException("unreachable"); - } - @Test public void testSerialization() throws Exception { ApexPipelineOptions options = PipelineOptionsFactory.create() .as(ApexPipelineOptions.class); ApexParDoOperator<Integer, Integer> operator = new ApexParDoOperator<>(options, - new Add(0), WindowingStrategy.globalDefault(), Collections.<PCollectionView<?>> emptyList()); + new Add(0), new TupleTag<Integer>(), TupleTagList.empty().getAll(), + WindowingStrategy.globalDefault(), Collections.<PCollectionView<?>> emptyList()); operator.setup(null); operator.beginWindow(0); WindowedValue<Integer> wv = WindowedValue.valueInGlobalWindow(0);