Repository: incubator-beam Updated Branches: refs/heads/master 79c26d9c1 -> cb0356932
Fix DoFnTester side inputs The side inputs were being stored as iterables, but being returned as the raw type. Store the side input values directly instead. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/1c1af625 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/1c1af625 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/1c1af625 Branch: refs/heads/master Commit: 1c1af62586db36212ebf76eb8307d1993666afa5 Parents: f0119b2 Author: Thomas Groh <[email protected]> Authored: Thu Jul 14 10:33:22 2016 -0700 Committer: Kenneth Knowles <[email protected]> Committed: Thu Jul 14 17:13:10 2016 -0700 ---------------------------------------------------------------------- .../apache/beam/sdk/transforms/DoFnTester.java | 70 ++++++++---------- .../beam/sdk/transforms/DoFnTesterTest.java | 74 ++++++++++++++++---- 2 files changed, 91 insertions(+), 53 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1c1af625/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java index 8cfb550..a638feb 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java @@ -103,50 +103,35 @@ public class DoFnTester<InputT, OutputT> { * Registers the tuple of values of the side input {@link PCollectionView}s to * pass to the {@link DoFn} under test. * - * <p>If needed, first creates a fresh instance of the {@link DoFn} - * under test. + * <p>Resets the state of this {@link DoFnTester}. * * <p>If this isn't called, {@code DoFnTester} assumes the * {@link DoFn} takes no side inputs. */ - public void setSideInputs(Map<PCollectionView<?>, Iterable<WindowedValue<?>>> sideInputs) { + public void setSideInputs(Map<PCollectionView<?>, Map<BoundedWindow, ?>> sideInputs) { this.sideInputs = sideInputs; resetState(); } /** - * Registers the values of a side input {@link PCollectionView} to - * pass to the {@link DoFn} under test. + * Registers the values of a side input {@link PCollectionView} to pass to the {@link DoFn} under + * test. * - * <p>If needed, first creates a fresh instance of the {@code DoFn} - * under test. + * <p>The provided value is the final value of the side input in the specified window, not + * the value of the input PCollection in that window. * - * <p>If this isn't called, {@code DoFnTester} assumes the - * {@code DoFn} takes no side inputs. + * <p>If this isn't called, {@code DoFnTester} will return the default value for any side input + * that is used. */ - public void setSideInput(PCollectionView<?> sideInput, Iterable<WindowedValue<?>> value) { - sideInputs.put(sideInput, value); - } - - /** - * Registers the values for a side input {@link PCollectionView} to - * pass to the {@link DoFn} under test. All values are placed - * in the global window. - */ - public void setSideInputInGlobalWindow( - PCollectionView<?> sideInput, - Iterable<?> value) { - sideInputs.put( - sideInput, - Iterables.transform(value, new Function<Object, WindowedValue<?>>() { - @Override - public WindowedValue<?> apply(Object input) { - return WindowedValue.valueInGlobalWindow(input); - } - })); + public <T> void setSideInput(PCollectionView<T> sideInput, BoundedWindow window, T value) { + Map<BoundedWindow, T> windowValues = (Map<BoundedWindow, T>) sideInputs.get(sideInput); + if (windowValues == null) { + windowValues = new HashMap<>(); + sideInputs.put(sideInput, windowValues); + } + windowValues.put(window, value); } - /** * Registers the list of {@code TupleTag}s that can be used by the * {@code DoFn} under test to output to side output @@ -523,14 +508,14 @@ public class DoFnTester<InputT, OutputT> { private final TestContext<InT, OutT> context; private final TupleTag<OutT> mainOutputTag; private final WindowedValue<InT> element; - private final Map<PCollectionView<?>, ?> sideInputs; + private final Map<PCollectionView<?>, Map<BoundedWindow, ?>> sideInputs; private TestProcessContext( DoFn<InT, OutT> fn, TestContext<InT, OutT> context, WindowedValue<InT> element, TupleTag<OutT> mainOutputTag, - Map<PCollectionView<?>, ?> sideInputs) { + Map<PCollectionView<?>, Map<BoundedWindow, ?>> sideInputs) { fn.super(); this.context = context; this.element = element; @@ -545,9 +530,17 @@ public class DoFnTester<InputT, OutputT> { @Override public <T> T sideInput(PCollectionView<T> view) { - @SuppressWarnings("unchecked") - T sideInput = (T) sideInputs.get(view); - return sideInput; + Map<BoundedWindow, ?> viewValues = sideInputs.get(view); + if (viewValues != null) { + BoundedWindow sideInputWindow = + view.getWindowingStrategyInternal().getWindowFn().getSideInputWindow(window()); + @SuppressWarnings("unchecked") + T windowValue = (T) viewValues.get(sideInputWindow); + if (windowValue != null) { + return windowValue; + } + } + return view.fromIterableInternal(Collections.<WindowedValue<?>>emptyList()); } @Override @@ -668,7 +661,7 @@ public class DoFnTester<InputT, OutputT> { final DoFn<InputT, OutputT> origFn; /** The side input values to provide to the DoFn under test. */ - private Map<PCollectionView<?>, Iterable<WindowedValue<?>>> sideInputs = + private Map<PCollectionView<?>, Map<BoundedWindow, ?>> sideInputs = new HashMap<>(); private Map<String, Object> accumulators; @@ -703,11 +696,6 @@ public class DoFnTester<InputT, OutputT> { SerializableUtils.deserializeFromByteArray( SerializableUtils.serializeToByteArray(origFn), origFn.toString()); - PTuple runnerSideInputs = PTuple.empty(); - for (Map.Entry<PCollectionView<?>, Iterable<WindowedValue<?>>> entry - : sideInputs.entrySet()) { - runnerSideInputs = runnerSideInputs.and(entry.getKey().getTagInternal(), entry.getValue()); - } outputs = new HashMap<>(); accumulators = new HashMap<>(); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1c1af625/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java index b391671..8460a7c 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java @@ -24,8 +24,13 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.util.PCollectionViews; +import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TimestampedValue; import org.hamcrest.Matchers; @@ -150,19 +155,15 @@ public class DoFnTesterTest { tester.processElement(2L); List<TimestampedValue<String>> peek = tester.peekOutputElementsWithTimestamp(); - TimestampedValue<String> one = - TimestampedValue.of("1", new Instant(1000L)); - TimestampedValue<String> two = - TimestampedValue.of("2", new Instant(2000L)); + TimestampedValue<String> one = TimestampedValue.of("1", new Instant(1000L)); + TimestampedValue<String> two = TimestampedValue.of("2", new Instant(2000L)); assertThat(peek, hasItems(one, two)); tester.processElement(3L); tester.processElement(4L); - TimestampedValue<String> three = - TimestampedValue.of("3", new Instant(3000L)); - TimestampedValue<String> four = - TimestampedValue.of("4", new Instant(4000L)); + TimestampedValue<String> three = TimestampedValue.of("3", new Instant(3000L)); + TimestampedValue<String> four = TimestampedValue.of("4", new Instant(4000L)); peek = tester.peekOutputElementsWithTimestamp(); assertThat(peek, hasItems(one, two, three, four)); List<TimestampedValue<String>> take = tester.takeOutputElementsWithTimestamp(); @@ -219,14 +220,63 @@ public class DoFnTesterTest { tester.processElement(2L); tester.finishBundle(); - assertThat(tester.peekOutputElementsInWindow(GlobalWindow.INSTANCE), - containsInAnyOrder(TimestampedValue.of("1", new Instant(1000L)), + assertThat( + tester.peekOutputElementsInWindow(GlobalWindow.INSTANCE), + containsInAnyOrder( + TimestampedValue.of("1", new Instant(1000L)), TimestampedValue.of("2", new Instant(2000L)))); - assertThat(tester.peekOutputElementsInWindow( - new IntervalWindow(new Instant(0L), new Instant(10L))), + assertThat( + tester.peekOutputElementsInWindow(new IntervalWindow(new Instant(0L), new Instant(10L))), Matchers.<TimestampedValue<String>>emptyIterable()); } + @Test + public void fnWithSideInputDefault() throws Exception { + final PCollectionView<Integer> value = + PCollectionViews.singletonView( + TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0, VarIntCoder.of()); + DoFn<Integer, Integer> fn = new SideInputDoFn(value); + + DoFnTester<Integer, Integer> tester = DoFnTester.of(fn); + + tester.processElement(1); + tester.processElement(2); + tester.processElement(4); + tester.processElement(8); + assertThat(tester.peekOutputElements(), containsInAnyOrder(0, 0, 0, 0)); + } + + @Test + public void fnWithSideInputExplicit() throws Exception { + final PCollectionView<Integer> value = + PCollectionViews.singletonView( + TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0, VarIntCoder.of()); + DoFn<Integer, Integer> fn = new SideInputDoFn(value); + + DoFnTester<Integer, Integer> tester = DoFnTester.of(fn); + tester.setSideInput(value, GlobalWindow.INSTANCE, -2); + tester.processElement(16); + tester.processElement(32); + tester.processElement(64); + tester.processElement(128); + tester.finishBundle(); + + assertThat(tester.peekOutputElements(), containsInAnyOrder(-2, -2, -2, -2)); + } + + private static class SideInputDoFn extends DoFn<Integer, Integer> { + private final PCollectionView<Integer> value; + + private SideInputDoFn(PCollectionView<Integer> value) { + this.value = value; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + c.output(c.sideInput(value)); + } + } + /** * A DoFn that adds values to an aggregator and converts input to String in processElement. */
