Repository: beam
Updated Branches:
  refs/heads/master 346a77fa8 -> c528fb2f7


Fix getAdditionalInputs, etc, for DirectRunner stateful ParDo override


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/81a72192
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/81a72192
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/81a72192

Branch: refs/heads/master
Commit: 81a72192dc4e792966de31c8eadda6a6c839a62c
Parents: 346a77f
Author: Kenneth Knowles <[email protected]>
Authored: Mon Jun 12 16:31:32 2017 -0700
Committer: Kenneth Knowles <[email protected]>
Committed: Thu Jun 15 16:47:53 2017 -0700

----------------------------------------------------------------------
 .../direct/ParDoMultiOverrideFactory.java       | 90 +++++++++++++++-----
 .../direct/StatefulParDoEvaluatorFactory.java   | 11 ++-
 .../StatefulParDoEvaluatorFactoryTest.java      | 65 +++++++-------
 3 files changed, 102 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/81a72192/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
index 858ea34..b20113e 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
@@ -19,6 +19,8 @@ package org.apache.beam.runners.direct;
 
 import static com.google.common.base.Preconditions.checkState;
 
+import com.google.common.collect.ImmutableMap;
+import java.util.List;
 import java.util.Map;
 import org.apache.beam.runners.core.KeyedWorkItem;
 import org.apache.beam.runners.core.KeyedWorkItemCoder;
@@ -27,7 +29,6 @@ import 
org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
 import org.apache.beam.runners.core.construction.SplittableParDo;
-import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.runners.AppliedPTransform;
@@ -48,6 +49,7 @@ import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
@@ -82,12 +84,14 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
       return new SplittableParDo(transform);
     } else if (signature.stateDeclarations().size() > 0
         || signature.timerDeclarations().size() > 0) {
+
       // Based on the fact that the signature is stateful, DoFnSignatures 
ensures
       // that it is also keyed
-      MultiOutput<KV<?, ?>, OutputT> keyedTransform =
-          (MultiOutput<KV<?, ?>, OutputT>) transform;
-
-      return new GbkThenStatefulParDo(keyedTransform);
+      return new GbkThenStatefulParDo(
+          fn,
+          transform.getMainOutputTag(),
+          transform.getAdditionalOutputTags(),
+          transform.getSideInputs());
     } else {
       return transform;
     }
@@ -101,10 +105,29 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
 
   static class GbkThenStatefulParDo<K, InputT, OutputT>
       extends PTransform<PCollection<KV<K, InputT>>, PCollectionTuple> {
-    private final MultiOutput<KV<K, InputT>, OutputT> underlyingParDo;
+    private final transient DoFn<KV<K, InputT>, OutputT> doFn;
+    private final TupleTagList additionalOutputTags;
+    private final TupleTag<OutputT> mainOutputTag;
+    private final List<PCollectionView<?>> sideInputs;
+
+    public GbkThenStatefulParDo(
+        DoFn<KV<K, InputT>, OutputT> doFn,
+        TupleTag<OutputT> mainOutputTag,
+        TupleTagList additionalOutputTags,
+        List<PCollectionView<?>> sideInputs) {
+      this.doFn = doFn;
+      this.additionalOutputTags = additionalOutputTags;
+      this.mainOutputTag = mainOutputTag;
+      this.sideInputs = sideInputs;
+    }
 
-    public GbkThenStatefulParDo(MultiOutput<KV<K, InputT>, OutputT> 
underlyingParDo) {
-      this.underlyingParDo = underlyingParDo;
+    @Override
+    public Map<TupleTag<?>, PValue> getAdditionalInputs() {
+      ImmutableMap.Builder<TupleTag<?>, PValue> additionalInputs = 
ImmutableMap.builder();
+      for (PCollectionView<?> sideInput : sideInputs) {
+        additionalInputs.put(sideInput.getTagInternal(), 
sideInput.getPCollection());
+      }
+      return additionalInputs.build();
     }
 
     @Override
@@ -160,7 +183,9 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
           adjustedInput
               // Explode the resulting iterable into elements that are exactly 
the ones from
               // the input
-              .apply("Stateful ParDo", new StatefulParDo<>(underlyingParDo, 
input));
+              .apply(
+              "Stateful ParDo",
+              new StatefulParDo<>(doFn, mainOutputTag, additionalOutputTags, 
sideInputs));
 
       return outputs;
     }
@@ -172,25 +197,45 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
   static class StatefulParDo<K, InputT, OutputT>
       extends PTransformTranslation.RawPTransform<
           PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, 
PCollectionTuple> {
-    private final transient MultiOutput<KV<K, InputT>, OutputT> 
underlyingParDo;
-    private final transient PCollection<KV<K, InputT>> originalInput;
+    private final transient DoFn<KV<K, InputT>, OutputT> doFn;
+    private final TupleTagList additionalOutputTags;
+    private final TupleTag<OutputT> mainOutputTag;
+    private final List<PCollectionView<?>> sideInputs;
 
     public StatefulParDo(
-        MultiOutput<KV<K, InputT>, OutputT> underlyingParDo,
-        PCollection<KV<K, InputT>> originalInput) {
-      this.underlyingParDo = underlyingParDo;
-      this.originalInput = originalInput;
+        DoFn<KV<K, InputT>, OutputT> doFn,
+        TupleTag<OutputT> mainOutputTag,
+        TupleTagList additionalOutputTags,
+        List<PCollectionView<?>> sideInputs) {
+      this.doFn = doFn;
+      this.mainOutputTag = mainOutputTag;
+      this.additionalOutputTags = additionalOutputTags;
+      this.sideInputs = sideInputs;
+    }
+
+    public DoFn<KV<K, InputT>, OutputT> getDoFn() {
+      return doFn;
+    }
+
+    public TupleTag<OutputT> getMainOutputTag() {
+      return mainOutputTag;
+    }
+
+    public List<PCollectionView<?>> getSideInputs() {
+      return sideInputs;
     }
 
-    public MultiOutput<KV<K, InputT>, OutputT> getUnderlyingParDo() {
-      return underlyingParDo;
+    public TupleTagList getAdditionalOutputTags() {
+      return additionalOutputTags;
     }
 
     @Override
-    public <T> Coder<T> getDefaultOutputCoder(
-        PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>> input, 
PCollection<T> output)
-        throws CannotProvideCoderException {
-      return underlyingParDo.getDefaultOutputCoder(originalInput, output);
+    public Map<TupleTag<?>, PValue> getAdditionalInputs() {
+      ImmutableMap.Builder<TupleTag<?>, PValue> additionalInputs = 
ImmutableMap.builder();
+      for (PCollectionView<?> sideInput : sideInputs) {
+        additionalInputs.put(sideInput.getTagInternal(), 
sideInput.getPCollection());
+      }
+      return additionalInputs.build();
     }
 
     @Override
@@ -199,8 +244,7 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
       PCollectionTuple outputs =
           PCollectionTuple.ofPrimitiveOutputsInternal(
               input.getPipeline(),
-              TupleTagList.of(underlyingParDo.getMainOutputTag())
-                  .and(underlyingParDo.getAdditionalOutputTags().getAll()),
+              
TupleTagList.of(getMainOutputTag()).and(getAdditionalOutputTags().getAll()),
               input.getWindowingStrategy(),
               input.isBounded());
 

http://git-wip-us.apache.org/repos/asf/beam/blob/81a72192/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
index 3619d05..bdec9c8 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
@@ -98,7 +98,7 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> 
implements Transfo
       throws Exception {
 
     final DoFn<KV<K, InputT>, OutputT> doFn =
-        application.getTransform().getUnderlyingParDo().getFn();
+        application.getTransform().getDoFn();
     final DoFnSignature signature = 
DoFnSignatures.getSignature(doFn.getClass());
 
     // If the DoFn is stateful, schedule state clearing.
@@ -120,9 +120,9 @@ final class StatefulParDoEvaluatorFactory<K, InputT, 
OutputT> implements Transfo
             (PCollection) inputBundle.getPCollection(),
             inputBundle.getKey(),
             doFn,
-            application.getTransform().getUnderlyingParDo().getSideInputs(),
-            application.getTransform().getUnderlyingParDo().getMainOutputTag(),
-            
application.getTransform().getUnderlyingParDo().getAdditionalOutputTags().getAll());
+            application.getTransform().getSideInputs(),
+            application.getTransform().getMainOutputTag(),
+            application.getTransform().getAdditionalOutputTags().getAll());
 
     return new StatefulParDoEvaluator<>(delegateEvaluator);
   }
@@ -152,12 +152,11 @@ final class StatefulParDoEvaluatorFactory<K, InputT, 
OutputT> implements Transfo
                   transformOutputWindow
                       .getTransform()
                       .getTransform()
-                      .getUnderlyingParDo()
                       .getMainOutputTag());
       WindowingStrategy<?, ?> windowingStrategy = pc.getWindowingStrategy();
       BoundedWindow window = transformOutputWindow.getWindow();
       final DoFn<?, ?> doFn =
-          
transformOutputWindow.getTransform().getTransform().getUnderlyingParDo().getFn();
+          transformOutputWindow.getTransform().getTransform().getDoFn();
       final DoFnSignature signature = 
DoFnSignatures.getSignature(doFn.getClass());
 
       final DirectStepContext stepContext =

http://git-wip-us.apache.org/repos/asf/beam/blob/81a72192/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
index 9366b7c..fe0b743 100644
--- 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
+++ 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
@@ -41,6 +41,7 @@ import org.apache.beam.runners.core.StateNamespace;
 import org.apache.beam.runners.core.StateNamespaces;
 import org.apache.beam.runners.core.StateTag;
 import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.core.construction.TransformInputs;
 import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo;
 import org.apache.beam.runners.direct.WatermarkManager.TimerUpdate;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
@@ -52,7 +53,6 @@ import org.apache.beam.sdk.state.ValueState;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
@@ -128,16 +128,17 @@ public class StatefulParDoEvaluatorFactoryTest implements 
Serializable {
         input
             .apply(
                 new ParDoMultiOverrideFactory.GbkThenStatefulParDo<>(
-                    ParDo.of(
-                            new DoFn<KV<String, Integer>, Integer>() {
-                              @StateId(stateId)
-                              private final StateSpec<ValueState<String>> spec 
=
-                                  StateSpecs.value(StringUtf8Coder.of());
-
-                              @ProcessElement
-                              public void process(ProcessContext c) {}
-                            })
-                        .withOutputTags(mainOutput, TupleTagList.empty())))
+                    new DoFn<KV<String, Integer>, Integer>() {
+                      @StateId(stateId)
+                      private final StateSpec<ValueState<String>> spec =
+                          StateSpecs.value(StringUtf8Coder.of());
+
+                      @ProcessElement
+                      public void process(ProcessContext c) {}
+                    },
+                    mainOutput,
+                    TupleTagList.empty(),
+                    Collections.<PCollectionView<?>>emptyList()))
             .get(mainOutput)
             .setCoder(VarIntCoder.of());
 
@@ -153,8 +154,7 @@ public class StatefulParDoEvaluatorFactoryTest implements 
Serializable {
     when(mockEvaluationContext.getExecutionContext(
             eq(producingTransform), Mockito.<StructuralKey>any()))
         .thenReturn(mockExecutionContext);
-    when(mockExecutionContext.getStepContext(anyString()))
-        .thenReturn(mockStepContext);
+    
when(mockExecutionContext.getStepContext(anyString())).thenReturn(mockStepContext);
 
     IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new 
Instant(9));
     IntervalWindow secondWindow = new IntervalWindow(new Instant(10), new 
Instant(19));
@@ -241,18 +241,17 @@ public class StatefulParDoEvaluatorFactoryTest implements 
Serializable {
         mainInput
             .apply(
                 new ParDoMultiOverrideFactory.GbkThenStatefulParDo<>(
-                    ParDo
-                        .of(
-                            new DoFn<KV<String, Integer>, Integer>() {
-                              @StateId(stateId)
-                              private final StateSpec<ValueState<String>> spec 
=
-                                  StateSpecs.value(StringUtf8Coder.of());
-
-                              @ProcessElement
-                              public void process(ProcessContext c) {}
-                            })
-                        .withSideInputs(sideInput)
-                        .withOutputTags(mainOutput, TupleTagList.empty())))
+                    new DoFn<KV<String, Integer>, Integer>() {
+                      @StateId(stateId)
+                      private final StateSpec<ValueState<String>> spec =
+                          StateSpecs.value(StringUtf8Coder.of());
+
+                      @ProcessElement
+                      public void process(ProcessContext c) {}
+                    },
+                    mainOutput,
+                    TupleTagList.empty(),
+                    Collections.<PCollectionView<?>>singletonList(sideInput)))
             .get(mainOutput)
             .setCoder(VarIntCoder.of());
 
@@ -269,8 +268,7 @@ public class StatefulParDoEvaluatorFactoryTest implements 
Serializable {
     when(mockEvaluationContext.getExecutionContext(
             eq(producingTransform), Mockito.<StructuralKey>any()))
         .thenReturn(mockExecutionContext);
-    when(mockExecutionContext.getStepContext(anyString()))
-        .thenReturn(mockStepContext);
+    
when(mockExecutionContext.getStepContext(anyString())).thenReturn(mockStepContext);
     
when(mockEvaluationContext.createBundle(Matchers.<PCollection<Integer>>any()))
         .thenReturn(mockUncommittedBundle);
     when(mockStepContext.getTimerUpdate()).thenReturn(TimerUpdate.empty());
@@ -287,11 +285,8 @@ public class StatefulParDoEvaluatorFactoryTest implements 
Serializable {
     // global window state merely by having the evaluator created. The cleanup 
logic does not
     // depend on the window.
     String key = "hello";
-    WindowedValue<KV<String, Integer>> firstKv = WindowedValue.of(
-        KV.of(key, 1),
-        new Instant(3),
-        firstWindow,
-        PaneInfo.NO_FIRING);
+    WindowedValue<KV<String, Integer>> firstKv =
+        WindowedValue.of(KV.of(key, 1), new Instant(3), firstWindow, 
PaneInfo.NO_FIRING);
 
     WindowedValue<KeyedWorkItem<String, KV<String, Integer>>> gbkOutputElement 
=
         firstKv.withValue(
@@ -306,7 +301,8 @@ public class StatefulParDoEvaluatorFactoryTest implements 
Serializable {
         BUNDLE_FACTORY
             .createBundle(
                 (PCollection<KeyedWorkItem<String, KV<String, Integer>>>)
-                    
Iterables.getOnlyElement(producingTransform.getInputs().values()))
+                    Iterables.getOnlyElement(
+                        
TransformInputs.nonAdditionalInputs(producingTransform)))
             .add(gbkOutputElement)
             .commit(Instant.now());
     TransformEvaluator<KeyedWorkItem<String, KV<String, Integer>>> evaluator =
@@ -316,8 +312,7 @@ public class StatefulParDoEvaluatorFactoryTest implements 
Serializable {
 
     // This should push back every element as a KV<String, Iterable<Integer>>
     // in the appropriate window. Since the keys are equal they are 
single-threaded
-    TransformResult<KeyedWorkItem<String, KV<String, Integer>>> result =
-        evaluator.finishBundle();
+    TransformResult<KeyedWorkItem<String, KV<String, Integer>>> result = 
evaluator.finishBundle();
 
     List<Integer> pushedBackInts = new ArrayList<>();
 

Reply via email to