Support for @Setup and @Teardown in DoFnTester
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/bef0e9de Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/bef0e9de Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/bef0e9de Branch: refs/heads/master Commit: bef0e9de02be051411f20b298168e8477ed1a0da Parents: 9009802 Author: Eugene Kirpichov <[email protected]> Authored: Mon Sep 26 16:58:20 2016 -0700 Committer: Dan Halperin <[email protected]> Committed: Tue Sep 27 14:57:49 2016 -0700 ---------------------------------------------------------------------- .../apache/beam/sdk/transforms/DoFnTester.java | 120 +++-- .../beam/sdk/transforms/DoFnTesterTest.java | 456 +++++++++++-------- 2 files changed, 338 insertions(+), 238 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/bef0e9de/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java index 0e018ba..9adb806 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java @@ -78,10 +78,10 @@ import org.joda.time.Instant; * @param <InputT> the type of the {@link DoFn}'s (main) input elements * @param <OutputT> the type of the {@link DoFn}'s (main) output elements */ -public class DoFnTester<InputT, OutputT> { +public class DoFnTester<InputT, OutputT> implements AutoCloseable { /** * Returns a {@code DoFnTester} supporting unit-testing of the given - * {@link DoFn}. + * {@link DoFn}. By default, uses {@link CloningBehavior#CLONE_ONCE}. */ @SuppressWarnings("unchecked") public static <InputT, OutputT> DoFnTester<InputT, OutputT> of(DoFn<InputT, OutputT> fn) { @@ -91,6 +91,8 @@ public class DoFnTester<InputT, OutputT> { /** * Returns a {@code DoFnTester} supporting unit-testing of the given * {@link OldDoFn}. + * + * @see #of(DoFn) */ @SuppressWarnings("unchecked") public static <InputT, OutputT> DoFnTester<InputT, OutputT> @@ -108,8 +110,11 @@ public class DoFnTester<InputT, OutputT> { * {@link DoFn} takes no side inputs. */ public void setSideInputs(Map<PCollectionView<?>, Map<BoundedWindow, ?>> sideInputs) { + checkState( + state == State.UNINITIALIZED, + "Can't add side inputs: DoFnTester is already initialized, in state %s", + state); this.sideInputs = sideInputs; - resetState(); } /** @@ -123,6 +128,10 @@ public class DoFnTester<InputT, OutputT> { * that is used. */ public <T> void setSideInput(PCollectionView<T> sideInput, BoundedWindow window, T value) { + checkState( + state == State.UNINITIALIZED, + "Can't add side inputs: DoFnTester is already initialized, in state %s", + state); Map<BoundedWindow, T> windowValues = (Map<BoundedWindow, T>) sideInputs.get(sideInput); if (windowValues == null) { windowValues = new HashMap<>(); @@ -132,10 +141,24 @@ public class DoFnTester<InputT, OutputT> { } /** - * Whether or not a {@link DoFnTester} should clone the {@link DoFn} under test. + * When a {@link DoFnTester} should clone the {@link DoFn} under test and how it should manage + * the lifecycle of the {@link DoFn}. */ public enum CloningBehavior { - CLONE, + /** + * Clone the {@link DoFn} and call {@link DoFn.Setup} every time a bundle starts; call {@link + * DoFn.Teardown} every time a bundle finishes. + */ + CLONE_PER_BUNDLE, + /** + * Clone the {@link DoFn} and call {@link DoFn.Setup} on the first access; call {@link + * DoFn.Teardown} only explicitly. + */ + CLONE_ONCE, + /** + * Do not clone the {@link DoFn}; call {@link DoFn.Setup} on the first access; call {@link + * DoFn.Teardown} only explicitly. + */ DO_NOT_CLONE } @@ -143,6 +166,7 @@ public class DoFnTester<InputT, OutputT> { * Instruct this {@link DoFnTester} whether or not to clone the {@link DoFn} under test. */ public void setCloningBehavior(CloningBehavior newValue) { + checkState(state == State.UNINITIALIZED, "Wrong state: %s", state); this.cloningBehavior = newValue; } @@ -187,11 +211,17 @@ public class DoFnTester<InputT, OutputT> { /** * Calls the {@link DoFn.StartBundle} method on the {@link DoFn} under test. * - * <p>If needed, first creates a fresh instance of the {@link DoFn} under test. + * <p>If needed, first creates a fresh instance of the {@link DoFn} under test and calls + * {@link DoFn.Setup}. */ public void startBundle() throws Exception { - resetState(); - initializeState(); + checkState( + state == State.UNINITIALIZED || state == State.BUNDLE_FINISHED, + "Wrong state during startBundle: %s", + state); + if (state == State.UNINITIALIZED) { + initializeState(); + } TestContext<InputT, OutputT> context = createContext(fn); context.setupDelegateAggregators(); try { @@ -199,7 +229,7 @@ public class DoFnTester<InputT, OutputT> { } catch (UserCodeException e) { unwrapUserCodeException(e); } - state = State.STARTED; + state = State.BUNDLE_STARTED; } private static void unwrapUserCodeException(UserCodeException e) throws Exception { @@ -236,15 +266,10 @@ public class DoFnTester<InputT, OutputT> { * already been called. * * <p>If the input timestamp is {@literal null}, the minimum timestamp will be used. - * - * @throws IllegalStateException if the {@code OldDoFn} under test has already - * been finished */ public void processTimestampedElement(TimestampedValue<InputT> element) throws Exception { checkNotNull(element, "Timestamped element cannot be null"); - checkState(state != State.FINISHED, "finishBundle() has already been called"); - - if (state == State.UNSTARTED) { + if (state != State.BUNDLE_STARTED) { startBundle(); } try { @@ -257,25 +282,30 @@ public class DoFnTester<InputT, OutputT> { /** * Calls the {@link DoFn.FinishBundle} method of the {@link DoFn} under test. * - * <p>Will call {@link #startBundle} automatically, if it hasn't - * already been called. + * <p>If {@link #setCloningBehavior} was called with {@link CloningBehavior#CLONE_PER_BUNDLE}, + * then also calls {@link DoFn.Teardown} on the {@link DoFn}, and it will be cloned and + * {@link DoFn.Setup} again when processing the next bundle. * - * @throws IllegalStateException if the {@link DoFn} under test has already - * been finished + * @throws IllegalStateException if {@link DoFn.FinishBundle} has already been called + * for this bundle. */ public void finishBundle() throws Exception { - if (state == State.FINISHED) { - throw new IllegalStateException("finishBundle() has already been called"); - } - if (state == State.UNSTARTED) { - startBundle(); - } + checkState( + state == State.BUNDLE_STARTED, + "Must be inside bundle to call finishBundle, but was: %s", + state); try { fn.finishBundle(createContext(fn)); } catch (UserCodeException e) { unwrapUserCodeException(e); } - state = State.FINISHED; + if (cloningBehavior == CloningBehavior.CLONE_PER_BUNDLE) { + fn.teardown(); + fn = null; + state = State.UNINITIALIZED; + } else { + state = State.BUNDLE_FINISHED; + } } /** @@ -695,13 +725,26 @@ public class DoFnTester<InputT, OutputT> { } } + @Override + public void close() throws Exception { + if (state == State.BUNDLE_STARTED) { + finishBundle(); + } + if (state == State.BUNDLE_FINISHED) { + fn.teardown(); + fn = null; + } + state = State.TORN_DOWN; + } + ///////////////////////////////////////////////////////////////////////////// /** The possible states of processing a {@link DoFn}. */ - enum State { - UNSTARTED, - STARTED, - FINISHED + private enum State { + UNINITIALIZED, + BUNDLE_STARTED, + BUNDLE_FINISHED, + TORN_DOWN } private final PipelineOptions options = PipelineOptionsFactory.create(); @@ -714,7 +757,7 @@ public class DoFnTester<InputT, OutputT> { * * <p>Worker-side {@link DoFn DoFns} may not be serializable, and are not required to be. */ - private CloningBehavior cloningBehavior = CloningBehavior.CLONE; + private CloningBehavior cloningBehavior = CloningBehavior.CLONE_ONCE; /** The side input values to provide to the {@link DoFn} under test. */ private Map<PCollectionView<?>, Map<BoundedWindow, ?>> sideInputs = @@ -732,22 +775,16 @@ public class DoFnTester<InputT, OutputT> { private Map<TupleTag<?>, List<WindowedValue<?>>> outputs; /** The state of processing of the {@link DoFn} under test. */ - private State state; + private State state = State.UNINITIALIZED; private DoFnTester(OldDoFn<InputT, OutputT> origFn) { this.origFn = origFn; - resetState(); - } - - private void resetState() { - fn = null; - outputs = null; - accumulators = null; - state = State.UNSTARTED; } @SuppressWarnings("unchecked") - private void initializeState() { + private void initializeState() throws Exception { + checkState(state == State.UNINITIALIZED, "Already initialized"); + checkState(fn == null, "Uninitialized but fn != null"); if (cloningBehavior.equals(CloningBehavior.DO_NOT_CLONE)) { fn = origFn; } else { @@ -756,6 +793,7 @@ public class DoFnTester<InputT, OutputT> { SerializableUtils.serializeToByteArray(origFn), origFn.toString()); } + fn.setup(); outputs = new HashMap<>(); accumulators = new HashMap<>(); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/bef0e9de/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java index 3ed30fd..f208488 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java @@ -17,15 +17,17 @@ */ package org.apache.beam.sdk.transforms; +import static com.google.common.base.Preconditions.checkState; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItems; -import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; @@ -51,122 +53,180 @@ public class DoFnTesterTest { @Test public void processElement() throws Exception { - CounterDoFn counterDoFn = new CounterDoFn(); - DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn); - - tester.processElement(1L); - - List<String> take = tester.takeOutputElements(); + for (DoFnTester.CloningBehavior cloning : DoFnTester.CloningBehavior.values()) { + try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) { + tester.setCloningBehavior(cloning); + tester.processElement(1L); - assertThat(take, hasItems("1")); + List<String> take = tester.takeOutputElements(); - // Following takeOutputElements(), neither takeOutputElements() - // nor peekOutputElements() return anything. - assertTrue(tester.takeOutputElements().isEmpty()); - assertTrue(tester.peekOutputElements().isEmpty()); + assertThat(take, hasItems("1")); - // processElement() caused startBundle() to be called, but finishBundle() was never called. - CounterDoFn deserializedDoFn = (CounterDoFn) tester.fn; - assertTrue(deserializedDoFn.wasStartBundleCalled()); - assertFalse(deserializedDoFn.wasFinishBundleCalled()); + // Following takeOutputElements(), neither takeOutputElements() + // nor peekOutputElements() return anything. + assertTrue(tester.takeOutputElements().isEmpty()); + assertTrue(tester.peekOutputElements().isEmpty()); + } + } } @Test public void processElementsWithPeeks() throws Exception { - CounterDoFn counterDoFn = new CounterDoFn(); - DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn); - - // Explicitly call startBundle(). - tester.startBundle(); - - // verify startBundle() was called but not finishBundle(). - CounterDoFn deserializedDoFn = (CounterDoFn) tester.fn; - assertTrue(deserializedDoFn.wasStartBundleCalled()); - assertFalse(deserializedDoFn.wasFinishBundleCalled()); - - // process a couple of elements. - tester.processElement(1L); - tester.processElement(2L); - - // peek the first 2 outputs. - List<String> peek = tester.peekOutputElements(); - assertThat(peek, hasItems("1", "2")); - - // process a couple more. - tester.processElement(3L); - tester.processElement(4L); - - // peek all the outputs so far. - peek = tester.peekOutputElements(); - assertThat(peek, hasItems("1", "2", "3", "4")); - // take the outputs. - List<String> take = tester.takeOutputElements(); - assertThat(take, hasItems("1", "2", "3", "4")); - - // Following takeOutputElements(), neither takeOutputElements() - // nor peekOutputElements() return anything. - assertTrue(tester.peekOutputElements().isEmpty()); - assertTrue(tester.takeOutputElements().isEmpty()); - - // verify finishBundle() hasn't been called yet. - assertTrue(deserializedDoFn.wasStartBundleCalled()); - assertFalse(deserializedDoFn.wasFinishBundleCalled()); - - // process a couple more. - tester.processElement(5L); - tester.processElement(6L); - - // peek and take now have only the 2 last outputs. - peek = tester.peekOutputElements(); - assertThat(peek, hasItems("5", "6")); - take = tester.takeOutputElements(); - assertThat(take, hasItems("5", "6")); - - tester.finishBundle(); - - // verify finishBundle() was called. - assertTrue(deserializedDoFn.wasStartBundleCalled()); - assertTrue(deserializedDoFn.wasFinishBundleCalled()); + for (DoFnTester.CloningBehavior cloning : DoFnTester.CloningBehavior.values()) { + try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) { + tester.setCloningBehavior(cloning); + // Explicitly call startBundle(). + tester.startBundle(); + + // process a couple of elements. + tester.processElement(1L); + tester.processElement(2L); + + // peek the first 2 outputs. + List<String> peek = tester.peekOutputElements(); + assertThat(peek, hasItems("1", "2")); + + // process a couple more. + tester.processElement(3L); + tester.processElement(4L); + + // peek all the outputs so far. + peek = tester.peekOutputElements(); + assertThat(peek, hasItems("1", "2", "3", "4")); + // take the outputs. + List<String> take = tester.takeOutputElements(); + assertThat(take, hasItems("1", "2", "3", "4")); + + // Following takeOutputElements(), neither takeOutputElements() + // nor peekOutputElements() return anything. + assertTrue(tester.peekOutputElements().isEmpty()); + assertTrue(tester.takeOutputElements().isEmpty()); + + // process a couple more. + tester.processElement(5L); + tester.processElement(6L); + + // peek and take now have only the 2 last outputs. + peek = tester.peekOutputElements(); + assertThat(peek, hasItems("5", "6")); + take = tester.takeOutputElements(); + assertThat(take, hasItems("5", "6")); + + tester.finishBundle(); + } + } } @Test - public void processElementAfterFinish() throws Exception { - DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn()); - tester.finishBundle(); + public void processBundle() throws Exception { + for (DoFnTester.CloningBehavior cloning : DoFnTester.CloningBehavior.values()) { + try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) { + tester.setCloningBehavior(cloning); + // processBundle() returns all the output like takeOutputElements(). + assertThat(tester.processBundle(1L, 2L, 3L, 4L), hasItems("1", "2", "3", "4")); + + // peek now returns nothing. + assertTrue(tester.peekOutputElements().isEmpty()); + } + } + } - thrown.expect(IllegalStateException.class); - thrown.expectMessage("finishBundle() has already been called"); - tester.processElement(1L); + @Test + public void processMultipleBundles() throws Exception { + for (DoFnTester.CloningBehavior cloning : DoFnTester.CloningBehavior.values()) { + try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) { + tester.setCloningBehavior(cloning); + // processBundle() returns all the output like takeOutputElements(). + assertThat(tester.processBundle(1L, 2L, 3L, 4L), hasItems("1", "2", "3", "4")); + assertThat(tester.processBundle(5L, 6L, 7L), hasItems("5", "6", "7")); + assertThat(tester.processBundle(8L, 9L), hasItems("8", "9")); + + // peek now returns nothing. + assertTrue(tester.peekOutputElements().isEmpty()); + } + } } @Test - public void processBatch() throws Exception { - CounterDoFn counterDoFn = new CounterDoFn(); - DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn); + public void doNotClone() throws Exception { + final AtomicInteger numSetupCalls = new AtomicInteger(); + final AtomicInteger numTeardownCalls = new AtomicInteger(); + DoFn<Long, String> fn = + new DoFn<Long, String>() { + @ProcessElement + public void process(ProcessContext context) {} + + @Setup + public void setup() { + numSetupCalls.addAndGet(1); + } + + @Teardown + public void teardown() { + numTeardownCalls.addAndGet(1); + } + }; + + try (DoFnTester<Long, String> tester = DoFnTester.of(fn)) { + tester.setCloningBehavior(DoFnTester.CloningBehavior.DO_NOT_CLONE); + + tester.processBundle(1L, 2L, 3L); + tester.processBundle(4L, 5L); + tester.processBundle(6L); + } + assertEquals(1, numSetupCalls.get()); + assertEquals(1, numTeardownCalls.get()); + } - // processBundle() returns all the output like takeOutputElements(). - List<String> take = tester.processBundle(1L, 2L, 3L, 4L); + private static class CountBundleCallsFn extends DoFn<Long, String> { + private int numStartBundleCalls = 0; + private int numFinishBundleCalls = 0; - assertThat(take, hasItems("1", "2", "3", "4")); + @ProcessElement + public void process(ProcessContext context) { + context.output(numStartBundleCalls + "/" + numFinishBundleCalls); + } - // peek now returns nothing. - assertTrue(tester.peekOutputElements().isEmpty()); + @StartBundle + public void startBundle(Context context) { + ++numStartBundleCalls; + } - // verify startBundle() and finishBundle() were both called. - CounterDoFn deserializedDoFn = (CounterDoFn) tester.fn; - assertTrue(deserializedDoFn.wasStartBundleCalled()); - assertTrue(deserializedDoFn.wasFinishBundleCalled()); + @FinishBundle + public void finishBundle(Context context) { + ++numFinishBundleCalls; + } } @Test - public void processTimestampedElement() throws Exception { - DoFn<Long, TimestampedValue<Long>> reifyTimestamps = new ReifyTimestamps(); + public void cloneOnce() throws Exception { + try (DoFnTester<Long, String> tester = DoFnTester.of(new CountBundleCallsFn())) { + tester.setCloningBehavior(DoFnTester.CloningBehavior.CLONE_ONCE); + + assertThat(tester.processBundle(1L, 2L, 3L), contains("1/0", "1/0", "1/0")); + assertThat(tester.processBundle(4L, 5L), contains("2/1", "2/1")); + assertThat(tester.processBundle(6L), contains("3/2")); + } + } + + @Test + public void clonePerBundle() throws Exception { + try (DoFnTester<Long, String> tester = DoFnTester.of(new CountBundleCallsFn())) { + tester.setCloningBehavior(DoFnTester.CloningBehavior.CLONE_PER_BUNDLE); - DoFnTester<Long, TimestampedValue<Long>> tester = DoFnTester.of(reifyTimestamps); + assertThat(tester.processBundle(1L, 2L, 3L), contains("1/0", "1/0", "1/0")); + assertThat(tester.processBundle(4L, 5L), contains("1/0", "1/0")); + assertThat(tester.processBundle(6L), contains("1/0")); + } + } - TimestampedValue<Long> input = TimestampedValue.of(1L, new Instant(100)); - tester.processTimestampedElement(input); - assertThat(tester.takeOutputElements(), contains(input)); + @Test + public void processTimestampedElement() throws Exception { + try (DoFnTester<Long, TimestampedValue<Long>> tester = DoFnTester.of(new ReifyTimestamps())) { + TimestampedValue<Long> input = TimestampedValue.of(1L, new Instant(100)); + tester.processTimestampedElement(input); + assertThat(tester.takeOutputElements(), contains(input)); + } } static class ReifyTimestamps extends DoFn<Long, TimestampedValue<Long>> { @@ -178,86 +238,83 @@ public class DoFnTesterTest { @Test public void processElementWithOutputTimestamp() throws Exception { - CounterDoFn counterDoFn = new CounterDoFn(); - DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn); - - tester.processElement(1L); - tester.processElement(2L); - - List<TimestampedValue<String>> peek = tester.peekOutputElementsWithTimestamp(); - TimestampedValue<String> one = TimestampedValue.of("1", new Instant(1000L)); - TimestampedValue<String> two = TimestampedValue.of("2", new Instant(2000L)); - assertThat(peek, hasItems(one, two)); - - tester.processElement(3L); - tester.processElement(4L); - - TimestampedValue<String> three = TimestampedValue.of("3", new Instant(3000L)); - TimestampedValue<String> four = TimestampedValue.of("4", new Instant(4000L)); - peek = tester.peekOutputElementsWithTimestamp(); - assertThat(peek, hasItems(one, two, three, four)); - List<TimestampedValue<String>> take = tester.takeOutputElementsWithTimestamp(); - assertThat(take, hasItems(one, two, three, four)); - - // Following takeOutputElementsWithTimestamp(), neither takeOutputElementsWithTimestamp() - // nor peekOutputElementsWithTimestamp() return anything. - assertTrue(tester.takeOutputElementsWithTimestamp().isEmpty()); - assertTrue(tester.peekOutputElementsWithTimestamp().isEmpty()); - - // peekOutputElements() and takeOutputElements() also return nothing. - assertTrue(tester.peekOutputElements().isEmpty()); - assertTrue(tester.takeOutputElements().isEmpty()); + try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) { + tester.processElement(1L); + tester.processElement(2L); + + List<TimestampedValue<String>> peek = tester.peekOutputElementsWithTimestamp(); + TimestampedValue<String> one = TimestampedValue.of("1", new Instant(1000L)); + TimestampedValue<String> two = TimestampedValue.of("2", new Instant(2000L)); + assertThat(peek, hasItems(one, two)); + + tester.processElement(3L); + tester.processElement(4L); + + TimestampedValue<String> three = TimestampedValue.of("3", new Instant(3000L)); + TimestampedValue<String> four = TimestampedValue.of("4", new Instant(4000L)); + peek = tester.peekOutputElementsWithTimestamp(); + assertThat(peek, hasItems(one, two, three, four)); + List<TimestampedValue<String>> take = tester.takeOutputElementsWithTimestamp(); + assertThat(take, hasItems(one, two, three, four)); + + // Following takeOutputElementsWithTimestamp(), neither takeOutputElementsWithTimestamp() + // nor peekOutputElementsWithTimestamp() return anything. + assertTrue(tester.takeOutputElementsWithTimestamp().isEmpty()); + assertTrue(tester.peekOutputElementsWithTimestamp().isEmpty()); + + // peekOutputElements() and takeOutputElements() also return nothing. + assertTrue(tester.peekOutputElements().isEmpty()); + assertTrue(tester.takeOutputElements().isEmpty()); + } } @Test public void getAggregatorValuesShouldGetValueOfCounter() throws Exception { CounterDoFn counterDoFn = new CounterDoFn(); - DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn); - tester.processBundle(1L, 2L, 4L, 8L); - - Long aggregatorVal = tester.getAggregatorValue(counterDoFn.agg); - - assertThat(aggregatorVal, equalTo(15L)); + try (DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn)) { + tester.processBundle(1L, 2L, 4L, 8L); + assertThat(tester.getAggregatorValue(counterDoFn.agg), equalTo(15L)); + } } @Test public void getAggregatorValuesWithEmptyCounterShouldSucceed() throws Exception { CounterDoFn counterDoFn = new CounterDoFn(); - DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn); - tester.processBundle(); - Long aggregatorVal = tester.getAggregatorValue(counterDoFn.agg); - // empty bundle - assertThat(aggregatorVal, equalTo(0L)); + try (DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn)) { + tester.processBundle(); + // empty bundle + assertThat(tester.getAggregatorValue(counterDoFn.agg), equalTo(0L)); + } } @Test public void getAggregatorValuesInStartFinishBundleShouldGetValues() throws Exception { - CounterDoFn fn = new CounterDoFn(1L, 2L); - DoFnTester<Long, String> tester = DoFnTester.of(fn); - tester.processBundle(0L, 0L); + CounterDoFn fn = new CounterDoFn(); + try (DoFnTester<Long, String> tester = DoFnTester.of(fn)) { + tester.processBundle(1L, 2L, 3L, 4L); - Long aggValue = tester.getAggregatorValue(fn.agg); - assertThat(aggValue, equalTo(1L + 2L)); + assertThat(tester.getAggregatorValue(fn.startBundleCalls), equalTo(1L)); + assertThat(tester.getAggregatorValue(fn.finishBundleCalls), equalTo(1L)); + } } @Test public void peekValuesInWindow() throws Exception { - CounterDoFn fn = new CounterDoFn(1L, 2L); - DoFnTester<Long, String> tester = DoFnTester.of(fn); - - tester.startBundle(); - tester.processElement(1L); - tester.processElement(2L); - tester.finishBundle(); - - assertThat( - tester.peekOutputElementsInWindow(GlobalWindow.INSTANCE), - containsInAnyOrder( - TimestampedValue.of("1", new Instant(1000L)), - TimestampedValue.of("2", new Instant(2000L)))); - assertThat( - tester.peekOutputElementsInWindow(new IntervalWindow(new Instant(0L), new Instant(10L))), - Matchers.<TimestampedValue<String>>emptyIterable()); + try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) { + tester.startBundle(); + tester.processElement(1L); + tester.processElement(2L); + tester.finishBundle(); + + assertThat( + tester.peekOutputElementsInWindow(GlobalWindow.INSTANCE), + containsInAnyOrder( + TimestampedValue.of("1", new Instant(1000L)), + TimestampedValue.of("2", new Instant(2000L)))); + assertThat( + tester.peekOutputElementsInWindow(new IntervalWindow(new Instant(0L), new Instant(10L))), + Matchers.<TimestampedValue<String>>emptyIterable()); + } } @Test @@ -265,15 +322,14 @@ public class DoFnTesterTest { final PCollectionView<Integer> value = PCollectionViews.singletonView( TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0, VarIntCoder.of()); - OldDoFn<Integer, Integer> fn = new SideInputDoFn(value); - - DoFnTester<Integer, Integer> tester = DoFnTester.of(fn); - tester.processElement(1); - tester.processElement(2); - tester.processElement(4); - tester.processElement(8); - assertThat(tester.peekOutputElements(), containsInAnyOrder(0, 0, 0, 0)); + try (DoFnTester<Integer, Integer> tester = DoFnTester.of(new SideInputDoFn(value))) { + tester.processElement(1); + tester.processElement(2); + tester.processElement(4); + tester.processElement(8); + assertThat(tester.peekOutputElements(), containsInAnyOrder(0, 0, 0, 0)); + } } @Test @@ -281,17 +337,17 @@ public class DoFnTesterTest { final PCollectionView<Integer> value = PCollectionViews.singletonView( TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0, VarIntCoder.of()); - OldDoFn<Integer, Integer> fn = new SideInputDoFn(value); - DoFnTester<Integer, Integer> tester = DoFnTester.of(fn); - tester.setSideInput(value, GlobalWindow.INSTANCE, -2); - tester.processElement(16); - tester.processElement(32); - tester.processElement(64); - tester.processElement(128); - tester.finishBundle(); + try (DoFnTester<Integer, Integer> tester = DoFnTester.of(new SideInputDoFn(value))) { + tester.setSideInput(value, GlobalWindow.INSTANCE, -2); + tester.processElement(16); + tester.processElement(32); + tester.processElement(64); + tester.processElement(128); + tester.finishBundle(); - assertThat(tester.peekOutputElements(), containsInAnyOrder(-2, -2, -2, -2)); + assertThat(tester.peekOutputElements(), containsInAnyOrder(-2, -2, -2, -2)); + } } private static class SideInputDoFn extends OldDoFn<Integer, Integer> { @@ -308,50 +364,56 @@ public class DoFnTesterTest { } /** - * An {@link OldDoFn} that adds values to an aggregator and converts input to String in + * A {@link DoFn} that adds values to an aggregator and converts input to String in * {@link OldDoFn#processElement). */ - private static class CounterDoFn extends OldDoFn<Long, String> { + private static class CounterDoFn extends DoFn<Long, String> { Aggregator<Long, Long> agg = createAggregator("ctr", new Sum.SumLongFn()); - private final long startBundleVal; - private final long finishBundleVal; - private boolean startBundleCalled; - private boolean finishBundleCalled; - - public CounterDoFn() { - this(0L, 0L); + Aggregator<Long, Long> startBundleCalls = + createAggregator("startBundleCalls", new Sum.SumLongFn()); + Aggregator<Long, Long> finishBundleCalls = + createAggregator("finishBundleCalls", new Sum.SumLongFn()); + + private enum LifecycleState { + UNINITIALIZED, + SET_UP, + INSIDE_BUNDLE, + TORN_DOWN } + private LifecycleState state = LifecycleState.UNINITIALIZED; - public CounterDoFn(long start, long finish) { - this.startBundleVal = start; - this.finishBundleVal = finish; + @Setup + public void setup() { + checkState(state == LifecycleState.UNINITIALIZED, "Wrong state: %s", state); + state = LifecycleState.SET_UP; } - @Override + @StartBundle public void startBundle(Context c) { - agg.addValue(startBundleVal); - startBundleCalled = true; + checkState(state == LifecycleState.SET_UP, "Wrong state: %s", state); + state = LifecycleState.INSIDE_BUNDLE; + startBundleCalls.addValue(1L); } - @Override + @ProcessElement public void processElement(ProcessContext c) throws Exception { + checkState(state == LifecycleState.INSIDE_BUNDLE, "Wrong state: %s", state); agg.addValue(c.element()); Instant instant = new Instant(1000L * c.element()); c.outputWithTimestamp(c.element().toString(), instant); } - @Override + @FinishBundle public void finishBundle(Context c) { - agg.addValue(finishBundleVal); - finishBundleCalled = true; - } - - boolean wasStartBundleCalled() { - return startBundleCalled; + checkState(state == LifecycleState.INSIDE_BUNDLE, "Wrong state: %s", state); + state = LifecycleState.SET_UP; + finishBundleCalls.addValue(1L); } - boolean wasFinishBundleCalled() { - return finishBundleCalled; + @Teardown + public void teardown() { + checkState(state == LifecycleState.SET_UP, "Wrong state: %s", state); + state = LifecycleState.TORN_DOWN; } } }
