Repository: incubator-beam
Updated Branches:
  refs/heads/master 79c26d9c1 -> cb0356932


Fix DoFnTester side inputs

The side inputs were being stored as iterables, but being returned as
the raw type.

Store the side input values directly instead.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/1c1af625
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/1c1af625
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/1c1af625

Branch: refs/heads/master
Commit: 1c1af62586db36212ebf76eb8307d1993666afa5
Parents: f0119b2
Author: Thomas Groh <[email protected]>
Authored: Thu Jul 14 10:33:22 2016 -0700
Committer: Kenneth Knowles <[email protected]>
Committed: Thu Jul 14 17:13:10 2016 -0700

----------------------------------------------------------------------
 .../apache/beam/sdk/transforms/DoFnTester.java  | 70 ++++++++----------
 .../beam/sdk/transforms/DoFnTesterTest.java     | 74 ++++++++++++++++----
 2 files changed, 91 insertions(+), 53 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1c1af625/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
index 8cfb550..a638feb 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
@@ -103,50 +103,35 @@ public class DoFnTester<InputT, OutputT> {
    * Registers the tuple of values of the side input {@link PCollectionView}s 
to
    * pass to the {@link DoFn} under test.
    *
-   * <p>If needed, first creates a fresh instance of the {@link DoFn}
-   * under test.
+   * <p>Resets the state of this {@link DoFnTester}.
    *
    * <p>If this isn't called, {@code DoFnTester} assumes the
    * {@link DoFn} takes no side inputs.
    */
-  public void setSideInputs(Map<PCollectionView<?>, 
Iterable<WindowedValue<?>>> sideInputs) {
+  public void setSideInputs(Map<PCollectionView<?>, Map<BoundedWindow, ?>> 
sideInputs) {
     this.sideInputs = sideInputs;
     resetState();
   }
 
   /**
-   * Registers the values of a side input {@link PCollectionView} to
-   * pass to the {@link DoFn} under test.
+   * Registers the values of a side input {@link PCollectionView} to pass to 
the {@link DoFn} under
+   * test.
    *
-   * <p>If needed, first creates a fresh instance of the {@code DoFn}
-   * under test.
+   * <p>The provided value is the final value of the side input in the 
specified window, not
+   * the value of the input PCollection in that window.
    *
-   * <p>If this isn't called, {@code DoFnTester} assumes the
-   * {@code DoFn} takes no side inputs.
+   * <p>If this isn't called, {@code DoFnTester} will return the default value 
for any side input
+   * that is used.
    */
-  public void setSideInput(PCollectionView<?> sideInput, 
Iterable<WindowedValue<?>> value) {
-    sideInputs.put(sideInput, value);
-  }
-
-  /**
-   * Registers the values for a side input {@link PCollectionView} to
-   * pass to the {@link DoFn} under test. All values are placed
-   * in the global window.
-   */
-  public void setSideInputInGlobalWindow(
-      PCollectionView<?> sideInput,
-      Iterable<?> value) {
-    sideInputs.put(
-        sideInput,
-        Iterables.transform(value, new Function<Object, WindowedValue<?>>() {
-          @Override
-          public WindowedValue<?> apply(Object input) {
-            return WindowedValue.valueInGlobalWindow(input);
-          }
-        }));
+  public <T> void setSideInput(PCollectionView<T> sideInput, BoundedWindow 
window, T value) {
+    Map<BoundedWindow, T> windowValues = (Map<BoundedWindow, T>) 
sideInputs.get(sideInput);
+    if (windowValues == null) {
+      windowValues = new HashMap<>();
+      sideInputs.put(sideInput, windowValues);
+    }
+    windowValues.put(window, value);
   }
 
-
   /**
    * Registers the list of {@code TupleTag}s that can be used by the
    * {@code DoFn} under test to output to side output
@@ -523,14 +508,14 @@ public class DoFnTester<InputT, OutputT> {
     private final TestContext<InT, OutT> context;
     private final TupleTag<OutT> mainOutputTag;
     private final WindowedValue<InT> element;
-    private final Map<PCollectionView<?>, ?> sideInputs;
+    private final Map<PCollectionView<?>, Map<BoundedWindow, ?>> sideInputs;
 
     private TestProcessContext(
         DoFn<InT, OutT> fn,
         TestContext<InT, OutT> context,
         WindowedValue<InT> element,
         TupleTag<OutT> mainOutputTag,
-        Map<PCollectionView<?>, ?> sideInputs) {
+        Map<PCollectionView<?>, Map<BoundedWindow, ?>> sideInputs) {
       fn.super();
       this.context = context;
       this.element = element;
@@ -545,9 +530,17 @@ public class DoFnTester<InputT, OutputT> {
 
     @Override
     public <T> T sideInput(PCollectionView<T> view) {
-      @SuppressWarnings("unchecked")
-      T sideInput = (T) sideInputs.get(view);
-      return sideInput;
+      Map<BoundedWindow, ?> viewValues = sideInputs.get(view);
+      if (viewValues != null) {
+        BoundedWindow sideInputWindow =
+            
view.getWindowingStrategyInternal().getWindowFn().getSideInputWindow(window());
+        @SuppressWarnings("unchecked")
+        T windowValue = (T) viewValues.get(sideInputWindow);
+        if (windowValue != null) {
+          return windowValue;
+        }
+      }
+      return 
view.fromIterableInternal(Collections.<WindowedValue<?>>emptyList());
     }
 
     @Override
@@ -668,7 +661,7 @@ public class DoFnTester<InputT, OutputT> {
   final DoFn<InputT, OutputT> origFn;
 
   /** The side input values to provide to the DoFn under test. */
-  private Map<PCollectionView<?>, Iterable<WindowedValue<?>>> sideInputs =
+  private Map<PCollectionView<?>, Map<BoundedWindow, ?>> sideInputs =
       new HashMap<>();
 
   private Map<String, Object> accumulators;
@@ -703,11 +696,6 @@ public class DoFnTester<InputT, OutputT> {
         SerializableUtils.deserializeFromByteArray(
             SerializableUtils.serializeToByteArray(origFn),
             origFn.toString());
-    PTuple runnerSideInputs = PTuple.empty();
-    for (Map.Entry<PCollectionView<?>, Iterable<WindowedValue<?>>> entry
-        : sideInputs.entrySet()) {
-      runnerSideInputs = runnerSideInputs.and(entry.getKey().getTagInternal(), 
entry.getValue());
-    }
     outputs = new HashMap<>();
     accumulators = new HashMap<>();
   }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1c1af625/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java
index b391671..8460a7c 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java
@@ -24,8 +24,13 @@ import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
 
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.util.PCollectionViews;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TimestampedValue;
 
 import org.hamcrest.Matchers;
@@ -150,19 +155,15 @@ public class DoFnTesterTest {
     tester.processElement(2L);
 
     List<TimestampedValue<String>> peek = 
tester.peekOutputElementsWithTimestamp();
-    TimestampedValue<String> one =
-        TimestampedValue.of("1", new Instant(1000L));
-    TimestampedValue<String> two =
-        TimestampedValue.of("2", new Instant(2000L));
+    TimestampedValue<String> one = TimestampedValue.of("1", new 
Instant(1000L));
+    TimestampedValue<String> two = TimestampedValue.of("2", new 
Instant(2000L));
     assertThat(peek, hasItems(one, two));
 
     tester.processElement(3L);
     tester.processElement(4L);
 
-    TimestampedValue<String> three =
-        TimestampedValue.of("3", new Instant(3000L));
-    TimestampedValue<String> four =
-        TimestampedValue.of("4", new Instant(4000L));
+    TimestampedValue<String> three = TimestampedValue.of("3", new 
Instant(3000L));
+    TimestampedValue<String> four = TimestampedValue.of("4", new 
Instant(4000L));
     peek = tester.peekOutputElementsWithTimestamp();
     assertThat(peek, hasItems(one, two, three, four));
     List<TimestampedValue<String>> take = 
tester.takeOutputElementsWithTimestamp();
@@ -219,14 +220,63 @@ public class DoFnTesterTest {
     tester.processElement(2L);
     tester.finishBundle();
 
-    assertThat(tester.peekOutputElementsInWindow(GlobalWindow.INSTANCE),
-        containsInAnyOrder(TimestampedValue.of("1", new Instant(1000L)),
+    assertThat(
+        tester.peekOutputElementsInWindow(GlobalWindow.INSTANCE),
+        containsInAnyOrder(
+            TimestampedValue.of("1", new Instant(1000L)),
             TimestampedValue.of("2", new Instant(2000L))));
-    assertThat(tester.peekOutputElementsInWindow(
-        new IntervalWindow(new Instant(0L), new Instant(10L))),
+    assertThat(
+        tester.peekOutputElementsInWindow(new IntervalWindow(new Instant(0L), 
new Instant(10L))),
         Matchers.<TimestampedValue<String>>emptyIterable());
   }
 
+  @Test
+  public void fnWithSideInputDefault() throws Exception {
+    final PCollectionView<Integer> value =
+        PCollectionViews.singletonView(
+            TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0, 
VarIntCoder.of());
+    DoFn<Integer, Integer> fn = new SideInputDoFn(value);
+
+    DoFnTester<Integer, Integer> tester = DoFnTester.of(fn);
+
+    tester.processElement(1);
+    tester.processElement(2);
+    tester.processElement(4);
+    tester.processElement(8);
+    assertThat(tester.peekOutputElements(), containsInAnyOrder(0, 0, 0, 0));
+  }
+
+  @Test
+  public void fnWithSideInputExplicit() throws Exception {
+    final PCollectionView<Integer> value =
+        PCollectionViews.singletonView(
+            TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0, 
VarIntCoder.of());
+    DoFn<Integer, Integer> fn = new SideInputDoFn(value);
+
+    DoFnTester<Integer, Integer> tester = DoFnTester.of(fn);
+    tester.setSideInput(value, GlobalWindow.INSTANCE, -2);
+    tester.processElement(16);
+    tester.processElement(32);
+    tester.processElement(64);
+    tester.processElement(128);
+    tester.finishBundle();
+
+    assertThat(tester.peekOutputElements(), containsInAnyOrder(-2, -2, -2, 
-2));
+  }
+
+  private static class SideInputDoFn extends DoFn<Integer, Integer> {
+    private final PCollectionView<Integer> value;
+
+    private SideInputDoFn(PCollectionView<Integer> value) {
+      this.value = value;
+    }
+
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+      c.output(c.sideInput(value));
+    }
+  }
+
   /**
    * A DoFn that adds values to an aggregator and converts input to String in 
processElement.
    */

Reply via email to