Add State parameter support to SimpleDoFnRunner
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/e17dc4af Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/e17dc4af Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/e17dc4af Branch: refs/heads/master Commit: e17dc4af9f7de717872d6c6f0ab52e0498f3b782 Parents: 1b7b065 Author: Kenneth Knowles <k...@google.com> Authored: Wed Nov 9 21:10:51 2016 -0800 Committer: Kenneth Knowles <k...@google.com> Committed: Mon Nov 28 11:43:21 2016 -0800 ---------------------------------------------------------------------- .../beam/runners/core/SimpleDoFnRunner.java | 60 +++++++++++++-- .../org/apache/beam/sdk/transforms/ParDo.java | 10 --- .../sdk/transforms/reflect/DoFnSignature.java | 1 + .../apache/beam/sdk/transforms/ParDoTest.java | 79 ++++++++++++++++---- 4 files changed, 118 insertions(+), 32 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e17dc4af/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java ---------------------------------------------------------------------- diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java index f611c0a..68751f0 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java @@ -25,7 +25,9 @@ import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.Set; +import javax.annotation.Nullable; import org.apache.beam.runners.core.DoFnRunners.OutputManager; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.Aggregator.AggregatorFactory; @@ -37,6 +39,7 @@ import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; import org.apache.beam.sdk.transforms.DoFn.ProcessContext; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -55,6 +58,10 @@ import org.apache.beam.sdk.util.WindowingInternals; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.util.state.State; import org.apache.beam.sdk.util.state.StateInternals; +import org.apache.beam.sdk.util.state.StateNamespace; +import org.apache.beam.sdk.util.state.StateNamespaces; +import org.apache.beam.sdk.util.state.StateSpec; +import org.apache.beam.sdk.util.state.StateTags; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.joda.time.Instant; @@ -87,6 +94,13 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out private final boolean observesWindow; + private final DoFnSignature signature; + + private final Coder<BoundedWindow> windowCoder; + + // Because of setKey(Object), we really must refresh stateInternals() at each access + private final StepContext stepContext; + public SimpleDoFnRunner( PipelineOptions options, DoFn<InputT, OutputT> fn, @@ -98,11 +112,20 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out AggregatorFactory aggregatorFactory, WindowingStrategy<?, ?> windowingStrategy) { this.fn = fn; - this.observesWindow = - DoFnSignatures.getSignature(fn.getClass()).processElement().observesWindow(); + this.signature = DoFnSignatures.getSignature(fn.getClass()); + this.observesWindow = signature.processElement().observesWindow(); this.invoker = DoFnInvokers.invokerFor(fn); this.outputManager = outputManager; this.mainOutputTag = mainOutputTag; + this.stepContext = stepContext; + + // This is a cast of an _invariant_ coder. But we are assured by pipeline validation + // that it really is the coder for whatever BoundedWindow subclass is provided + @SuppressWarnings("unchecked") + Coder<BoundedWindow> untypedCoder = + (Coder<BoundedWindow>) windowingStrategy.getWindowFn().windowCoder(); + this.windowCoder = untypedCoder; + this.context = new DoFnContext<>( options, @@ -113,7 +136,7 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out sideOutputTags, stepContext, aggregatorFactory, - windowingStrategy == null ? null : windowingStrategy.getWindowFn()); + windowingStrategy.getWindowFn()); } @Override @@ -427,6 +450,23 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out final DoFnContext<InputT, OutputT> context; final WindowedValue<InputT> windowedValue; + /** Lazily initialized; should only be accessed via {@link #getNamespace()}. */ + @Nullable private StateNamespace namespace; + + /** + * The state namespace for this context. + * + * <p>Any call to {@link #getNamespace()} when more than one window is present will crash; this + * represents a bug in the runner or the {@link DoFnSignature}, since values must be in exactly + * one window when state or timers are relevant. + */ + private StateNamespace getNamespace() { + if (namespace == null) { + namespace = StateNamespaces.window(windowCoder, window()); + } + return namespace; + } + private DoFnProcessContext( DoFn<InputT, OutputT> fn, DoFnContext<InputT, OutputT> context, @@ -564,8 +604,16 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out } @Override - public State state(String timerId) { - throw new UnsupportedOperationException("State parameters are not supported."); + public State state(String stateId) { + try { + StateSpec<?, ?> spec = + (StateSpec<?, ?>) signature.stateDeclarations().get(stateId).field().get(fn); + return stepContext + .stateInternals() + .state(getNamespace(), StateTags.tagForSpec(stateId, (StateSpec) spec)); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } } @Override @@ -593,7 +641,7 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out @Override public StateInternals<?> stateInternals() { - return context.stepContext.stateInternals(); + return stepContext.stateInternals(); } @Override http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e17dc4af/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index 215ae6a..9453294 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -596,16 +596,6 @@ public class ParDo { // To be removed when the features are complete and runners have their own adequate // rejection logic - if (!signature.stateDeclarations().isEmpty()) { - throw new UnsupportedOperationException( - String.format("Found %s annotations on %s, but %s cannot yet be used with state.", - DoFn.StateId.class.getSimpleName(), - fn.getClass().getName(), - DoFn.class.getSimpleName())); - } - - // To be removed when the features are complete and runners have their own adequate - // rejection logic if (!signature.timerDeclarations().isEmpty()) { throw new UnsupportedOperationException( String.format("Found %s annotations on %s, but %s cannot yet be used with timers.", http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e17dc4af/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 1c16030..cd93583 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -523,6 +523,7 @@ public abstract class DoFnSignature { static StateDeclaration create( String id, Field field, TypeDescriptor<? extends State> stateType) { + field.setAccessible(true); return new AutoValue_DoFnSignature_StateDeclaration(id, field, stateType); } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e17dc4af/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 3c3e266..be1eaa4 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -36,6 +36,9 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import com.fasterxml.jackson.annotation.JsonCreator; +import com.google.common.base.MoreObjects; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -68,6 +71,7 @@ import org.apache.beam.sdk.util.TimeDomain; import org.apache.beam.sdk.util.TimerSpec; import org.apache.beam.sdk.util.TimerSpecs; import org.apache.beam.sdk.util.common.ElementByteSizeObserver; +import org.apache.beam.sdk.util.state.BagState; import org.apache.beam.sdk.util.state.StateSpec; import org.apache.beam.sdk.util.state.StateSpecs; import org.apache.beam.sdk.util.state.ValueState; @@ -1459,27 +1463,70 @@ public class ParDoTest implements Serializable { assertThat(displayData, hasDisplayItem("fn", fn.getClass())); } - /** - * A test that we properly reject {@link DoFn} implementations that - * include {@link DoFn.StateId} annotations, for now. - */ @Test - public void testUnsupportedState() { - thrown.expect(UnsupportedOperationException.class); - thrown.expectMessage("cannot yet be used with state"); + @Category(RunnableOnService.class) + public void testValueState() { + final String stateId = "foo"; + + DoFn<KV<String, Integer>, Integer> fn = + new DoFn<KV<String, Integer>, Integer>() { + + @StateId(stateId) + private final StateSpec<Object, ValueState<Integer>> intState = + StateSpecs.value(VarIntCoder.of()); + + @ProcessElement + public void processElement( + ProcessContext c, @StateId(stateId) ValueState<Integer> state) { + Integer currentValue = MoreObjects.firstNonNull(state.read(), 0); + c.output(currentValue); + state.write(currentValue + 1); + } + }; - DoFn<KV<String, String>, KV<String, String>> fn = - new DoFn<KV<String, String>, KV<String, String>>() { + Pipeline p = TestPipeline.create(); + PCollection<Integer> output = + p.apply(Create.of(KV.of("hello", 42), KV.of("hello", 97), KV.of("hello", 84))) + .apply(ParDo.of(fn)); - @StateId("foo") - private final StateSpec<Object, ValueState<Integer>> intState = - StateSpecs.value(VarIntCoder.of()); + PAssert.that(output).containsInAnyOrder(0, 1, 2); + p.run(); + } - @ProcessElement - public void processElement(ProcessContext c) { } - }; + @Test + @Category(RunnableOnService.class) + public void testBagSTate() { + final String stateId = "foo"; + + DoFn<KV<String, Integer>, List<Integer>> fn = + new DoFn<KV<String, Integer>, List<Integer>>() { + + @StateId(stateId) + private final StateSpec<Object, BagState<Integer>> bufferState = + StateSpecs.bag(VarIntCoder.of()); + + @ProcessElement + public void processElement( + ProcessContext c, @StateId(stateId) BagState<Integer> state) { + Iterable<Integer> currentValue = state.read(); + state.add(c.element().getValue()); + if (Iterables.size(state.read()) >= 4) { + List<Integer> sorted = Lists.newArrayList(currentValue); + Collections.sort(sorted); + c.output(sorted); + } + } + }; - ParDo.of(fn); + Pipeline p = TestPipeline.create(); + PCollection<List<Integer>> output = + p.apply( + Create.of( + KV.of("hello", 97), KV.of("hello", 42), KV.of("hello", 84), KV.of("hello", 12))) + .apply(ParDo.of(fn)); + + PAssert.that(output).containsInAnyOrder(Lists.newArrayList(12, 42, 84, 97)); + p.run(); } @Test