Perform a Multi-step combine in the DirectRunner This exercises the entire CombineFn lifecycle for simple combine fns, expressed as a collection of DoFns.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/6ea2eda2 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/6ea2eda2 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/6ea2eda2 Branch: refs/heads/master Commit: 6ea2eda2b5dbe4fc0cc2a84b64b76a07c7d0eda8 Parents: a94d680 Author: Thomas Groh <[email protected]> Authored: Thu Jun 15 15:53:46 2017 -0700 Committer: Thomas Groh <[email protected]> Committed: Fri Jul 28 15:25:11 2017 -0700 ---------------------------------------------------------------------- .../beam/runners/direct/DirectRunner.java | 65 +-- .../beam/runners/direct/MultiStepCombine.java | 423 +++++++++++++++++++ .../direct/TransformEvaluatorRegistry.java | 4 + .../runners/direct/MultiStepCombineTest.java | 228 ++++++++++ 4 files changed, 690 insertions(+), 30 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/6ea2eda2/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index c5f29e5..642ce8f 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -233,36 +233,41 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { PTransformMatchers.writeWithRunnerDeterminedSharding(), new WriteWithShardingFactory())); /* Uses a view internally. */ } - builder = builder.add( - PTransformOverride.of( - PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN), - new ViewOverrideFactory())) /* Uses pardos and GBKs */ - .add( - PTransformOverride.of( - PTransformMatchers.urnEqualTo(PTransformTranslation.TEST_STREAM_TRANSFORM_URN), - new DirectTestStreamFactory(this))) /* primitive */ - // SplittableParMultiDo is implemented in terms of nonsplittable simple ParDos and extra - // primitives - .add( - PTransformOverride.of( - PTransformMatchers.splittableParDo(), new ParDoMultiOverrideFactory())) - // state and timer pardos are implemented in terms of simple ParDos and extra primitives - .add( - PTransformOverride.of( - PTransformMatchers.stateOrTimerParDo(), new ParDoMultiOverrideFactory())) - .add( - PTransformOverride.of( - PTransformMatchers.urnEqualTo( - SplittableParDo.SPLITTABLE_PROCESS_KEYED_ELEMENTS_URN), - new SplittableParDoViaKeyedWorkItems.OverrideFactory())) - .add( - PTransformOverride.of( - PTransformMatchers.urnEqualTo(SplittableParDo.SPLITTABLE_GBKIKWI_URN), - new DirectGBKIntoKeyedWorkItemsOverrideFactory())) /* Returns a GBKO */ - .add( - PTransformOverride.of( - PTransformMatchers.urnEqualTo(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN), - new DirectGroupByKeyOverrideFactory())); /* returns two chained primitives. */ + builder = + builder + .add( + PTransformOverride.of( + MultiStepCombine.matcher(), MultiStepCombine.Factory.create())) + .add( + PTransformOverride.of( + PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN), + new ViewOverrideFactory())) /* Uses pardos and GBKs */ + .add( + PTransformOverride.of( + PTransformMatchers.urnEqualTo(PTransformTranslation.TEST_STREAM_TRANSFORM_URN), + new DirectTestStreamFactory(this))) /* primitive */ + // SplittableParMultiDo is implemented in terms of nonsplittable simple ParDos and extra + // primitives + .add( + PTransformOverride.of( + PTransformMatchers.splittableParDo(), new ParDoMultiOverrideFactory())) + // state and timer pardos are implemented in terms of simple ParDos and extra primitives + .add( + PTransformOverride.of( + PTransformMatchers.stateOrTimerParDo(), new ParDoMultiOverrideFactory())) + .add( + PTransformOverride.of( + PTransformMatchers.urnEqualTo( + SplittableParDo.SPLITTABLE_PROCESS_KEYED_ELEMENTS_URN), + new SplittableParDoViaKeyedWorkItems.OverrideFactory())) + .add( + PTransformOverride.of( + PTransformMatchers.urnEqualTo(SplittableParDo.SPLITTABLE_GBKIKWI_URN), + new DirectGBKIntoKeyedWorkItemsOverrideFactory())) /* Returns a GBKO */ + .add( + PTransformOverride.of( + PTransformMatchers.urnEqualTo(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN), + new DirectGroupByKeyOverrideFactory())); /* returns two chained primitives. */ return builder.build(); } http://git-wip-us.apache.org/repos/asf/beam/blob/6ea2eda2/runners/direct-java/src/main/java/org/apache/beam/runners/direct/MultiStepCombine.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/MultiStepCombine.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/MultiStepCombine.java new file mode 100644 index 0000000..6f49e94 --- /dev/null +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/MultiStepCombine.java @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.direct; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.collect.Iterables; +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; +import javax.annotation.Nullable; +import org.apache.beam.runners.core.construction.CombineTranslation; +import org.apache.beam.runners.core.construction.PTransformTranslation; +import org.apache.beam.runners.core.construction.PTransformTranslation.RawPTransform; +import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.runners.PTransformMatcher; +import org.apache.beam.sdk.runners.PTransformOverrideFactory; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.Combine.PerKey; +import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.DefaultTrigger; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.util.UserCodeException; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.joda.time.Instant; + +/** A {@link Combine} that performs the combine in multiple steps. */ +class MultiStepCombine<K, InputT, AccumT, OutputT> + extends RawPTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>> { + public static PTransformMatcher matcher() { + return new PTransformMatcher() { + @Override + public boolean matches(AppliedPTransform<?, ?, ?> application) { + if (PTransformTranslation.COMBINE_TRANSFORM_URN.equals( + PTransformTranslation.urnForTransformOrNull(application.getTransform()))) { + try { + GlobalCombineFn fn = CombineTranslation.getCombineFn(application); + return isApplicable(application.getInputs(), fn); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + return false; + } + + private <K, InputT> boolean isApplicable( + Map<TupleTag<?>, PValue> inputs, GlobalCombineFn<InputT, ?, ?> fn) { + if (!(fn instanceof CombineFn)) { + return false; + } + if (inputs.size() == 1) { + PCollection<KV<K, InputT>> input = + (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(inputs.values()); + WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy(); + boolean windowFnApplicable = windowingStrategy.getWindowFn().isNonMerging(); + // Triggering with count based triggers is not appropriately handled here. Disabling + // most triggers is safe, though more broad than is technically required. + boolean triggerApplicable = DefaultTrigger.of().equals(windowingStrategy.getTrigger()); + boolean accumulatorCoderAvailable; + try { + if (input.getCoder() instanceof KvCoder) { + KvCoder<K, InputT> kvCoder = (KvCoder<K, InputT>) input.getCoder(); + Coder<?> accumulatorCoder = + fn.getAccumulatorCoder( + input.getPipeline().getCoderRegistry(), kvCoder.getValueCoder()); + accumulatorCoderAvailable = accumulatorCoder != null; + } else { + accumulatorCoderAvailable = false; + } + } catch (CannotProvideCoderException e) { + throw new RuntimeException( + String.format( + "Could not construct an accumulator %s for %s. Accumulator %s for a %s may be" + + " null, but may not throw an exception", + Coder.class.getSimpleName(), + fn, + Coder.class.getSimpleName(), + Combine.class.getSimpleName()), + e); + } + return windowFnApplicable && triggerApplicable && accumulatorCoderAvailable; + } + return false; + } + }; + } + + static class Factory<K, InputT, AccumT, OutputT> + extends SingleInputOutputOverrideFactory< + PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>, + PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>> { + public static PTransformOverrideFactory create() { + return new Factory<>(); + } + + private Factory() {} + + @Override + public PTransformReplacement<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>> + getReplacementTransform( + AppliedPTransform< + PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>, + PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>> + transform) { + try { + GlobalCombineFn<?, ?, ?> globalFn = CombineTranslation.getCombineFn(transform); + checkState( + globalFn instanceof CombineFn, + "%s.matcher() should only match %s instances using %s, got %s", + MultiStepCombine.class.getSimpleName(), + PerKey.class.getSimpleName(), + CombineFn.class.getSimpleName(), + globalFn.getClass().getName()); + @SuppressWarnings("unchecked") + CombineFn<InputT, AccumT, OutputT> fn = (CombineFn<InputT, AccumT, OutputT>) globalFn; + @SuppressWarnings("unchecked") + PCollection<KV<K, InputT>> input = + (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(transform.getInputs().values()); + @SuppressWarnings("unchecked") + PCollection<KV<K, OutputT>> output = + (PCollection<KV<K, OutputT>>) Iterables.getOnlyElement(transform.getOutputs().values()); + return PTransformReplacement.of(input, new MultiStepCombine<>(fn, output.getCoder())); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + // =========================================================================================== + + private final CombineFn<InputT, AccumT, OutputT> combineFn; + private final Coder<KV<K, OutputT>> outputCoder; + + public static <K, InputT, AccumT, OutputT> MultiStepCombine<K, InputT, AccumT, OutputT> of( + CombineFn<InputT, AccumT, OutputT> combineFn, Coder<KV<K, OutputT>> outputCoder) { + return new MultiStepCombine<>(combineFn, outputCoder); + } + + private MultiStepCombine( + CombineFn<InputT, AccumT, OutputT> combineFn, Coder<KV<K, OutputT>> outputCoder) { + this.combineFn = combineFn; + this.outputCoder = outputCoder; + } + + @Nullable + @Override + public String getUrn() { + return "urn:beam:directrunner:transforms:multistepcombine:v1"; + } + + @Override + public PCollection<KV<K, OutputT>> expand(PCollection<KV<K, InputT>> input) { + checkArgument( + input.getCoder() instanceof KvCoder, + "Expected input to have a %s of type %s, got %s", + Coder.class.getSimpleName(), + KvCoder.class.getSimpleName(), + input.getCoder()); + KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder(); + Coder<InputT> inputValueCoder = inputCoder.getValueCoder(); + Coder<AccumT> accumulatorCoder; + try { + accumulatorCoder = + combineFn.getAccumulatorCoder(input.getPipeline().getCoderRegistry(), inputValueCoder); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException( + String.format( + "Could not construct an Accumulator Coder with the provided %s %s", + CombineFn.class.getSimpleName(), combineFn), + e); + } + return input + .apply( + ParDo.of( + new CombineInputs<>( + combineFn, + input.getWindowingStrategy().getTimestampCombiner(), + inputCoder.getKeyCoder()))) + .setCoder(KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder)) + .apply(GroupByKey.<K, AccumT>create()) + .apply(new MergeAndExtractAccumulatorOutput<K, AccumT, OutputT>(combineFn)) + .setCoder(outputCoder); + } + + private static class CombineInputs<K, InputT, AccumT> extends DoFn<KV<K, InputT>, KV<K, AccumT>> { + private final CombineFn<InputT, AccumT, ?> combineFn; + private final TimestampCombiner timestampCombiner; + private final Coder<K> keyCoder; + + /** + * Per-bundle state. Accumulators and output timestamps should only be tracked while a bundle + * is being processed, and must be cleared when a bundle is completed. + */ + private transient Map<WindowedStructuralKey<K>, AccumT> accumulators; + private transient Map<WindowedStructuralKey<K>, Instant> timestamps; + + private CombineInputs( + CombineFn<InputT, AccumT, ?> combineFn, + TimestampCombiner timestampCombiner, + Coder<K> keyCoder) { + this.combineFn = combineFn; + this.timestampCombiner = timestampCombiner; + this.keyCoder = keyCoder; + } + + @StartBundle + public void startBundle() { + accumulators = new LinkedHashMap<>(); + timestamps = new LinkedHashMap<>(); + } + + @ProcessElement + public void processElement(ProcessContext context, BoundedWindow window) { + WindowedStructuralKey<K> + key = WindowedStructuralKey.create(keyCoder, context.element().getKey(), window); + AccumT accumulator = accumulators.get(key); + Instant assignedTs = timestampCombiner.assign(window, context.timestamp()); + if (accumulator == null) { + accumulator = combineFn.createAccumulator(); + accumulators.put(key, accumulator); + timestamps.put(key, assignedTs); + } + accumulators.put(key, combineFn.addInput(accumulator, context.element().getValue())); + timestamps.put(key, timestampCombiner.combine(assignedTs, timestamps.get(key))); + } + + @FinishBundle + public void outputAccumulators(FinishBundleContext context) { + for (Map.Entry<WindowedStructuralKey<K>, AccumT> preCombineEntry : accumulators.entrySet()) { + context.output( + KV.of(preCombineEntry.getKey().getKey(), combineFn.compact(preCombineEntry.getValue())), + timestamps.get(preCombineEntry.getKey()), + preCombineEntry.getKey().getWindow()); + } + accumulators = null; + timestamps = null; + } + } + + static class WindowedStructuralKey<K> { + public static <K> WindowedStructuralKey<K> create( + Coder<K> keyCoder, K key, BoundedWindow window) { + return new WindowedStructuralKey<>(StructuralKey.of(key, keyCoder), window); + } + + private final StructuralKey<K> key; + private final BoundedWindow window; + + private WindowedStructuralKey(StructuralKey<K> key, BoundedWindow window) { + this.key = checkNotNull(key, "key cannot be null"); + this.window = checkNotNull(window, "Window cannot be null"); + } + + public K getKey() { + return key.getKey(); + } + + public BoundedWindow getWindow() { + return window; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof MultiStepCombine.WindowedStructuralKey)) { + return false; + } + WindowedStructuralKey that = (WindowedStructuralKey<?>) other; + return this.window.equals(that.window) && this.key.equals(that.key); + } + + @Override + public int hashCode() { + return Objects.hash(window, key); + } + } + + static final String DIRECT_MERGE_ACCUMULATORS_EXTRACT_OUTPUT_URN = + "urn:beam:directrunner:transforms:merge_accumulators_extract_output:v1"; + /** + * A primitive {@link PTransform} that merges iterables of accumulators and extracts the output. + * + * <p>Required to ensure that Immutability Enforcement is not applied. Accumulators + * are explicitly mutable. + */ + static class MergeAndExtractAccumulatorOutput<K, AccumT, OutputT> + extends RawPTransform<PCollection<KV<K, Iterable<AccumT>>>, PCollection<KV<K, OutputT>>> { + private final CombineFn<?, AccumT, OutputT> combineFn; + + private MergeAndExtractAccumulatorOutput(CombineFn<?, AccumT, OutputT> combineFn) { + this.combineFn = combineFn; + } + + CombineFn<?, AccumT, OutputT> getCombineFn() { + return combineFn; + } + + @Override + public PCollection<KV<K, OutputT>> expand(PCollection<KV<K, Iterable<AccumT>>> input) { + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), input.getWindowingStrategy(), input.isBounded()); + } + + @Nullable + @Override + public String getUrn() { + return DIRECT_MERGE_ACCUMULATORS_EXTRACT_OUTPUT_URN; + } + } + + static class MergeAndExtractAccumulatorOutputEvaluatorFactory + implements TransformEvaluatorFactory { + private final EvaluationContext ctxt; + + public MergeAndExtractAccumulatorOutputEvaluatorFactory(EvaluationContext ctxt) { + this.ctxt = ctxt; + } + + @Nullable + @Override + public <InputT> TransformEvaluator<InputT> forApplication( + AppliedPTransform<?, ?, ?> application, CommittedBundle<?> inputBundle) throws Exception { + return createEvaluator((AppliedPTransform) application, (CommittedBundle) inputBundle); + } + + private <K, AccumT, OutputT> TransformEvaluator<KV<K, Iterable<AccumT>>> createEvaluator( + AppliedPTransform< + PCollection<KV<K, Iterable<AccumT>>>, PCollection<KV<K, OutputT>>, + MergeAndExtractAccumulatorOutput<K, AccumT, OutputT>> + application, + CommittedBundle<KV<K, Iterable<AccumT>>> inputBundle) { + return new MergeAccumulatorsAndExtractOutputEvaluator<>(ctxt, application); + } + + @Override + public void cleanup() throws Exception {} + } + + private static class MergeAccumulatorsAndExtractOutputEvaluator<K, AccumT, OutputT> + implements TransformEvaluator<KV<K, Iterable<AccumT>>> { + private final AppliedPTransform< + PCollection<KV<K, Iterable<AccumT>>>, PCollection<KV<K, OutputT>>, + MergeAndExtractAccumulatorOutput<K, AccumT, OutputT>> + application; + private final CombineFn<?, AccumT, OutputT> combineFn; + private final UncommittedBundle<KV<K, OutputT>> output; + + public MergeAccumulatorsAndExtractOutputEvaluator( + EvaluationContext ctxt, + AppliedPTransform< + PCollection<KV<K, Iterable<AccumT>>>, PCollection<KV<K, OutputT>>, + MergeAndExtractAccumulatorOutput<K, AccumT, OutputT>> + application) { + this.application = application; + this.combineFn = application.getTransform().getCombineFn(); + this.output = + ctxt.createBundle( + (PCollection<KV<K, OutputT>>) + Iterables.getOnlyElement(application.getOutputs().values())); + } + + @Override + public void processElement(WindowedValue<KV<K, Iterable<AccumT>>> element) throws Exception { + checkState( + element.getWindows().size() == 1, + "Expected inputs to %s to be in exactly one window. Got %s", + MergeAccumulatorsAndExtractOutputEvaluator.class.getSimpleName(), + element.getWindows().size()); + Iterable<AccumT> inputAccumulators = element.getValue().getValue(); + try { + AccumT first = combineFn.createAccumulator(); + AccumT merged = combineFn.mergeAccumulators(Iterables.concat(Collections.singleton(first), + inputAccumulators, + Collections.singleton(combineFn.createAccumulator()))); + OutputT extracted = combineFn.extractOutput(merged); + output.add(element.withValue(KV.of(element.getValue().getKey(), extracted))); + } catch (Exception e) { + throw UserCodeException.wrap(e); + } + } + + @Override + public TransformResult<KV<K, Iterable<AccumT>>> finishBundle() throws Exception { + return StepTransformResult.<KV<K, Iterable<AccumT>>>withoutHold(application) + .addOutput(output) + .build(); + } + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/6ea2eda2/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java index 0c907df..30666db 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java @@ -26,6 +26,7 @@ import static org.apache.beam.runners.core.construction.PTransformTranslation.WI import static org.apache.beam.runners.core.construction.SplittableParDo.SPLITTABLE_PROCESS_URN; import static org.apache.beam.runners.direct.DirectGroupByKey.DIRECT_GABW_URN; import static org.apache.beam.runners.direct.DirectGroupByKey.DIRECT_GBKO_URN; +import static org.apache.beam.runners.direct.MultiStepCombine.DIRECT_MERGE_ACCUMULATORS_EXTRACT_OUTPUT_URN; import static org.apache.beam.runners.direct.ParDoMultiOverrideFactory.DIRECT_STATEFUL_PAR_DO_URN; import static org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory.DIRECT_TEST_STREAM_URN; import static org.apache.beam.runners.direct.ViewOverrideFactory.DIRECT_WRITE_VIEW_URN; @@ -73,6 +74,9 @@ class TransformEvaluatorRegistry implements TransformEvaluatorFactory { .put(DIRECT_GBKO_URN, new GroupByKeyOnlyEvaluatorFactory(ctxt)) .put(DIRECT_GABW_URN, new GroupAlsoByWindowEvaluatorFactory(ctxt)) .put(DIRECT_TEST_STREAM_URN, new TestStreamEvaluatorFactory(ctxt)) + .put( + DIRECT_MERGE_ACCUMULATORS_EXTRACT_OUTPUT_URN, + new MultiStepCombine.MergeAndExtractAccumulatorOutputEvaluatorFactory(ctxt)) // Runners-core primitives .put(SPLITTABLE_PROCESS_URN, new SplittableProcessElementsEvaluatorFactory<>(ctxt)) http://git-wip-us.apache.org/repos/asf/beam/blob/6ea2eda2/runners/direct-java/src/test/java/org/apache/beam/runners/direct/MultiStepCombineTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/MultiStepCombineTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/MultiStepCombineTest.java new file mode 100644 index 0000000..0c11a8a --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/MultiStepCombineTest.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.direct; + +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; + +import com.google.auto.value.AutoValue; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.SlidingWindows; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.util.VarInt; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TimestampedValue; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link MultiStepCombine}. + */ +@RunWith(JUnit4.class) +public class MultiStepCombineTest implements Serializable { + @Rule public transient TestPipeline pipeline = TestPipeline.create(); + + private transient KvCoder<String, Long> combinedCoder = + KvCoder.of(StringUtf8Coder.of(), VarLongCoder.of()); + + @Test + public void testMultiStepCombine() { + PCollection<KV<String, Long>> combined = + pipeline + .apply( + Create.of( + KV.of("foo", 1L), + KV.of("bar", 2L), + KV.of("bizzle", 3L), + KV.of("bar", 4L), + KV.of("bizzle", 11L))) + .apply(Combine.<String, Long, Long>perKey(new MultiStepCombineFn())); + + PAssert.that(combined) + .containsInAnyOrder(KV.of("foo", 1L), KV.of("bar", 6L), KV.of("bizzle", 14L)); + pipeline.run(); + } + + @Test + public void testMultiStepCombineWindowed() { + SlidingWindows windowFn = SlidingWindows.of(Duration.millis(6L)).every(Duration.millis(3L)); + PCollection<KV<String, Long>> combined = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(KV.of("foo", 1L), new Instant(1L)), + TimestampedValue.of(KV.of("bar", 2L), new Instant(2L)), + TimestampedValue.of(KV.of("bizzle", 3L), new Instant(3L)), + TimestampedValue.of(KV.of("bar", 4L), new Instant(4L)), + TimestampedValue.of(KV.of("bizzle", 11L), new Instant(11L)))) + .apply(Window.<KV<String, Long>>into(windowFn)) + .apply(Combine.<String, Long, Long>perKey(new MultiStepCombineFn())); + + PAssert.that("Windows should combine only elements in their windows", combined) + .inWindow(new IntervalWindow(new Instant(0L), Duration.millis(6L))) + .containsInAnyOrder(KV.of("foo", 1L), KV.of("bar", 6L), KV.of("bizzle", 3L)); + PAssert.that("Elements should appear in all the windows they are assigned to", combined) + .inWindow(new IntervalWindow(new Instant(-3L), Duration.millis(6L))) + .containsInAnyOrder(KV.of("foo", 1L), KV.of("bar", 2L)); + PAssert.that(combined) + .inWindow(new IntervalWindow(new Instant(6L), Duration.millis(6L))) + .containsInAnyOrder(KV.of("bizzle", 11L)); + PAssert.that(combined) + .containsInAnyOrder( + KV.of("foo", 1L), + KV.of("foo", 1L), + KV.of("bar", 6L), + KV.of("bar", 2L), + KV.of("bar", 4L), + KV.of("bizzle", 11L), + KV.of("bizzle", 11L), + KV.of("bizzle", 3L), + KV.of("bizzle", 3L)); + pipeline.run(); + } + + @Test + public void testMultiStepCombineTimestampCombiner() { + TimestampCombiner combiner = TimestampCombiner.LATEST; + combinedCoder = KvCoder.of(StringUtf8Coder.of(), VarLongCoder.of()); + PCollection<KV<String, Long>> combined = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(KV.of("foo", 4L), new Instant(1L)), + TimestampedValue.of(KV.of("foo", 1L), new Instant(4L)), + TimestampedValue.of(KV.of("bazzle", 4L), new Instant(4L)), + TimestampedValue.of(KV.of("foo", 12L), new Instant(12L)))) + .apply( + Window.<KV<String, Long>>into(FixedWindows.of(Duration.millis(5L))) + .withTimestampCombiner(combiner)) + .apply(Combine.<String, Long, Long>perKey(new MultiStepCombineFn())); + PCollection<KV<String, TimestampedValue<Long>>> reified = + combined.apply( + ParDo.of( + new DoFn<KV<String, Long>, KV<String, TimestampedValue<Long>>>() { + @ProcessElement + public void reifyTimestamp(ProcessContext context) { + context.output( + KV.of( + context.element().getKey(), + TimestampedValue.of( + context.element().getValue(), context.timestamp()))); + } + })); + + PAssert.that(reified) + .containsInAnyOrder( + KV.of("foo", TimestampedValue.of(5L, new Instant(4L))), + KV.of("bazzle", TimestampedValue.of(4L, new Instant(4L))), + KV.of("foo", TimestampedValue.of(12L, new Instant(12L)))); + pipeline.run(); + } + + private static class MultiStepCombineFn extends CombineFn<Long, MultiStepAccumulator, Long> { + @Override + public Coder<MultiStepAccumulator> getAccumulatorCoder( + CoderRegistry registry, Coder<Long> inputCoder) throws CannotProvideCoderException { + return new MultiStepAccumulatorCoder(); + } + + @Override + public MultiStepAccumulator createAccumulator() { + return MultiStepAccumulator.of(0L, false); + } + + @Override + public MultiStepAccumulator addInput(MultiStepAccumulator accumulator, Long input) { + return MultiStepAccumulator.of(accumulator.getValue() + input, accumulator.isDeserialized()); + } + + @Override + public MultiStepAccumulator mergeAccumulators(Iterable<MultiStepAccumulator> accumulators) { + MultiStepAccumulator result = MultiStepAccumulator.of(0L, false); + for (MultiStepAccumulator accumulator : accumulators) { + result = result.merge(accumulator); + } + return result; + } + + @Override + public Long extractOutput(MultiStepAccumulator accumulator) { + assertThat( + "Accumulators should have been serialized and deserialized within the Pipeline", + accumulator.isDeserialized(), + is(true)); + return accumulator.getValue(); + } + } + + @AutoValue + abstract static class MultiStepAccumulator { + private static MultiStepAccumulator of(long value, boolean deserialized) { + return new AutoValue_MultiStepCombineTest_MultiStepAccumulator(value, deserialized); + } + + MultiStepAccumulator merge(MultiStepAccumulator other) { + return MultiStepAccumulator.of( + this.getValue() + other.getValue(), this.isDeserialized() || other.isDeserialized()); + } + + abstract long getValue(); + + abstract boolean isDeserialized(); + } + + private static class MultiStepAccumulatorCoder extends CustomCoder<MultiStepAccumulator> { + @Override + public void encode(MultiStepAccumulator value, OutputStream outStream) + throws CoderException, IOException { + VarInt.encode(value.getValue(), outStream); + } + + @Override + public MultiStepAccumulator decode(InputStream inStream) throws CoderException, IOException { + return MultiStepAccumulator.of(VarInt.decodeLong(inStream), true); + } + } +}
