Implement Single-Output ParDo as a composite This reduces the number of primitive transforms in the Java SDK. There is no functional change for any pipeline as a result of this change.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/6253abaa Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/6253abaa Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/6253abaa Branch: refs/heads/master Commit: 6253abaac62979e8496a828c18c7d1aa7214be6a Parents: 7e9233b Author: Thomas Groh <[email protected]> Authored: Thu Mar 2 14:27:29 2017 -0800 Committer: Thomas Groh <[email protected]> Committed: Fri Mar 3 15:36:09 2017 -0800 ---------------------------------------------------------------------- .../translation/ApexPipelineTranslator.java | 3 +- .../translation/ParDoBoundMultiTranslator.java | 185 ---------- .../apex/translation/ParDoBoundTranslator.java | 95 ----- .../apex/translation/ParDoTranslator.java | 185 ++++++++++ .../FlattenPCollectionTranslatorTest.java | 3 +- .../translation/ParDoBoundTranslatorTest.java | 344 ------------------- .../apex/translation/ParDoTranslatorTest.java | 344 +++++++++++++++++++ .../beam/runners/direct/DirectRunner.java | 18 +- .../ParDoSingleViaMultiOverrideFactory.java | 70 ---- .../ParDoSingleViaMultiOverrideFactoryTest.java | 46 --- .../flink/FlinkBatchTransformTranslators.java | 78 +---- .../FlinkStreamingTransformTranslators.java | 113 +----- .../dataflow/DataflowPipelineTranslator.java | 29 -- .../DataflowPipelineTranslatorTest.java | 7 +- .../spark/translation/TransformTranslator.java | 100 +++--- .../streaming/StreamingTransformTranslator.java | 115 ++++--- .../streaming/TrackStreamingSourcesTest.java | 4 +- .../org/apache/beam/sdk/transforms/ParDo.java | 8 +- 18 files changed, 668 insertions(+), 1079 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ApexPipelineTranslator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ApexPipelineTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ApexPipelineTranslator.java index 951a286..7eb9551 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ApexPipelineTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ApexPipelineTranslator.java @@ -59,8 +59,7 @@ public class ApexPipelineTranslator implements Pipeline.PipelineVisitor { static { // register TransformTranslators - registerTransformTranslator(ParDo.Bound.class, new ParDoBoundTranslator()); - registerTransformTranslator(ParDo.BoundMulti.class, new ParDoBoundMultiTranslator<>()); + registerTransformTranslator(ParDo.BoundMulti.class, new ParDoTranslator<>()); registerTransformTranslator(Read.Unbounded.class, new ReadUnboundedTranslator()); registerTransformTranslator(Read.Bounded.class, new ReadBoundedTranslator()); registerTransformTranslator(GroupByKey.class, new GroupByKeyTranslator()); http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundMultiTranslator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundMultiTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundMultiTranslator.java deleted file mode 100644 index f55b48c..0000000 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundMultiTranslator.java +++ /dev/null @@ -1,185 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.beam.runners.apex.translation; - -import static com.google.common.base.Preconditions.checkArgument; - -import com.datatorrent.api.Operator; -import com.datatorrent.api.Operator.OutputPort; -import com.google.common.collect.Maps; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.apache.beam.runners.apex.ApexRunner; -import org.apache.beam.runners.apex.translation.operators.ApexParDoOperator; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.reflect.DoFnSignature; -import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; -import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; -import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.TaggedPValue; -import org.apache.beam.sdk.values.TupleTag; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * {@link ParDo.BoundMulti} is translated to {@link ApexParDoOperator} that wraps the {@link DoFn}. - */ -class ParDoBoundMultiTranslator<InputT, OutputT> - implements TransformTranslator<ParDo.BoundMulti<InputT, OutputT>> { - private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(ParDoBoundMultiTranslator.class); - - @Override - public void translate(ParDo.BoundMulti<InputT, OutputT> transform, TranslationContext context) { - DoFn<InputT, OutputT> doFn = transform.getFn(); - DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); - - if (signature.processElement().isSplittable()) { - throw new UnsupportedOperationException( - String.format( - "%s does not support splittable DoFn: %s", ApexRunner.class.getSimpleName(), doFn)); - } - if (signature.stateDeclarations().size() > 0) { - throw new UnsupportedOperationException( - String.format( - "Found %s annotations on %s, but %s cannot yet be used with state in the %s.", - DoFn.StateId.class.getSimpleName(), - doFn.getClass().getName(), - DoFn.class.getSimpleName(), - ApexRunner.class.getSimpleName())); - } - - if (signature.timerDeclarations().size() > 0) { - throw new UnsupportedOperationException( - String.format( - "Found %s annotations on %s, but %s cannot yet be used with timers in the %s.", - DoFn.TimerId.class.getSimpleName(), - doFn.getClass().getName(), - DoFn.class.getSimpleName(), - ApexRunner.class.getSimpleName())); - } - - List<TaggedPValue> outputs = context.getOutputs(); - PCollection<InputT> input = (PCollection<InputT>) context.getInput(); - List<PCollectionView<?>> sideInputs = transform.getSideInputs(); - Coder<InputT> inputCoder = input.getCoder(); - WindowedValueCoder<InputT> wvInputCoder = - FullWindowedValueCoder.of( - inputCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); - - ApexParDoOperator<InputT, OutputT> operator = - new ApexParDoOperator<>( - context.getPipelineOptions(), - doFn, - transform.getMainOutputTag(), - transform.getSideOutputTags().getAll(), - ((PCollection<InputT>) context.getInput()).getWindowingStrategy(), - sideInputs, - wvInputCoder, - context.<Void>stateInternalsFactory()); - - Map<PCollection<?>, OutputPort<?>> ports = Maps.newHashMapWithExpectedSize(outputs.size()); - for (TaggedPValue output : outputs) { - checkArgument( - output.getValue() instanceof PCollection, - "%s %s outputs non-PCollection %s of type %s", - ParDo.BoundMulti.class.getSimpleName(), - context.getFullName(), - output.getValue(), - output.getValue().getClass().getSimpleName()); - PCollection<?> pc = (PCollection<?>) output.getValue(); - if (output.getTag().equals(transform.getMainOutputTag())) { - ports.put(pc, operator.output); - } else { - int portIndex = 0; - for (TupleTag<?> tag : transform.getSideOutputTags().getAll()) { - if (tag.equals(output.getTag())) { - ports.put(pc, operator.sideOutputPorts[portIndex]); - break; - } - portIndex++; - } - } - } - context.addOperator(operator, ports); - context.addStream(context.getInput(), operator.input); - if (!sideInputs.isEmpty()) { - addSideInputs(operator, sideInputs, context); - } - } - - static void addSideInputs( - ApexParDoOperator<?, ?> operator, - List<PCollectionView<?>> sideInputs, - TranslationContext context) { - Operator.InputPort<?>[] sideInputPorts = {operator.sideInput1}; - if (sideInputs.size() > sideInputPorts.length) { - PCollection<?> unionCollection = unionSideInputs(sideInputs, context); - context.addStream(unionCollection, sideInputPorts[0]); - } else { - // the number of ports for side inputs is fixed and each port can only take one input. - for (int i = 0; i < sideInputs.size(); i++) { - context.addStream(context.getViewInput(sideInputs.get(i)), sideInputPorts[i]); - } - } - } - - private static PCollection<?> unionSideInputs( - List<PCollectionView<?>> sideInputs, TranslationContext context) { - checkArgument(sideInputs.size() > 1, "requires multiple side inputs"); - // flatten and assign union tag - List<PCollection<Object>> sourceCollections = new ArrayList<>(); - Map<PCollection<?>, Integer> unionTags = new HashMap<>(); - PCollection<Object> firstSideInput = context.getViewInput(sideInputs.get(0)); - for (int i = 0; i < sideInputs.size(); i++) { - PCollectionView<?> sideInput = sideInputs.get(i); - PCollection<?> sideInputCollection = context.getViewInput(sideInput); - if (!sideInputCollection - .getWindowingStrategy() - .equals(firstSideInput.getWindowingStrategy())) { - // TODO: check how to handle this in stream codec - //String msg = "Multiple side inputs with different window strategies."; - //throw new UnsupportedOperationException(msg); - LOG.warn( - "Side inputs union with different windowing strategies {} {}", - firstSideInput.getWindowingStrategy(), - sideInputCollection.getWindowingStrategy()); - } - if (!sideInputCollection.getCoder().equals(firstSideInput.getCoder())) { - String msg = "Multiple side inputs with different coders."; - throw new UnsupportedOperationException(msg); - } - sourceCollections.add(context.<PCollection<Object>>getViewInput(sideInput)); - unionTags.put(sideInputCollection, i); - } - - PCollection<Object> resultCollection = - FlattenPCollectionTranslator.intermediateCollection( - firstSideInput, firstSideInput.getCoder()); - FlattenPCollectionTranslator.flattenCollections( - sourceCollections, unionTags, resultCollection, context); - return resultCollection; - } -} http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslator.java deleted file mode 100644 index 5195809..0000000 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslator.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.beam.runners.apex.translation; - -import java.util.List; -import org.apache.beam.runners.apex.ApexRunner; -import org.apache.beam.runners.apex.translation.operators.ApexParDoOperator; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.reflect.DoFnSignature; -import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; -import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; -import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.TupleTagList; - -/** {@link ParDo.Bound} is translated to {link ApexParDoOperator} that wraps the {@link DoFn}. */ -class ParDoBoundTranslator<InputT, OutputT> - implements TransformTranslator<ParDo.Bound<InputT, OutputT>> { - private static final long serialVersionUID = 1L; - - @Override - public void translate(ParDo.Bound<InputT, OutputT> transform, TranslationContext context) { - DoFn<InputT, OutputT> doFn = transform.getFn(); - DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); - - if (signature.processElement().isSplittable()) { - throw new UnsupportedOperationException( - String.format( - "%s does not support splittable DoFn: %s", ApexRunner.class.getSimpleName(), doFn)); - } - if (signature.stateDeclarations().size() > 0) { - throw new UnsupportedOperationException( - String.format( - "Found %s annotations on %s, but %s cannot yet be used with state in the %s.", - DoFn.StateId.class.getSimpleName(), - doFn.getClass().getName(), - DoFn.class.getSimpleName(), - ApexRunner.class.getSimpleName())); - } - - if (signature.timerDeclarations().size() > 0) { - throw new UnsupportedOperationException( - String.format( - "Found %s annotations on %s, but %s cannot yet be used with timers in the %s.", - DoFn.TimerId.class.getSimpleName(), - doFn.getClass().getName(), - DoFn.class.getSimpleName(), - ApexRunner.class.getSimpleName())); - } - - PCollection<OutputT> output = (PCollection<OutputT>) context.getOutput(); - PCollection<InputT> input = (PCollection<InputT>) context.getInput(); - List<PCollectionView<?>> sideInputs = transform.getSideInputs(); - Coder<InputT> inputCoder = input.getCoder(); - WindowedValueCoder<InputT> wvInputCoder = - FullWindowedValueCoder.of( - inputCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); - - ApexParDoOperator<InputT, OutputT> operator = - new ApexParDoOperator<>( - context.getPipelineOptions(), - doFn, - new TupleTag<OutputT>(), - TupleTagList.empty().getAll() /*sideOutputTags*/, - output.getWindowingStrategy(), - sideInputs, - wvInputCoder, - context.<Void>stateInternalsFactory()); - context.addOperator(operator, operator.output); - context.addStream(context.getInput(), operator.input); - if (!sideInputs.isEmpty()) { - ParDoBoundMultiTranslator.addSideInputs(operator, sideInputs, context); - } - } -} http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java new file mode 100644 index 0000000..5ffc3c3 --- /dev/null +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.apex.translation; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.datatorrent.api.Operator; +import com.datatorrent.api.Operator.OutputPort; +import com.google.common.collect.Maps; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.apex.ApexRunner; +import org.apache.beam.runners.apex.translation.operators.ApexParDoOperator; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; +import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.TupleTag; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@link ParDo.BoundMulti} is translated to {@link ApexParDoOperator} that wraps the {@link DoFn}. + */ +class ParDoTranslator<InputT, OutputT> + implements TransformTranslator<ParDo.BoundMulti<InputT, OutputT>> { + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(ParDoTranslator.class); + + @Override + public void translate(ParDo.BoundMulti<InputT, OutputT> transform, TranslationContext context) { + DoFn<InputT, OutputT> doFn = transform.getFn(); + DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); + + if (signature.processElement().isSplittable()) { + throw new UnsupportedOperationException( + String.format( + "%s does not support splittable DoFn: %s", ApexRunner.class.getSimpleName(), doFn)); + } + if (signature.stateDeclarations().size() > 0) { + throw new UnsupportedOperationException( + String.format( + "Found %s annotations on %s, but %s cannot yet be used with state in the %s.", + DoFn.StateId.class.getSimpleName(), + doFn.getClass().getName(), + DoFn.class.getSimpleName(), + ApexRunner.class.getSimpleName())); + } + + if (signature.timerDeclarations().size() > 0) { + throw new UnsupportedOperationException( + String.format( + "Found %s annotations on %s, but %s cannot yet be used with timers in the %s.", + DoFn.TimerId.class.getSimpleName(), + doFn.getClass().getName(), + DoFn.class.getSimpleName(), + ApexRunner.class.getSimpleName())); + } + + List<TaggedPValue> outputs = context.getOutputs(); + PCollection<InputT> input = (PCollection<InputT>) context.getInput(); + List<PCollectionView<?>> sideInputs = transform.getSideInputs(); + Coder<InputT> inputCoder = input.getCoder(); + WindowedValueCoder<InputT> wvInputCoder = + FullWindowedValueCoder.of( + inputCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); + + ApexParDoOperator<InputT, OutputT> operator = + new ApexParDoOperator<>( + context.getPipelineOptions(), + doFn, + transform.getMainOutputTag(), + transform.getSideOutputTags().getAll(), + ((PCollection<InputT>) context.getInput()).getWindowingStrategy(), + sideInputs, + wvInputCoder, + context.<Void>stateInternalsFactory()); + + Map<PCollection<?>, OutputPort<?>> ports = Maps.newHashMapWithExpectedSize(outputs.size()); + for (TaggedPValue output : outputs) { + checkArgument( + output.getValue() instanceof PCollection, + "%s %s outputs non-PCollection %s of type %s", + ParDo.BoundMulti.class.getSimpleName(), + context.getFullName(), + output.getValue(), + output.getValue().getClass().getSimpleName()); + PCollection<?> pc = (PCollection<?>) output.getValue(); + if (output.getTag().equals(transform.getMainOutputTag())) { + ports.put(pc, operator.output); + } else { + int portIndex = 0; + for (TupleTag<?> tag : transform.getSideOutputTags().getAll()) { + if (tag.equals(output.getTag())) { + ports.put(pc, operator.sideOutputPorts[portIndex]); + break; + } + portIndex++; + } + } + } + context.addOperator(operator, ports); + context.addStream(context.getInput(), operator.input); + if (!sideInputs.isEmpty()) { + addSideInputs(operator, sideInputs, context); + } + } + + static void addSideInputs( + ApexParDoOperator<?, ?> operator, + List<PCollectionView<?>> sideInputs, + TranslationContext context) { + Operator.InputPort<?>[] sideInputPorts = {operator.sideInput1}; + if (sideInputs.size() > sideInputPorts.length) { + PCollection<?> unionCollection = unionSideInputs(sideInputs, context); + context.addStream(unionCollection, sideInputPorts[0]); + } else { + // the number of ports for side inputs is fixed and each port can only take one input. + for (int i = 0; i < sideInputs.size(); i++) { + context.addStream(context.getViewInput(sideInputs.get(i)), sideInputPorts[i]); + } + } + } + + private static PCollection<?> unionSideInputs( + List<PCollectionView<?>> sideInputs, TranslationContext context) { + checkArgument(sideInputs.size() > 1, "requires multiple side inputs"); + // flatten and assign union tag + List<PCollection<Object>> sourceCollections = new ArrayList<>(); + Map<PCollection<?>, Integer> unionTags = new HashMap<>(); + PCollection<Object> firstSideInput = context.getViewInput(sideInputs.get(0)); + for (int i = 0; i < sideInputs.size(); i++) { + PCollectionView<?> sideInput = sideInputs.get(i); + PCollection<?> sideInputCollection = context.getViewInput(sideInput); + if (!sideInputCollection + .getWindowingStrategy() + .equals(firstSideInput.getWindowingStrategy())) { + // TODO: check how to handle this in stream codec + //String msg = "Multiple side inputs with different window strategies."; + //throw new UnsupportedOperationException(msg); + LOG.warn( + "Side inputs union with different windowing strategies {} {}", + firstSideInput.getWindowingStrategy(), + sideInputCollection.getWindowingStrategy()); + } + if (!sideInputCollection.getCoder().equals(firstSideInput.getCoder())) { + String msg = "Multiple side inputs with different coders."; + throw new UnsupportedOperationException(msg); + } + sourceCollections.add(context.<PCollection<Object>>getViewInput(sideInput)); + unionTags.put(sideInputCollection, i); + } + + PCollection<Object> resultCollection = + FlattenPCollectionTranslator.intermediateCollection( + firstSideInput, firstSideInput.getCoder()); + FlattenPCollectionTranslator.flattenCollections( + sourceCollections, unionTags, resultCollection, context); + return resultCollection; + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslatorTest.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslatorTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslatorTest.java index b2e29b6..64ca0ee 100644 --- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslatorTest.java +++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslatorTest.java @@ -110,7 +110,8 @@ public class FlattenPCollectionTranslatorTest { PCollectionList.of(single).apply(Flatten.<String>pCollections()) .apply(ParDo.of(new EmbeddedCollector())); translator.translate(p, dag); - Assert.assertNotNull(dag.getOperatorMeta("ParDo(EmbeddedCollector)")); + Assert.assertNotNull( + dag.getOperatorMeta("ParDo(EmbeddedCollector)/ParMultiDo(EmbeddedCollector)")); } } http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslatorTest.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslatorTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslatorTest.java deleted file mode 100644 index 2aa0720..0000000 --- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslatorTest.java +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.beam.runners.apex.translation; - -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; - -import com.datatorrent.api.DAG; -import com.datatorrent.api.Sink; -import com.datatorrent.lib.util.KryoCloneUtils; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.regex.Pattern; -import org.apache.beam.runners.apex.ApexPipelineOptions; -import org.apache.beam.runners.apex.ApexRunner; -import org.apache.beam.runners.apex.ApexRunnerResult; -import org.apache.beam.runners.apex.TestApexRunner; -import org.apache.beam.runners.apex.translation.operators.ApexParDoOperator; -import org.apache.beam.runners.apex.translation.operators.ApexReadUnboundedInputOperator; -import org.apache.beam.runners.apex.translation.utils.ApexStateInternals; -import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.SerializableCoder; -import org.apache.beam.sdk.coders.VarIntCoder; -import org.apache.beam.sdk.coders.VoidCoder; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.Sum; -import org.apache.beam.sdk.transforms.View; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.util.WindowingStrategy; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionTuple; -import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.TupleTagList; -import org.junit.Assert; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * integration test for {@link ParDoBoundTranslator}. - */ -@RunWith(JUnit4.class) -public class ParDoBoundTranslatorTest { - private static final Logger LOG = LoggerFactory.getLogger(ParDoBoundTranslatorTest.class); - private static final long SLEEP_MILLIS = 500; - private static final long TIMEOUT_MILLIS = 30000; - - @Test - public void test() throws Exception { - ApexPipelineOptions options = PipelineOptionsFactory.create() - .as(ApexPipelineOptions.class); - options.setApplicationName("ParDoBound"); - options.setRunner(ApexRunner.class); - - Pipeline p = Pipeline.create(options); - - List<Integer> collection = Lists.newArrayList(1, 2, 3, 4, 5); - List<Integer> expected = Lists.newArrayList(6, 7, 8, 9, 10); - p.apply(Create.of(collection).withCoder(SerializableCoder.of(Integer.class))) - .apply(ParDo.of(new Add(5))) - .apply(ParDo.of(new EmbeddedCollector())); - - ApexRunnerResult result = (ApexRunnerResult) p.run(); - DAG dag = result.getApexDAG(); - - DAG.OperatorMeta om = dag.getOperatorMeta("Create.Values"); - Assert.assertNotNull(om); - Assert.assertEquals(om.getOperator().getClass(), ApexReadUnboundedInputOperator.class); - - om = dag.getOperatorMeta("ParDo(Add)"); - Assert.assertNotNull(om); - Assert.assertEquals(om.getOperator().getClass(), ApexParDoOperator.class); - - long timeout = System.currentTimeMillis() + TIMEOUT_MILLIS; - while (System.currentTimeMillis() < timeout) { - if (EmbeddedCollector.RESULTS.containsAll(expected)) { - break; - } - LOG.info("Waiting for expected results."); - Thread.sleep(SLEEP_MILLIS); - } - Assert.assertEquals(Sets.newHashSet(expected), EmbeddedCollector.RESULTS); - } - - private static class Add extends DoFn<Integer, Integer> { - private static final long serialVersionUID = 1L; - private Integer number; - private PCollectionView<Integer> sideInputView; - - private Add(Integer number) { - this.number = number; - } - - private Add(PCollectionView<Integer> sideInputView) { - this.sideInputView = sideInputView; - } - - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - if (sideInputView != null) { - number = c.sideInput(sideInputView); - } - c.output(c.element() + number); - } - } - - private static class EmbeddedCollector extends DoFn<Object, Void> { - private static final long serialVersionUID = 1L; - private static final Set<Object> RESULTS = Collections.synchronizedSet(new HashSet<>()); - - public EmbeddedCollector() { - RESULTS.clear(); - } - - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - RESULTS.add(c.element()); - } - } - - private static Throwable runExpectingAssertionFailure(Pipeline pipeline) { - // We cannot use thrown.expect(AssertionError.class) because the AssertionError - // is first caught by JUnit and causes a test failure. - try { - pipeline.run(); - } catch (AssertionError exc) { - return exc; - } - fail("assertion should have failed"); - throw new RuntimeException("unreachable"); - } - - @Test - public void testAssertionFailure() throws Exception { - ApexPipelineOptions options = PipelineOptionsFactory.create() - .as(ApexPipelineOptions.class); - options.setRunner(TestApexRunner.class); - Pipeline pipeline = Pipeline.create(options); - - PCollection<Integer> pcollection = pipeline - .apply(Create.of(1, 2, 3, 4)); - PAssert.that(pcollection).containsInAnyOrder(2, 1, 4, 3, 7); - - Throwable exc = runExpectingAssertionFailure(pipeline); - Pattern expectedPattern = Pattern.compile( - "Expected: iterable over \\[((<4>|<7>|<3>|<2>|<1>)(, )?){5}\\] in any order"); - // A loose pattern, but should get the job done. - assertTrue( - "Expected error message from PAssert with substring matching " - + expectedPattern - + " but the message was \"" - + exc.getMessage() - + "\"", - expectedPattern.matcher(exc.getMessage()).find()); - } - - @Test - public void testContainsInAnyOrder() throws Exception { - ApexPipelineOptions options = PipelineOptionsFactory.create().as(ApexPipelineOptions.class); - options.setRunner(TestApexRunner.class); - Pipeline pipeline = Pipeline.create(options); - PCollection<Integer> pcollection = pipeline.apply(Create.of(1, 2, 3, 4)); - PAssert.that(pcollection).containsInAnyOrder(2, 1, 4, 3); - // TODO: terminate faster based on processed assertion vs. auto-shutdown - pipeline.run(); - } - - @Test - public void testSerialization() throws Exception { - ApexPipelineOptions options = PipelineOptionsFactory.create() - .as(ApexPipelineOptions.class); - options.setRunner(TestApexRunner.class); - Pipeline pipeline = Pipeline.create(options); - Coder<WindowedValue<Integer>> coder = WindowedValue.getValueOnlyCoder(VarIntCoder.of()); - - PCollectionView<Integer> singletonView = pipeline.apply(Create.of(1)) - .apply(Sum.integersGlobally().asSingletonView()); - - ApexParDoOperator<Integer, Integer> operator = - new ApexParDoOperator<>( - options, - new Add(singletonView), - new TupleTag<Integer>(), - TupleTagList.empty().getAll(), - WindowingStrategy.globalDefault(), - Collections.<PCollectionView<?>>singletonList(singletonView), - coder, - new ApexStateInternals.ApexStateInternalsFactory<Void>()); - operator.setup(null); - operator.beginWindow(0); - WindowedValue<Integer> wv1 = WindowedValue.valueInGlobalWindow(1); - WindowedValue<Iterable<?>> sideInput = WindowedValue.<Iterable<?>>valueInGlobalWindow( - Lists.<Integer>newArrayList(22)); - operator.input.process(ApexStreamTuple.DataTuple.of(wv1)); // pushed back input - - final List<Object> results = Lists.newArrayList(); - Sink<Object> sink = new Sink<Object>() { - @Override - public void put(Object tuple) { - results.add(tuple); - } - @Override - public int getCount(boolean reset) { - return 0; - } - }; - - // verify pushed back input checkpointing - Assert.assertNotNull("Serialization", operator = KryoCloneUtils.cloneObject(operator)); - operator.output.setSink(sink); - operator.setup(null); - operator.beginWindow(1); - WindowedValue<Integer> wv2 = WindowedValue.valueInGlobalWindow(2); - operator.sideInput1.process(ApexStreamTuple.DataTuple.of(sideInput)); - Assert.assertEquals("number outputs", 1, results.size()); - Assert.assertEquals("result", WindowedValue.valueInGlobalWindow(23), - ((ApexStreamTuple.DataTuple<?>) results.get(0)).getValue()); - - // verify side input checkpointing - results.clear(); - Assert.assertNotNull("Serialization", operator = KryoCloneUtils.cloneObject(operator)); - operator.output.setSink(sink); - operator.setup(null); - operator.beginWindow(2); - operator.input.process(ApexStreamTuple.DataTuple.of(wv2)); - Assert.assertEquals("number outputs", 1, results.size()); - Assert.assertEquals("result", WindowedValue.valueInGlobalWindow(24), - ((ApexStreamTuple.DataTuple<?>) results.get(0)).getValue()); - } - - @Test - public void testMultiOutputParDoWithSideInputs() throws Exception { - ApexPipelineOptions options = PipelineOptionsFactory.create().as(ApexPipelineOptions.class); - options.setRunner(ApexRunner.class); // non-blocking run - Pipeline pipeline = Pipeline.create(options); - - List<Integer> inputs = Arrays.asList(3, -42, 666); - final TupleTag<String> mainOutputTag = new TupleTag<>("main"); - final TupleTag<Void> sideOutputTag = new TupleTag<>("sideOutput"); - - PCollectionView<Integer> sideInput1 = pipeline - .apply("CreateSideInput1", Create.of(11)) - .apply("ViewSideInput1", View.<Integer>asSingleton()); - PCollectionView<Integer> sideInputUnread = pipeline - .apply("CreateSideInputUnread", Create.of(-3333)) - .apply("ViewSideInputUnread", View.<Integer>asSingleton()); - PCollectionView<Integer> sideInput2 = pipeline - .apply("CreateSideInput2", Create.of(222)) - .apply("ViewSideInput2", View.<Integer>asSingleton()); - - PCollectionTuple outputs = pipeline - .apply(Create.of(inputs)) - .apply(ParDo.withSideInputs(sideInput1) - .withSideInputs(sideInputUnread) - .withSideInputs(sideInput2) - .withOutputTags(mainOutputTag, TupleTagList.of(sideOutputTag)) - .of(new TestMultiOutputWithSideInputsFn( - Arrays.asList(sideInput1, sideInput2), - Arrays.<TupleTag<String>>asList()))); - - outputs.get(mainOutputTag).apply(ParDo.of(new EmbeddedCollector())); - outputs.get(sideOutputTag).setCoder(VoidCoder.of()); - ApexRunnerResult result = (ApexRunnerResult) pipeline.run(); - - HashSet<String> expected = Sets.newHashSet("processing: 3: [11, 222]", - "processing: -42: [11, 222]", "processing: 666: [11, 222]"); - long timeout = System.currentTimeMillis() + TIMEOUT_MILLIS; - while (System.currentTimeMillis() < timeout) { - if (EmbeddedCollector.RESULTS.containsAll(expected)) { - break; - } - LOG.info("Waiting for expected results."); - Thread.sleep(SLEEP_MILLIS); - } - result.cancel(); - Assert.assertEquals(Sets.newHashSet(expected), EmbeddedCollector.RESULTS); - } - - private static class TestMultiOutputWithSideInputsFn extends DoFn<Integer, String> { - private static final long serialVersionUID = 1L; - - final List<PCollectionView<Integer>> sideInputViews = new ArrayList<>(); - final List<TupleTag<String>> sideOutputTupleTags = new ArrayList<>(); - - public TestMultiOutputWithSideInputsFn(List<PCollectionView<Integer>> sideInputViews, - List<TupleTag<String>> sideOutputTupleTags) { - this.sideInputViews.addAll(sideInputViews); - this.sideOutputTupleTags.addAll(sideOutputTupleTags); - } - - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - outputToAllWithSideInputs(c, "processing: " + c.element()); - } - - private void outputToAllWithSideInputs(ProcessContext c, String value) { - if (!sideInputViews.isEmpty()) { - List<Integer> sideInputValues = new ArrayList<>(); - for (PCollectionView<Integer> sideInputView : sideInputViews) { - sideInputValues.add(c.sideInput(sideInputView)); - } - value += ": " + sideInputValues; - } - c.output(value); - for (TupleTag<String> sideOutputTupleTag : sideOutputTupleTags) { - c.sideOutput(sideOutputTupleTag, - sideOutputTupleTag.getId() + ": " + value); - } - } - - } - -} http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java new file mode 100644 index 0000000..83e68f7 --- /dev/null +++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.apex.translation; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.datatorrent.api.DAG; +import com.datatorrent.api.Sink; +import com.datatorrent.lib.util.KryoCloneUtils; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.regex.Pattern; +import org.apache.beam.runners.apex.ApexPipelineOptions; +import org.apache.beam.runners.apex.ApexRunner; +import org.apache.beam.runners.apex.ApexRunnerResult; +import org.apache.beam.runners.apex.TestApexRunner; +import org.apache.beam.runners.apex.translation.operators.ApexParDoOperator; +import org.apache.beam.runners.apex.translation.operators.ApexReadUnboundedInputOperator; +import org.apache.beam.runners.apex.translation.utils.ApexStateInternals; +import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.View; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * integration test for {@link ParDoTranslator}. + */ +@RunWith(JUnit4.class) +public class ParDoTranslatorTest { + private static final Logger LOG = LoggerFactory.getLogger(ParDoTranslatorTest.class); + private static final long SLEEP_MILLIS = 500; + private static final long TIMEOUT_MILLIS = 30000; + + @Test + public void test() throws Exception { + ApexPipelineOptions options = PipelineOptionsFactory.create() + .as(ApexPipelineOptions.class); + options.setApplicationName("ParDoBound"); + options.setRunner(ApexRunner.class); + + Pipeline p = Pipeline.create(options); + + List<Integer> collection = Lists.newArrayList(1, 2, 3, 4, 5); + List<Integer> expected = Lists.newArrayList(6, 7, 8, 9, 10); + p.apply(Create.of(collection).withCoder(SerializableCoder.of(Integer.class))) + .apply(ParDo.of(new Add(5))) + .apply(ParDo.of(new EmbeddedCollector())); + + ApexRunnerResult result = (ApexRunnerResult) p.run(); + DAG dag = result.getApexDAG(); + + DAG.OperatorMeta om = dag.getOperatorMeta("Create.Values"); + Assert.assertNotNull(om); + Assert.assertEquals(om.getOperator().getClass(), ApexReadUnboundedInputOperator.class); + + om = dag.getOperatorMeta("ParDo(Add)/ParMultiDo(Add)"); + Assert.assertNotNull(om); + Assert.assertEquals(om.getOperator().getClass(), ApexParDoOperator.class); + + long timeout = System.currentTimeMillis() + TIMEOUT_MILLIS; + while (System.currentTimeMillis() < timeout) { + if (EmbeddedCollector.RESULTS.containsAll(expected)) { + break; + } + LOG.info("Waiting for expected results."); + Thread.sleep(SLEEP_MILLIS); + } + Assert.assertEquals(Sets.newHashSet(expected), EmbeddedCollector.RESULTS); + } + + private static class Add extends DoFn<Integer, Integer> { + private static final long serialVersionUID = 1L; + private Integer number; + private PCollectionView<Integer> sideInputView; + + private Add(Integer number) { + this.number = number; + } + + private Add(PCollectionView<Integer> sideInputView) { + this.sideInputView = sideInputView; + } + + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + if (sideInputView != null) { + number = c.sideInput(sideInputView); + } + c.output(c.element() + number); + } + } + + private static class EmbeddedCollector extends DoFn<Object, Void> { + private static final long serialVersionUID = 1L; + private static final Set<Object> RESULTS = Collections.synchronizedSet(new HashSet<>()); + + public EmbeddedCollector() { + RESULTS.clear(); + } + + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + RESULTS.add(c.element()); + } + } + + private static Throwable runExpectingAssertionFailure(Pipeline pipeline) { + // We cannot use thrown.expect(AssertionError.class) because the AssertionError + // is first caught by JUnit and causes a test failure. + try { + pipeline.run(); + } catch (AssertionError exc) { + return exc; + } + fail("assertion should have failed"); + throw new RuntimeException("unreachable"); + } + + @Test + public void testAssertionFailure() throws Exception { + ApexPipelineOptions options = PipelineOptionsFactory.create() + .as(ApexPipelineOptions.class); + options.setRunner(TestApexRunner.class); + Pipeline pipeline = Pipeline.create(options); + + PCollection<Integer> pcollection = pipeline + .apply(Create.of(1, 2, 3, 4)); + PAssert.that(pcollection).containsInAnyOrder(2, 1, 4, 3, 7); + + Throwable exc = runExpectingAssertionFailure(pipeline); + Pattern expectedPattern = Pattern.compile( + "Expected: iterable over \\[((<4>|<7>|<3>|<2>|<1>)(, )?){5}\\] in any order"); + // A loose pattern, but should get the job done. + assertTrue( + "Expected error message from PAssert with substring matching " + + expectedPattern + + " but the message was \"" + + exc.getMessage() + + "\"", + expectedPattern.matcher(exc.getMessage()).find()); + } + + @Test + public void testContainsInAnyOrder() throws Exception { + ApexPipelineOptions options = PipelineOptionsFactory.create().as(ApexPipelineOptions.class); + options.setRunner(TestApexRunner.class); + Pipeline pipeline = Pipeline.create(options); + PCollection<Integer> pcollection = pipeline.apply(Create.of(1, 2, 3, 4)); + PAssert.that(pcollection).containsInAnyOrder(2, 1, 4, 3); + // TODO: terminate faster based on processed assertion vs. auto-shutdown + pipeline.run(); + } + + @Test + public void testSerialization() throws Exception { + ApexPipelineOptions options = PipelineOptionsFactory.create() + .as(ApexPipelineOptions.class); + options.setRunner(TestApexRunner.class); + Pipeline pipeline = Pipeline.create(options); + Coder<WindowedValue<Integer>> coder = WindowedValue.getValueOnlyCoder(VarIntCoder.of()); + + PCollectionView<Integer> singletonView = pipeline.apply(Create.of(1)) + .apply(Sum.integersGlobally().asSingletonView()); + + ApexParDoOperator<Integer, Integer> operator = + new ApexParDoOperator<>( + options, + new Add(singletonView), + new TupleTag<Integer>(), + TupleTagList.empty().getAll(), + WindowingStrategy.globalDefault(), + Collections.<PCollectionView<?>>singletonList(singletonView), + coder, + new ApexStateInternals.ApexStateInternalsFactory<Void>()); + operator.setup(null); + operator.beginWindow(0); + WindowedValue<Integer> wv1 = WindowedValue.valueInGlobalWindow(1); + WindowedValue<Iterable<?>> sideInput = WindowedValue.<Iterable<?>>valueInGlobalWindow( + Lists.<Integer>newArrayList(22)); + operator.input.process(ApexStreamTuple.DataTuple.of(wv1)); // pushed back input + + final List<Object> results = Lists.newArrayList(); + Sink<Object> sink = new Sink<Object>() { + @Override + public void put(Object tuple) { + results.add(tuple); + } + @Override + public int getCount(boolean reset) { + return 0; + } + }; + + // verify pushed back input checkpointing + Assert.assertNotNull("Serialization", operator = KryoCloneUtils.cloneObject(operator)); + operator.output.setSink(sink); + operator.setup(null); + operator.beginWindow(1); + WindowedValue<Integer> wv2 = WindowedValue.valueInGlobalWindow(2); + operator.sideInput1.process(ApexStreamTuple.DataTuple.of(sideInput)); + Assert.assertEquals("number outputs", 1, results.size()); + Assert.assertEquals("result", WindowedValue.valueInGlobalWindow(23), + ((ApexStreamTuple.DataTuple<?>) results.get(0)).getValue()); + + // verify side input checkpointing + results.clear(); + Assert.assertNotNull("Serialization", operator = KryoCloneUtils.cloneObject(operator)); + operator.output.setSink(sink); + operator.setup(null); + operator.beginWindow(2); + operator.input.process(ApexStreamTuple.DataTuple.of(wv2)); + Assert.assertEquals("number outputs", 1, results.size()); + Assert.assertEquals("result", WindowedValue.valueInGlobalWindow(24), + ((ApexStreamTuple.DataTuple<?>) results.get(0)).getValue()); + } + + @Test + public void testMultiOutputParDoWithSideInputs() throws Exception { + ApexPipelineOptions options = PipelineOptionsFactory.create().as(ApexPipelineOptions.class); + options.setRunner(ApexRunner.class); // non-blocking run + Pipeline pipeline = Pipeline.create(options); + + List<Integer> inputs = Arrays.asList(3, -42, 666); + final TupleTag<String> mainOutputTag = new TupleTag<>("main"); + final TupleTag<Void> sideOutputTag = new TupleTag<>("sideOutput"); + + PCollectionView<Integer> sideInput1 = pipeline + .apply("CreateSideInput1", Create.of(11)) + .apply("ViewSideInput1", View.<Integer>asSingleton()); + PCollectionView<Integer> sideInputUnread = pipeline + .apply("CreateSideInputUnread", Create.of(-3333)) + .apply("ViewSideInputUnread", View.<Integer>asSingleton()); + PCollectionView<Integer> sideInput2 = pipeline + .apply("CreateSideInput2", Create.of(222)) + .apply("ViewSideInput2", View.<Integer>asSingleton()); + + PCollectionTuple outputs = pipeline + .apply(Create.of(inputs)) + .apply(ParDo.withSideInputs(sideInput1) + .withSideInputs(sideInputUnread) + .withSideInputs(sideInput2) + .withOutputTags(mainOutputTag, TupleTagList.of(sideOutputTag)) + .of(new TestMultiOutputWithSideInputsFn( + Arrays.asList(sideInput1, sideInput2), + Arrays.<TupleTag<String>>asList()))); + + outputs.get(mainOutputTag).apply(ParDo.of(new EmbeddedCollector())); + outputs.get(sideOutputTag).setCoder(VoidCoder.of()); + ApexRunnerResult result = (ApexRunnerResult) pipeline.run(); + + HashSet<String> expected = Sets.newHashSet("processing: 3: [11, 222]", + "processing: -42: [11, 222]", "processing: 666: [11, 222]"); + long timeout = System.currentTimeMillis() + TIMEOUT_MILLIS; + while (System.currentTimeMillis() < timeout) { + if (EmbeddedCollector.RESULTS.containsAll(expected)) { + break; + } + LOG.info("Waiting for expected results."); + Thread.sleep(SLEEP_MILLIS); + } + result.cancel(); + Assert.assertEquals(Sets.newHashSet(expected), EmbeddedCollector.RESULTS); + } + + private static class TestMultiOutputWithSideInputsFn extends DoFn<Integer, String> { + private static final long serialVersionUID = 1L; + + final List<PCollectionView<Integer>> sideInputViews = new ArrayList<>(); + final List<TupleTag<String>> sideOutputTupleTags = new ArrayList<>(); + + public TestMultiOutputWithSideInputsFn(List<PCollectionView<Integer>> sideInputViews, + List<TupleTag<String>> sideOutputTupleTags) { + this.sideInputViews.addAll(sideInputViews); + this.sideOutputTupleTags.addAll(sideOutputTupleTags); + } + + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + outputToAllWithSideInputs(c, "processing: " + c.element()); + } + + private void outputToAllWithSideInputs(ProcessContext c, String value) { + if (!sideInputViews.isEmpty()) { + List<Integer> sideInputValues = new ArrayList<>(); + for (PCollectionView<Integer> sideInputView : sideInputViews) { + sideInputValues.add(c.sideInput(sideInputView)); + } + value += ": " + sideInputValues; + } + c.output(value); + for (TupleTag<String> sideOutputTupleTag : sideOutputTupleTags) { + c.sideOutput(sideOutputTupleTag, + sideOutputTupleTag.getId() + ": " + value); + } + } + + } + +} http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index f56d225..4601262 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -89,24 +89,10 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { .put( PTransformMatchers.classEqualTo(TestStream.class), new DirectTestStreamFactory()) /* primitive */ - /* Single-output ParDos are implemented in terms of Multi-output ParDos. Any override - that is applied to a multi-output ParDo must first have all matching Single-output ParDos - converted to match. - */ - .put(PTransformMatchers.splittableParDoSingle(), new ParDoSingleViaMultiOverrideFactory()) - .put( - PTransformMatchers.stateOrTimerParDoSingle(), - new ParDoSingleViaMultiOverrideFactory()) - // SplittableParMultiDo is implemented in terms of nonsplittable single ParDos - .put(PTransformMatchers.splittableParDoMulti(), new ParDoMultiOverrideFactory()) - // state and timer pardos are implemented in terms of nonsplittable single ParDos - .put(PTransformMatchers.stateOrTimerParDoMulti(), new ParDoMultiOverrideFactory()) - .put( - PTransformMatchers.classEqualTo(ParDo.Bound.class), - new ParDoSingleViaMultiOverrideFactory()) /* returns a BoundMulti */ .put( PTransformMatchers.classEqualTo(BoundMulti.class), - /* returns one of two primitives; SplittableParDos are replaced above. */ + /* returns one of two primitives; SplittableParDos and ParDos with state and timers + are replaced appropriately by the override factory. */ new ParDoMultiOverrideFactory()) .put( PTransformMatchers.classEqualTo(GBKIntoKeyedWorkItems.class), http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java deleted file mode 100644 index f859729..0000000 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.direct; - -import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory; -import org.apache.beam.sdk.runners.PTransformOverrideFactory; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.ParDo.Bound; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionTuple; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.TupleTagList; - -/** - * A {@link PTransformOverrideFactory} that overrides single-output {@link ParDo} to implement - * it in terms of multi-output {@link ParDo}. - */ -class ParDoSingleViaMultiOverrideFactory<InputT, OutputT> - extends SingleInputOutputOverrideFactory< - PCollection<? extends InputT>, PCollection<OutputT>, Bound<InputT, OutputT>> { - @Override - public PTransform<PCollection<? extends InputT>, PCollection<OutputT>> getReplacementTransform( - Bound<InputT, OutputT> transform) { - return new ParDoSingleViaMulti<>(transform); - } - - static class ParDoSingleViaMulti<InputT, OutputT> - extends PTransform<PCollection<? extends InputT>, PCollection<OutputT>> { - private static final String MAIN_OUTPUT_TAG = "main"; - - private final ParDo.Bound<InputT, OutputT> underlyingParDo; - - public ParDoSingleViaMulti(ParDo.Bound<InputT, OutputT> underlyingParDo) { - this.underlyingParDo = underlyingParDo; - } - - @Override - public PCollection<OutputT> expand(PCollection<? extends InputT> input) { - - // Output tags for ParDo need only be unique up to applied transform - TupleTag<OutputT> mainOutputTag = new TupleTag<OutputT>(MAIN_OUTPUT_TAG); - - PCollectionTuple outputs = - input.apply( - ParDo.of(underlyingParDo.getFn()) - .withSideInputs(underlyingParDo.getSideInputs()) - .withOutputTags(mainOutputTag, TupleTagList.empty())); - PCollection<OutputT> output = outputs.get(mainOutputTag); - - output.setTypeDescriptor(underlyingParDo.getFn().getOutputTypeDescriptor()); - return output; - } - } -} http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java deleted file mode 100644 index 59577a8..0000000 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.beam.runners.direct; - -import static org.junit.Assert.assertThat; - -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.values.PCollection; -import org.hamcrest.Matchers; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Tests for {@link ParDoSingleViaMultiOverrideFactory}. - */ -@RunWith(JUnit4.class) -public class ParDoSingleViaMultiOverrideFactoryTest { - private ParDoSingleViaMultiOverrideFactory<Integer, Integer> factory = - new ParDoSingleViaMultiOverrideFactory<>(); - - @Test - public void getInputSucceeds() { - TestPipeline p = TestPipeline.create(); - PCollection<Integer> input = p.apply(Create.of(1, 2, 3)); - PCollection<?> reconstructed = factory.getInput(input.expand(), p); - assertThat(reconstructed, Matchers.<PCollection<?>>equalTo(input)); - } -} http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java index f043c90..31a6bda 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java @@ -112,8 +112,7 @@ class FlinkBatchTransformTranslators { TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslatorBatch()); - TRANSLATORS.put(ParDo.Bound.class, new ParDoBoundTranslatorBatch()); - TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoBoundMultiTranslatorBatch()); + TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoTranslatorBatch()); TRANSLATORS.put(Read.Bounded.class, new ReadSourceTranslatorBatch()); } @@ -498,80 +497,7 @@ class FlinkBatchTransformTranslators { } } - private static class ParDoBoundTranslatorBatch<InputT, OutputT> - implements FlinkBatchPipelineTranslator.BatchTransformTranslator< - ParDo.Bound<InputT, OutputT>> { - - @Override - @SuppressWarnings("unchecked") - public void translateNode( - ParDo.Bound<InputT, OutputT> transform, - - FlinkBatchTranslationContext context) { - DoFn<InputT, OutputT> doFn = transform.getFn(); - rejectSplittable(doFn); - - DataSet<WindowedValue<InputT>> inputDataSet = - context.getInputDataSet(context.getInput(transform)); - - TypeInformation<WindowedValue<OutputT>> typeInformation = - context.getTypeInfo(context.getOutput(transform)); - - List<PCollectionView<?>> sideInputs = transform.getSideInputs(); - - // construct a map from side input to WindowingStrategy so that - // the DoFn runner can map main-input windows to side input windows - Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputStrategies = new HashMap<>(); - for (PCollectionView<?> sideInput: sideInputs) { - sideInputStrategies.put(sideInput, sideInput.getWindowingStrategyInternal()); - } - - WindowingStrategy<?, ?> windowingStrategy = - context.getOutput(transform).getWindowingStrategy(); - - SingleInputUdfOperator<WindowedValue<InputT>, WindowedValue<OutputT>, ?> outputDataSet; - DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass()); - if (signature.stateDeclarations().size() > 0 - || signature.timerDeclarations().size() > 0) { - - // Based on the fact that the signature is stateful, DoFnSignatures ensures - // that it is also keyed - KvCoder<?, InputT> inputCoder = - (KvCoder<?, InputT>) context.getInput(transform).getCoder(); - - FlinkStatefulDoFnFunction<?, ?, OutputT> doFnWrapper = new FlinkStatefulDoFnFunction<>( - (DoFn) doFn, windowingStrategy, sideInputStrategies, context.getPipelineOptions(), - null, new TupleTag<OutputT>() - ); - - Grouping<WindowedValue<InputT>> grouping = - inputDataSet.groupBy(new KvKeySelector(inputCoder.getKeyCoder())); - - outputDataSet = new GroupReduceOperator( - grouping, typeInformation, doFnWrapper, transform.getName()); - - } else { - FlinkDoFnFunction<InputT, OutputT> doFnWrapper = - new FlinkDoFnFunction<>( - doFn, - windowingStrategy, - sideInputStrategies, - context.getPipelineOptions(), - null, new TupleTag<OutputT>()); - - outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper, - transform.getName()); - - } - - transformSideInputs(sideInputs, outputDataSet, context); - - context.setOutputDataSet(context.getOutput(transform), outputDataSet); - - } - } - - private static class ParDoBoundMultiTranslatorBatch<InputT, OutputT> + private static class ParDoTranslatorBatch<InputT, OutputT> implements FlinkBatchPipelineTranslator.BatchTransformTranslator< ParDo.BoundMulti<InputT, OutputT>> { http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index c7df91d..7227dce 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -121,8 +121,7 @@ class FlinkStreamingTransformTranslators { TRANSLATORS.put(Write.class, new WriteSinkStreamingTranslator()); TRANSLATORS.put(TextIO.Write.Bound.class, new TextIOWriteBoundStreamingTranslator()); - TRANSLATORS.put(ParDo.Bound.class, new ParDoBoundStreamingTranslator()); - TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoBoundMultiStreamingTranslator()); + TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoStreamingTranslator()); TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslator()); TRANSLATORS.put(Flatten.PCollections.class, new FlattenPCollectionTranslator()); @@ -320,114 +319,6 @@ class FlinkStreamingTransformTranslators { } } - private static class ParDoBoundStreamingTranslator<InputT, OutputT> - extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< - ParDo.Bound<InputT, OutputT>> { - - @Override - public void translateNode( - ParDo.Bound<InputT, OutputT> transform, - FlinkStreamingTranslationContext context) { - - DoFn<InputT, OutputT> doFn = transform.getFn(); - rejectSplittable(doFn); - - WindowingStrategy<?, ?> windowingStrategy = - context.getOutput(transform).getWindowingStrategy(); - - TypeInformation<WindowedValue<OutputT>> typeInfo = - context.getTypeInfo(context.getOutput(transform)); - - List<PCollectionView<?>> sideInputs = transform.getSideInputs(); - - @SuppressWarnings("unchecked") - PCollection<InputT> inputPCollection = (PCollection<InputT>) context.getInput(transform); - - Coder<WindowedValue<InputT>> inputCoder = context.getCoder(inputPCollection); - - DataStream<WindowedValue<InputT>> inputDataStream = - context.getInputDataStream(context.getInput(transform)); - Coder keyCoder = null; - boolean stateful = false; - DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass()); - if (signature.stateDeclarations().size() > 0 - || signature.timerDeclarations().size() > 0) { - // Based on the fact that the signature is stateful, DoFnSignatures ensures - // that it is also keyed - keyCoder = ((KvCoder) inputPCollection.getCoder()).getKeyCoder(); - inputDataStream = inputDataStream.keyBy(new KvToByteBufferKeySelector(keyCoder)); - stateful = true; - } - - if (sideInputs.isEmpty()) { - DoFnOperator<InputT, OutputT, WindowedValue<OutputT>> doFnOperator = - new DoFnOperator<>( - transform.getFn(), - inputCoder, - new TupleTag<OutputT>("main output"), - Collections.<TupleTag<?>>emptyList(), - new DoFnOperator.DefaultOutputManagerFactory<WindowedValue<OutputT>>(), - windowingStrategy, - new HashMap<Integer, PCollectionView<?>>(), /* side-input mapping */ - Collections.<PCollectionView<?>>emptyList(), /* side inputs */ - context.getPipelineOptions(), - keyCoder); - - SingleOutputStreamOperator<WindowedValue<OutputT>> outDataStream = inputDataStream - .transform(transform.getName(), typeInfo, doFnOperator); - - context.setOutputDataStream(context.getOutput(transform), outDataStream); - } else { - Tuple2<Map<Integer, PCollectionView<?>>, DataStream<RawUnionValue>> transformedSideInputs = - transformSideInputs(sideInputs, context); - - DoFnOperator<InputT, OutputT, WindowedValue<OutputT>> doFnOperator = - new DoFnOperator<>( - transform.getFn(), - inputCoder, - new TupleTag<OutputT>("main output"), - Collections.<TupleTag<?>>emptyList(), - new DoFnOperator.DefaultOutputManagerFactory<WindowedValue<OutputT>>(), - windowingStrategy, - transformedSideInputs.f0, - sideInputs, - context.getPipelineOptions(), - keyCoder); - - SingleOutputStreamOperator<WindowedValue<OutputT>> outDataStream; - if (stateful) { - // we have to manually contruct the two-input transform because we're not - // allowed to have only one input keyed, normally. - KeyedStream keyedStream = (KeyedStream<?, InputT>) inputDataStream; - TwoInputTransformation< - WindowedValue<KV<?, InputT>>, - RawUnionValue, - WindowedValue<OutputT>> rawFlinkTransform = new TwoInputTransformation<>( - keyedStream.getTransformation(), - transformedSideInputs.f1.broadcast().getTransformation(), - transform.getName(), - (TwoInputStreamOperator) doFnOperator, - typeInfo, - keyedStream.getParallelism()); - - rawFlinkTransform.setStateKeyType(keyedStream.getKeyType()); - rawFlinkTransform.setStateKeySelectors(keyedStream.getKeySelector(), null); - - outDataStream = new SingleOutputStreamOperator( - keyedStream.getExecutionEnvironment(), - rawFlinkTransform) {}; // we have to cheat around the ctor being protected - - keyedStream.getExecutionEnvironment().addOperator(rawFlinkTransform); - } else { - outDataStream = inputDataStream - .connect(transformedSideInputs.f1.broadcast()) - .transform(transform.getName(), typeInfo, doFnOperator); - } - context.setOutputDataStream(context.getOutput(transform), outDataStream); - } - } - } - /** * Wraps each element in a {@link RawUnionValue} with the given tag id. */ @@ -505,7 +396,7 @@ class FlinkStreamingTransformTranslators { } - private static class ParDoBoundMultiStreamingTranslator<InputT, OutputT> + private static class ParDoStreamingTranslator<InputT, OutputT> extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< ParDo.BoundMulti<InputT, OutputT>> { http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index 06e5048..ab4cb9c 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -45,7 +45,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Supplier; import com.google.common.collect.BiMap; import com.google.common.collect.ImmutableBiMap; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import java.io.IOException; import java.util.ArrayList; @@ -848,34 +847,6 @@ public class DataflowPipelineTranslator { }); registerTransformTranslator( - ParDo.Bound.class, - new TransformTranslator<ParDo.Bound>() { - @Override - public void translate(ParDo.Bound transform, TranslationContext context) { - translateSingleHelper(transform, context); - } - - private <InputT, OutputT> void translateSingleHelper( - ParDo.Bound<InputT, OutputT> transform, TranslationContext context) { - - StepTranslationContext stepContext = context.addStep(transform, "ParallelDo"); - translateInputs( - stepContext, context.getInput(transform), transform.getSideInputs(), context); - long mainOutput = stepContext.addOutput(context.getOutput(transform)); - translateFn( - stepContext, - transform.getFn(), - context.getInput(transform).getWindowingStrategy(), - transform.getSideInputs(), - context.getInput(transform).getCoder(), - context, - mainOutput, - ImmutableMap.<Long, TupleTag<?>>of( - mainOutput, new TupleTag<>(PropertyNames.OUTPUT))); - } - }); - - registerTransformTranslator( Window.Assign.class, new TransformTranslator<Window.Assign>() { @Override http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java index d4271e5..ccb185c 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java @@ -525,7 +525,8 @@ public class DataflowPipelineTranslatorTest implements Serializable { assertEquals(13, job.getSteps().size()); Step step = job.getSteps().get(1); - assertEquals(stepName, getString(step.getProperties(), PropertyNames.USER_NAME)); + assertEquals( + stepName + "/ParMultiDo(NoOp)", getString(step.getProperties(), PropertyNames.USER_NAME)); assertAllStepOutputsHaveUniqueIds(job); return step; } @@ -971,7 +972,7 @@ public class DataflowPipelineTranslatorTest implements Serializable { .put("type", "JAVA_CLASS") .put("value", fn1.getClass().getName()) .put("shortValue", fn1.getClass().getSimpleName()) - .put("namespace", parDo1.getClass().getName()) + .put("namespace", ParDo.BoundMulti.class.getName()) .build(), ImmutableMap.<String, Object>builder() .put("key", "foo2") @@ -991,7 +992,7 @@ public class DataflowPipelineTranslatorTest implements Serializable { .put("type", "JAVA_CLASS") .put("value", fn2.getClass().getName()) .put("shortValue", fn2.getClass().getSimpleName()) - .put("namespace", parDo2.getClass().getName()) + .put("namespace", ParDo.BoundMulti.class.getName()) .build(), ImmutableMap.<String, Object>builder() .put("key", "foo3") http://git-wip-us.apache.org/repos/asf/beam/blob/6253abaa/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 8ebb496..a4939b9 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -27,6 +27,7 @@ import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.replaceSh import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectSplittable; import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers; +import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import java.io.IOException; @@ -73,6 +74,7 @@ import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TaggedPValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.hadoop.conf.Configuration; @@ -331,38 +333,19 @@ public final class TransformTranslator { }; } - private static <InputT, OutputT> TransformEvaluator<ParDo.Bound<InputT, OutputT>> parDo() { - return new TransformEvaluator<ParDo.Bound<InputT, OutputT>>() { - @Override - public void evaluate(ParDo.Bound<InputT, OutputT> transform, EvaluationContext context) { - String stepName = context.getCurrentTransform().getFullName(); - DoFn<InputT, OutputT> doFn = transform.getFn(); - rejectSplittable(doFn); - rejectStateAndTimers(doFn); - @SuppressWarnings("unchecked") - JavaRDD<WindowedValue<InputT>> inRDD = - ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD(); - WindowingStrategy<?, ?> windowingStrategy = - context.getInput(transform).getWindowingStrategy(); - JavaSparkContext jsc = context.getSparkContext(); - Accumulator<NamedAggregators> aggAccum = - SparkAggregators.getNamedAggregators(jsc); - Accumulator<SparkMetricsContainer> metricsAccum = - MetricsAccumulator.getInstance(); - Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs = - TranslationUtils.getSideInputs(transform.getSideInputs(), context); - context.putDataset(transform, - new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(aggAccum, metricsAccum, - stepName, doFn, context.getRuntimeContext(), sideInputs, windowingStrategy)))); - } - }; - } - - private static <InputT, OutputT> TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>> - multiDo() { + private static <InputT, OutputT> TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>> parDo() { return new TransformEvaluator<ParDo.BoundMulti<InputT, OutputT>>() { @Override public void evaluate(ParDo.BoundMulti<InputT, OutputT> transform, EvaluationContext context) { + if (transform.getSideOutputTags().size() == 0) { + evaluateSingle(transform, context); + } else { + evaluateMulti(transform, context); + } + } + + private void evaluateMulti( + ParDo.BoundMulti<InputT, OutputT> transform, EvaluationContext context) { String stepName = context.getCurrentTransform().getFullName(); DoFn<InputT, OutputT> doFn = transform.getFn(); rejectSplittable(doFn); @@ -373,16 +356,21 @@ public final class TransformTranslator { WindowingStrategy<?, ?> windowingStrategy = context.getInput(transform).getWindowingStrategy(); JavaSparkContext jsc = context.getSparkContext(); - Accumulator<NamedAggregators> aggAccum = - SparkAggregators.getNamedAggregators(jsc); - Accumulator<SparkMetricsContainer> metricsAccum = - MetricsAccumulator.getInstance(); - JavaPairRDD<TupleTag<?>, WindowedValue<?>> all = inRDD - .mapPartitionsToPair( - new MultiDoFnFunction<>(aggAccum, metricsAccum, stepName, doFn, - context.getRuntimeContext(), transform.getMainOutputTag(), - TranslationUtils.getSideInputs(transform.getSideInputs(), context), - windowingStrategy)).cache(); + Accumulator<NamedAggregators> aggAccum = SparkAggregators.getNamedAggregators(jsc); + Accumulator<SparkMetricsContainer> metricsAccum = MetricsAccumulator.getInstance(); + JavaPairRDD<TupleTag<?>, WindowedValue<?>> all = + inRDD + .mapPartitionsToPair( + new MultiDoFnFunction<>( + aggAccum, + metricsAccum, + stepName, + doFn, + context.getRuntimeContext(), + transform.getMainOutputTag(), + TranslationUtils.getSideInputs(transform.getSideInputs(), context), + windowingStrategy)) + .cache(); List<TaggedPValue> pct = context.getOutputs(transform); for (TaggedPValue e : pct) { @SuppressWarnings("unchecked") @@ -395,6 +383,37 @@ public final class TransformTranslator { context.putDataset(e.getValue(), new BoundedDataset<>(values)); } } + + private void evaluateSingle( + ParDo.BoundMulti<InputT, OutputT> transform, EvaluationContext context) { + String stepName = context.getCurrentTransform().getFullName(); + DoFn<InputT, OutputT> doFn = transform.getFn(); + rejectSplittable(doFn); + rejectStateAndTimers(doFn); + @SuppressWarnings("unchecked") + JavaRDD<WindowedValue<InputT>> inRDD = + ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD(); + WindowingStrategy<?, ?> windowingStrategy = + context.getInput(transform).getWindowingStrategy(); + JavaSparkContext jsc = context.getSparkContext(); + Accumulator<NamedAggregators> aggAccum = SparkAggregators.getNamedAggregators(jsc); + Accumulator<SparkMetricsContainer> metricsAccum = MetricsAccumulator.getInstance(); + Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs = + TranslationUtils.getSideInputs(transform.getSideInputs(), context); + PValue onlyOutput = Iterables.getOnlyElement(context.getOutputs(transform)).getValue(); + context.putDataset( + onlyOutput, + new BoundedDataset<>( + inRDD.mapPartitions( + new DoFnFunction<>( + aggAccum, + metricsAccum, + stepName, + doFn, + context.getRuntimeContext(), + sideInputs, + windowingStrategy)))); + } }; } @@ -723,8 +742,7 @@ public final class TransformTranslator { EVALUATORS.put(Read.Bounded.class, readBounded()); EVALUATORS.put(HadoopIO.Read.Bound.class, readHadoop()); EVALUATORS.put(HadoopIO.Write.Bound.class, writeHadoop()); - EVALUATORS.put(ParDo.Bound.class, parDo()); - EVALUATORS.put(ParDo.BoundMulti.class, multiDo()); + EVALUATORS.put(ParDo.BoundMulti.class, parDo()); EVALUATORS.put(GroupByKey.class, groupByKey()); EVALUATORS.put(Combine.GroupedValues.class, combineGrouped()); EVALUATORS.put(Combine.Globally.class, combineGlobally());
