Repository: beam Updated Branches: refs/heads/master 346a77fa8 -> c528fb2f7
Fix getAdditionalInputs, etc, for DirectRunner stateful ParDo override Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/81a72192 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/81a72192 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/81a72192 Branch: refs/heads/master Commit: 81a72192dc4e792966de31c8eadda6a6c839a62c Parents: 346a77f Author: Kenneth Knowles <[email protected]> Authored: Mon Jun 12 16:31:32 2017 -0700 Committer: Kenneth Knowles <[email protected]> Committed: Thu Jun 15 16:47:53 2017 -0700 ---------------------------------------------------------------------- .../direct/ParDoMultiOverrideFactory.java | 90 +++++++++++++++----- .../direct/StatefulParDoEvaluatorFactory.java | 11 ++- .../StatefulParDoEvaluatorFactoryTest.java | 65 +++++++------- 3 files changed, 102 insertions(+), 64 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/81a72192/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java index 858ea34..b20113e 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java @@ -19,6 +19,8 @@ package org.apache.beam.runners.direct; import static com.google.common.base.Preconditions.checkState; +import com.google.common.collect.ImmutableMap; +import java.util.List; import java.util.Map; import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.KeyedWorkItemCoder; @@ -27,7 +29,6 @@ import org.apache.beam.runners.core.construction.PTransformReplacements; import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.ReplacementOutputs; import org.apache.beam.runners.core.construction.SplittableParDo; -import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.runners.AppliedPTransform; @@ -48,6 +49,7 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; @@ -82,12 +84,14 @@ class ParDoMultiOverrideFactory<InputT, OutputT> return new SplittableParDo(transform); } else if (signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0) { + // Based on the fact that the signature is stateful, DoFnSignatures ensures // that it is also keyed - MultiOutput<KV<?, ?>, OutputT> keyedTransform = - (MultiOutput<KV<?, ?>, OutputT>) transform; - - return new GbkThenStatefulParDo(keyedTransform); + return new GbkThenStatefulParDo( + fn, + transform.getMainOutputTag(), + transform.getAdditionalOutputTags(), + transform.getSideInputs()); } else { return transform; } @@ -101,10 +105,29 @@ class ParDoMultiOverrideFactory<InputT, OutputT> static class GbkThenStatefulParDo<K, InputT, OutputT> extends PTransform<PCollection<KV<K, InputT>>, PCollectionTuple> { - private final MultiOutput<KV<K, InputT>, OutputT> underlyingParDo; + private final transient DoFn<KV<K, InputT>, OutputT> doFn; + private final TupleTagList additionalOutputTags; + private final TupleTag<OutputT> mainOutputTag; + private final List<PCollectionView<?>> sideInputs; + + public GbkThenStatefulParDo( + DoFn<KV<K, InputT>, OutputT> doFn, + TupleTag<OutputT> mainOutputTag, + TupleTagList additionalOutputTags, + List<PCollectionView<?>> sideInputs) { + this.doFn = doFn; + this.additionalOutputTags = additionalOutputTags; + this.mainOutputTag = mainOutputTag; + this.sideInputs = sideInputs; + } - public GbkThenStatefulParDo(MultiOutput<KV<K, InputT>, OutputT> underlyingParDo) { - this.underlyingParDo = underlyingParDo; + @Override + public Map<TupleTag<?>, PValue> getAdditionalInputs() { + ImmutableMap.Builder<TupleTag<?>, PValue> additionalInputs = ImmutableMap.builder(); + for (PCollectionView<?> sideInput : sideInputs) { + additionalInputs.put(sideInput.getTagInternal(), sideInput.getPCollection()); + } + return additionalInputs.build(); } @Override @@ -160,7 +183,9 @@ class ParDoMultiOverrideFactory<InputT, OutputT> adjustedInput // Explode the resulting iterable into elements that are exactly the ones from // the input - .apply("Stateful ParDo", new StatefulParDo<>(underlyingParDo, input)); + .apply( + "Stateful ParDo", + new StatefulParDo<>(doFn, mainOutputTag, additionalOutputTags, sideInputs)); return outputs; } @@ -172,25 +197,45 @@ class ParDoMultiOverrideFactory<InputT, OutputT> static class StatefulParDo<K, InputT, OutputT> extends PTransformTranslation.RawPTransform< PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple> { - private final transient MultiOutput<KV<K, InputT>, OutputT> underlyingParDo; - private final transient PCollection<KV<K, InputT>> originalInput; + private final transient DoFn<KV<K, InputT>, OutputT> doFn; + private final TupleTagList additionalOutputTags; + private final TupleTag<OutputT> mainOutputTag; + private final List<PCollectionView<?>> sideInputs; public StatefulParDo( - MultiOutput<KV<K, InputT>, OutputT> underlyingParDo, - PCollection<KV<K, InputT>> originalInput) { - this.underlyingParDo = underlyingParDo; - this.originalInput = originalInput; + DoFn<KV<K, InputT>, OutputT> doFn, + TupleTag<OutputT> mainOutputTag, + TupleTagList additionalOutputTags, + List<PCollectionView<?>> sideInputs) { + this.doFn = doFn; + this.mainOutputTag = mainOutputTag; + this.additionalOutputTags = additionalOutputTags; + this.sideInputs = sideInputs; + } + + public DoFn<KV<K, InputT>, OutputT> getDoFn() { + return doFn; + } + + public TupleTag<OutputT> getMainOutputTag() { + return mainOutputTag; + } + + public List<PCollectionView<?>> getSideInputs() { + return sideInputs; } - public MultiOutput<KV<K, InputT>, OutputT> getUnderlyingParDo() { - return underlyingParDo; + public TupleTagList getAdditionalOutputTags() { + return additionalOutputTags; } @Override - public <T> Coder<T> getDefaultOutputCoder( - PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>> input, PCollection<T> output) - throws CannotProvideCoderException { - return underlyingParDo.getDefaultOutputCoder(originalInput, output); + public Map<TupleTag<?>, PValue> getAdditionalInputs() { + ImmutableMap.Builder<TupleTag<?>, PValue> additionalInputs = ImmutableMap.builder(); + for (PCollectionView<?> sideInput : sideInputs) { + additionalInputs.put(sideInput.getTagInternal(), sideInput.getPCollection()); + } + return additionalInputs.build(); } @Override @@ -199,8 +244,7 @@ class ParDoMultiOverrideFactory<InputT, OutputT> PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal( input.getPipeline(), - TupleTagList.of(underlyingParDo.getMainOutputTag()) - .and(underlyingParDo.getAdditionalOutputTags().getAll()), + TupleTagList.of(getMainOutputTag()).and(getAdditionalOutputTags().getAll()), input.getWindowingStrategy(), input.isBounded()); http://git-wip-us.apache.org/repos/asf/beam/blob/81a72192/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java index 3619d05..bdec9c8 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java @@ -98,7 +98,7 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo throws Exception { final DoFn<KV<K, InputT>, OutputT> doFn = - application.getTransform().getUnderlyingParDo().getFn(); + application.getTransform().getDoFn(); final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); // If the DoFn is stateful, schedule state clearing. @@ -120,9 +120,9 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo (PCollection) inputBundle.getPCollection(), inputBundle.getKey(), doFn, - application.getTransform().getUnderlyingParDo().getSideInputs(), - application.getTransform().getUnderlyingParDo().getMainOutputTag(), - application.getTransform().getUnderlyingParDo().getAdditionalOutputTags().getAll()); + application.getTransform().getSideInputs(), + application.getTransform().getMainOutputTag(), + application.getTransform().getAdditionalOutputTags().getAll()); return new StatefulParDoEvaluator<>(delegateEvaluator); } @@ -152,12 +152,11 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo transformOutputWindow .getTransform() .getTransform() - .getUnderlyingParDo() .getMainOutputTag()); WindowingStrategy<?, ?> windowingStrategy = pc.getWindowingStrategy(); BoundedWindow window = transformOutputWindow.getWindow(); final DoFn<?, ?> doFn = - transformOutputWindow.getTransform().getTransform().getUnderlyingParDo().getFn(); + transformOutputWindow.getTransform().getTransform().getDoFn(); final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); final DirectStepContext stepContext = http://git-wip-us.apache.org/repos/asf/beam/blob/81a72192/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java index 9366b7c..fe0b743 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java @@ -41,6 +41,7 @@ import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateNamespaces; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTags; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo; import org.apache.beam.runners.direct.WatermarkManager.TimerUpdate; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -52,7 +53,6 @@ import org.apache.beam.sdk.state.ValueState; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.FixedWindows; @@ -128,16 +128,17 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { input .apply( new ParDoMultiOverrideFactory.GbkThenStatefulParDo<>( - ParDo.of( - new DoFn<KV<String, Integer>, Integer>() { - @StateId(stateId) - private final StateSpec<ValueState<String>> spec = - StateSpecs.value(StringUtf8Coder.of()); - - @ProcessElement - public void process(ProcessContext c) {} - }) - .withOutputTags(mainOutput, TupleTagList.empty()))) + new DoFn<KV<String, Integer>, Integer>() { + @StateId(stateId) + private final StateSpec<ValueState<String>> spec = + StateSpecs.value(StringUtf8Coder.of()); + + @ProcessElement + public void process(ProcessContext c) {} + }, + mainOutput, + TupleTagList.empty(), + Collections.<PCollectionView<?>>emptyList())) .get(mainOutput) .setCoder(VarIntCoder.of()); @@ -153,8 +154,7 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { when(mockEvaluationContext.getExecutionContext( eq(producingTransform), Mockito.<StructuralKey>any())) .thenReturn(mockExecutionContext); - when(mockExecutionContext.getStepContext(anyString())) - .thenReturn(mockStepContext); + when(mockExecutionContext.getStepContext(anyString())).thenReturn(mockStepContext); IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new Instant(9)); IntervalWindow secondWindow = new IntervalWindow(new Instant(10), new Instant(19)); @@ -241,18 +241,17 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { mainInput .apply( new ParDoMultiOverrideFactory.GbkThenStatefulParDo<>( - ParDo - .of( - new DoFn<KV<String, Integer>, Integer>() { - @StateId(stateId) - private final StateSpec<ValueState<String>> spec = - StateSpecs.value(StringUtf8Coder.of()); - - @ProcessElement - public void process(ProcessContext c) {} - }) - .withSideInputs(sideInput) - .withOutputTags(mainOutput, TupleTagList.empty()))) + new DoFn<KV<String, Integer>, Integer>() { + @StateId(stateId) + private final StateSpec<ValueState<String>> spec = + StateSpecs.value(StringUtf8Coder.of()); + + @ProcessElement + public void process(ProcessContext c) {} + }, + mainOutput, + TupleTagList.empty(), + Collections.<PCollectionView<?>>singletonList(sideInput))) .get(mainOutput) .setCoder(VarIntCoder.of()); @@ -269,8 +268,7 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { when(mockEvaluationContext.getExecutionContext( eq(producingTransform), Mockito.<StructuralKey>any())) .thenReturn(mockExecutionContext); - when(mockExecutionContext.getStepContext(anyString())) - .thenReturn(mockStepContext); + when(mockExecutionContext.getStepContext(anyString())).thenReturn(mockStepContext); when(mockEvaluationContext.createBundle(Matchers.<PCollection<Integer>>any())) .thenReturn(mockUncommittedBundle); when(mockStepContext.getTimerUpdate()).thenReturn(TimerUpdate.empty()); @@ -287,11 +285,8 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { // global window state merely by having the evaluator created. The cleanup logic does not // depend on the window. String key = "hello"; - WindowedValue<KV<String, Integer>> firstKv = WindowedValue.of( - KV.of(key, 1), - new Instant(3), - firstWindow, - PaneInfo.NO_FIRING); + WindowedValue<KV<String, Integer>> firstKv = + WindowedValue.of(KV.of(key, 1), new Instant(3), firstWindow, PaneInfo.NO_FIRING); WindowedValue<KeyedWorkItem<String, KV<String, Integer>>> gbkOutputElement = firstKv.withValue( @@ -306,7 +301,8 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { BUNDLE_FACTORY .createBundle( (PCollection<KeyedWorkItem<String, KV<String, Integer>>>) - Iterables.getOnlyElement(producingTransform.getInputs().values())) + Iterables.getOnlyElement( + TransformInputs.nonAdditionalInputs(producingTransform))) .add(gbkOutputElement) .commit(Instant.now()); TransformEvaluator<KeyedWorkItem<String, KV<String, Integer>>> evaluator = @@ -316,8 +312,7 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { // This should push back every element as a KV<String, Iterable<Integer>> // in the appropriate window. Since the keys are equal they are single-threaded - TransformResult<KeyedWorkItem<String, KV<String, Integer>>> result = - evaluator.finishBundle(); + TransformResult<KeyedWorkItem<String, KV<String, Integer>>> result = evaluator.finishBundle(); List<Integer> pushedBackInts = new ArrayList<>();
