Add analysis and validation of State parameters to DoFn.ProcessElement
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/9e6246e0 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/9e6246e0 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/9e6246e0 Branch: refs/heads/master Commit: 9e6246e0ce6afaef542c214b610d39bfd758e797 Parents: 00c7587 Author: Kenneth Knowles <[email protected]> Authored: Fri Oct 14 11:07:55 2016 -0700 Committer: Kenneth Knowles <[email protected]> Committed: Thu Oct 20 11:47:40 2016 -0700 ---------------------------------------------------------------------- .../beam/sdk/transforms/DoFnAdapters.java | 2 +- .../sdk/transforms/reflect/DoFnInvokers.java | 21 +- .../sdk/transforms/reflect/DoFnSignature.java | 205 ++++++++++++++++- .../sdk/transforms/reflect/DoFnSignatures.java | 93 +++++++- .../DoFnSignaturesSplittableDoFnTest.java | 6 +- .../transforms/reflect/DoFnSignaturesTest.java | 228 ++++++++++++++++++- .../reflect/DoFnSignaturesTestUtils.java | 4 +- 7 files changed, 522 insertions(+), 37 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/9e6246e0/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java index 12d4824..e45679e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java @@ -62,7 +62,7 @@ public class DoFnAdapters { @SuppressWarnings({"unchecked", "rawtypes"}) public static <InputT, OutputT> OldDoFn<InputT, OutputT> toOldDoFn(DoFn<InputT, OutputT> fn) { DoFnSignature signature = DoFnSignatures.INSTANCE.getSignature((Class) fn.getClass()); - if (signature.processElement().usesSingleWindow()) { + if (signature.processElement().observesWindow()) { return new WindowDoFnAdapter<>(fn); } else { return new SimpleDoFnAdapter<>(fn); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/9e6246e0/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java index 8eb6145..dd134b7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java @@ -25,7 +25,7 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Collections; -import java.util.EnumMap; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -508,22 +508,21 @@ public class DoFnInvokers { static { try { - Map<DoFnSignature.Parameter, MethodDescription> methods = - new EnumMap<>(DoFnSignature.Parameter.class); + Map<DoFnSignature.Parameter, MethodDescription> methods = new HashMap<>(); methods.put( - DoFnSignature.Parameter.BOUNDED_WINDOW, + DoFnSignature.Parameter.boundedWindow(), new MethodDescription.ForLoadedMethod( DoFn.ExtraContextFactory.class.getMethod("window"))); methods.put( - DoFnSignature.Parameter.INPUT_PROVIDER, + DoFnSignature.Parameter.inputProvider(), new MethodDescription.ForLoadedMethod( DoFn.ExtraContextFactory.class.getMethod("inputProvider"))); methods.put( - DoFnSignature.Parameter.OUTPUT_RECEIVER, + DoFnSignature.Parameter.outputReceiver(), new MethodDescription.ForLoadedMethod( DoFn.ExtraContextFactory.class.getMethod("outputReceiver"))); methods.put( - DoFnSignature.Parameter.RESTRICTION_TRACKER, + DoFnSignature.Parameter.restrictionTracker(), new MethodDescription.ForLoadedMethod( DoFn.ExtraContextFactory.class.getMethod("restrictionTracker"))); EXTRA_CONTEXT_FACTORY_METHODS = Collections.unmodifiableMap(methods); @@ -539,6 +538,10 @@ public class DoFnInvokers { } } + private static MethodDescription getExtraContextFactoryMethod(DoFnSignature.Parameter param) { + return EXTRA_CONTEXT_FACTORY_METHODS.get(param); + } + private final DoFnSignature.ProcessElementMethod signature; /** Implementation of {@link MethodDelegation} for the {@link ProcessElement} method. */ @@ -562,11 +565,11 @@ public class DoFnInvokers { parameters.add( new StackManipulation.Compound( pushExtraContextFactory, - MethodInvocation.invoke(EXTRA_CONTEXT_FACTORY_METHODS.get(param)), + MethodInvocation.invoke(getExtraContextFactoryMethod(param)), // ExtraContextFactory.restrictionTracker() returns a RestrictionTracker, // but the @ProcessElement method expects a concrete subtype of it. // Insert a downcast. - (param == DoFnSignature.Parameter.RESTRICTION_TRACKER) + DoFnSignature.Parameter.restrictionTracker().equals(param) ? TypeCasting.to( new TypeDescription.ForLoadedType(signature.trackerT().getRawType())) : StackManipulation.Trivial.INSTANCE)); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/9e6246e0/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 0d503d2..1dc1fe3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -18,6 +18,8 @@ package org.apache.beam.sdk.transforms.reflect; import com.google.auto.value.AutoValue; +import com.google.common.base.Predicates; +import com.google.common.collect.Iterables; import com.google.common.reflect.TypeToken; import java.lang.reflect.Field; import java.lang.reflect.Method; @@ -27,8 +29,15 @@ import java.util.Map; import javax.annotation.Nullable; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.InputProvider; +import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; +import org.apache.beam.sdk.transforms.DoFn.StateId; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.BoundedWindowParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.Timer; import org.apache.beam.sdk.util.TimerSpec; import org.apache.beam.sdk.util.state.State; import org.apache.beam.sdk.util.state.StateSpec; @@ -123,12 +132,178 @@ public abstract class DoFnSignature { Method targetMethod(); } - /** A type of optional parameter of the {@link DoFn.ProcessElement} method. */ - public enum Parameter { - BOUNDED_WINDOW, - INPUT_PROVIDER, - OUTPUT_RECEIVER, - RESTRICTION_TRACKER + /** A descriptor for an optional parameter of the {@link DoFn.ProcessElement} method. */ + public abstract static class Parameter { + + // Private as no extensions other than those nested here are permitted + private Parameter() {} + + /** + * Performs case analysis on this {@link Parameter}, processing it with the appropriate + * {@link Cases#dispatch} case of the provided {@link Cases} object. + */ + public <ResultT> ResultT match(Cases<ResultT> cases) { + // This could be done with reflection, but since the number of cases is small and known, + // they are simply inlined. + if (this instanceof BoundedWindowParameter) { + return cases.dispatch((BoundedWindowParameter) this); + } else if (this instanceof RestrictionTrackerParameter) { + return cases.dispatch((RestrictionTrackerParameter) this); + } else if (this instanceof InputProviderParameter) { + return cases.dispatch((InputProviderParameter) this); + } else if (this instanceof OutputReceiverParameter) { + return cases.dispatch((OutputReceiverParameter) this); + } else if (this instanceof StateParameter) { + return cases.dispatch((StateParameter) this); + } else { + throw new IllegalStateException( + String.format("Attempt to case match on unknown %s subclass %s", + Parameter.class.getCanonicalName(), this.getClass().getCanonicalName())); + } + } + + /** + * An interface for destructuring a {@link Parameter}. + */ + public interface Cases<ResultT> { + ResultT dispatch(BoundedWindowParameter p); + ResultT dispatch(InputProviderParameter p); + ResultT dispatch(OutputReceiverParameter p); + ResultT dispatch(RestrictionTrackerParameter p); + ResultT dispatch(StateParameter p); + + /** + * A base class for a visitor with a default method for cases it is not interested in. + */ + public abstract static class WithDefault<ResultT> implements Cases<ResultT> { + + protected abstract ResultT dispatchDefault(Parameter p); + + @Override + public ResultT dispatch(BoundedWindowParameter p) { + return dispatchDefault(p); + } + + @Override + public ResultT dispatch(InputProviderParameter p) { + return dispatchDefault(p); + } + + @Override + public ResultT dispatch(OutputReceiverParameter p) { + return dispatchDefault(p); + } + + @Override + public ResultT dispatch(RestrictionTrackerParameter p) { + return dispatchDefault(p); + } + + @Override + public ResultT dispatch(StateParameter p) { + return dispatchDefault(p); + } + } + } + + // These parameter descriptors are constant + private static final BoundedWindowParameter BOUNDED_WINDOW_PARAMETER = + new AutoValue_DoFnSignature_Parameter_BoundedWindowParameter(); + private static final RestrictionTrackerParameter RESTRICTION_TRACKER_PARAMETER = + new AutoValue_DoFnSignature_Parameter_RestrictionTrackerParameter(); + private static final InputProviderParameter INPUT_PROVIDER_PARAMETER = + new AutoValue_DoFnSignature_Parameter_InputProviderParameter(); + private static final OutputReceiverParameter OUTPUT_RECEIVER_PARAMETER = + new AutoValue_DoFnSignature_Parameter_OutputReceiverParameter(); + + /** + * Returns a {@link BoundedWindowParameter}. + */ + public static BoundedWindowParameter boundedWindow() { + return BOUNDED_WINDOW_PARAMETER; + } + + /** + * Returns an {@link InputProviderParameter}. + */ + public static InputProviderParameter inputProvider() { + return INPUT_PROVIDER_PARAMETER; + } + + /** + * Returns an {@link OutputReceiverParameter}. + */ + public static OutputReceiverParameter outputReceiver() { + return OUTPUT_RECEIVER_PARAMETER; + } + + /** + * Returns a {@link RestrictionTrackerParameter}. + */ + public static RestrictionTrackerParameter restrictionTracker() { + return RESTRICTION_TRACKER_PARAMETER; + } + + /** + * Returns a {@link StateParameter} referring to the given {@link StateDeclaration}. + */ + public static StateParameter stateParameter(StateDeclaration decl) { + return new AutoValue_DoFnSignature_Parameter_StateParameter(decl); + } + + /** + * Descriptor for a {@link Parameter} of type {@link BoundedWindow}. + * + * <p>All such descriptors are equal. + */ + @AutoValue + public abstract static class BoundedWindowParameter extends Parameter { + BoundedWindowParameter() {} + } + + /** + * Descriptor for a {@link Parameter} of type {@link InputProvider}. + * + * <p>All such descriptors are equal. + */ + @AutoValue + public abstract static class InputProviderParameter extends Parameter { + InputProviderParameter() {} + } + + /** + * Descriptor for a {@link Parameter} of type {@link OutputReceiver}. + * + * <p>All such descriptors are equal. + */ + @AutoValue + public abstract static class OutputReceiverParameter extends Parameter { + OutputReceiverParameter() {} + } + + /** + * Descriptor for a {@link Parameter} of a subclass of {@link RestrictionTracker}. + * + * <p>All such descriptors are equal. + */ + @AutoValue + public abstract static class RestrictionTrackerParameter extends Parameter { + // Package visible for AutoValue + RestrictionTrackerParameter() {} + } + + /** + * Descriptor for a {@link Parameter} of a subclass of {@link State}, with an id indicated by + * its {@link StateId} annotation. + * + * <p>All descriptors for the same declared state are equal. + */ + @AutoValue + public abstract static class StateParameter extends Parameter { + // Package visible for AutoValue + StateParameter() {} + public abstract StateDeclaration referent(); + } } /** Describes a {@link DoFn.ProcessElement} method. */ @@ -157,16 +332,26 @@ public abstract class DoFnSignature { targetMethod, Collections.unmodifiableList(extraParameters), trackerT, hasReturnValue); } - /** Whether this {@link DoFn} uses a Single Window. */ - public boolean usesSingleWindow() { - return extraParameters().contains(Parameter.BOUNDED_WINDOW); + /** + * Whether this {@link DoFn} observes - directly or indirectly - the window that an element + * resides in. + * + * <p>{@link State} and {@link Timer} parameters indirectly observe the window, because + * they are each scoped to a single window. + */ + public boolean observesWindow() { + return Iterables.any( + extraParameters(), + Predicates.or( + Predicates.instanceOf(BoundedWindowParameter.class), + Predicates.instanceOf(StateParameter.class))); } /** * Whether this {@link DoFn} is <a href="https://s.apache.org/splittable-do-fn">splittable</a>. */ public boolean isSplittable() { - return extraParameters().contains(Parameter.RESTRICTION_TRACKER); + return extraParameters().contains(Parameter.restrictionTracker()); } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/9e6246e0/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index 04f50d3..038b55d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -44,6 +44,7 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.OnTimerMethod; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.TimerDeclaration; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; @@ -165,7 +166,8 @@ public class DoFnSignatures { errors.forMethod(DoFn.ProcessElement.class, processElementMethod); DoFnSignature.ProcessElementMethod processElement = analyzeProcessElementMethod( - processElementErrors, fnToken, processElementMethod, inputT, outputT); + processElementErrors, fnToken, processElementMethod, inputT, outputT, + stateDeclarations); builder.setProcessElement(processElement); if (startBundleMethod != null) { @@ -444,7 +446,8 @@ public class DoFnSignatures { TypeToken<? extends DoFn<?, ?>> fnClass, Method m, TypeToken<?> inputT, - TypeToken<?> outputT) { + TypeToken<?> outputT, + Map<String, StateDeclaration> stateDeclarations) { errors.checkArgument( void.class.equals(m.getReturnType()) || DoFn.ProcessContinuation.class.equals(m.getReturnType()), @@ -464,6 +467,7 @@ public class DoFnSignatures { formatType(processContextToken)); List<DoFnSignature.Parameter> extraParameters = new ArrayList<>(); + Map<String, DoFnSignature.Parameter> stateParameters = new HashMap<>(); TypeToken<?> trackerT = null; TypeToken<?> expectedInputProviderT = inputProviderTypeOf(inputT); @@ -473,13 +477,13 @@ public class DoFnSignatures { Class<?> rawType = paramT.getRawType(); if (rawType.equals(BoundedWindow.class)) { errors.checkArgument( - !extraParameters.contains(DoFnSignature.Parameter.BOUNDED_WINDOW), + !extraParameters.contains(DoFnSignature.Parameter.boundedWindow()), "Multiple %s parameters", BoundedWindow.class.getSimpleName()); - extraParameters.add(DoFnSignature.Parameter.BOUNDED_WINDOW); + extraParameters.add(DoFnSignature.Parameter.boundedWindow()); } else if (rawType.equals(DoFn.InputProvider.class)) { errors.checkArgument( - !extraParameters.contains(DoFnSignature.Parameter.INPUT_PROVIDER), + !extraParameters.contains(DoFnSignature.Parameter.inputProvider()), "Multiple %s parameters", DoFn.InputProvider.class.getSimpleName()); errors.checkArgument( @@ -488,10 +492,10 @@ public class DoFnSignatures { DoFn.InputProvider.class.getSimpleName(), formatType(paramT), formatType(expectedInputProviderT)); - extraParameters.add(DoFnSignature.Parameter.INPUT_PROVIDER); + extraParameters.add(DoFnSignature.Parameter.inputProvider()); } else if (rawType.equals(DoFn.OutputReceiver.class)) { errors.checkArgument( - !extraParameters.contains(DoFnSignature.Parameter.OUTPUT_RECEIVER), + !extraParameters.contains(DoFnSignature.Parameter.outputReceiver()), "Multiple %s parameters", DoFn.OutputReceiver.class.getSimpleName()); errors.checkArgument( @@ -500,14 +504,79 @@ public class DoFnSignatures { DoFn.OutputReceiver.class.getSimpleName(), formatType(paramT), formatType(expectedOutputReceiverT)); - extraParameters.add(DoFnSignature.Parameter.OUTPUT_RECEIVER); + extraParameters.add(DoFnSignature.Parameter.outputReceiver()); } else if (RestrictionTracker.class.isAssignableFrom(rawType)) { errors.checkArgument( - !extraParameters.contains(DoFnSignature.Parameter.RESTRICTION_TRACKER), + !extraParameters.contains(DoFnSignature.Parameter.restrictionTracker()), "Multiple %s parameters", RestrictionTracker.class.getSimpleName()); - extraParameters.add(DoFnSignature.Parameter.RESTRICTION_TRACKER); + extraParameters.add(DoFnSignature.Parameter.restrictionTracker()); trackerT = paramT; + } else if (State.class.isAssignableFrom(rawType)) { + // m.getParameters() is not available until Java 8 + Annotation[] annotations = m.getParameterAnnotations()[i]; + String id = null; + for (Annotation anno : annotations) { + if (anno.annotationType().equals(DoFn.StateId.class)) { + id = ((DoFn.StateId) anno).value(); + break; + } + } + errors.checkArgument( + id != null, + "%s parameter of type %s at index %s missing %s annotation", + fnClass.getRawType().getName(), + params[i], + i, + DoFn.StateId.class.getSimpleName()); + + errors.checkArgument( + !stateParameters.containsKey(id), + "%s parameter of type %s at index %s duplicates %s(\"%s\") on other parameter", + fnClass.getRawType().getName(), + params[i], + i, + DoFn.StateId.class.getSimpleName(), + id); + + // By static typing this is already a well-formed State subclass + TypeDescriptor<? extends State> stateType = + (TypeDescriptor<? extends State>) + TypeDescriptor.of(fnClass.getType()) + .resolveType(params[i]); + + StateDeclaration stateDecl = stateDeclarations.get(id); + errors.checkArgument( + stateDecl != null, + "%s parameter of type %s at index %s references undeclared %s \"%s\"", + fnClass.getRawType().getName(), + params[i], + i, + DoFn.StateId.class.getSimpleName(), + id); + + errors.checkArgument( + stateDecl.stateType().equals(stateType), + "%s parameter at index %s has type %s but is a reference to StateId %s of type %s", + fnClass.getRawType().getName(), + i, + params[i], + id, + stateDecl.stateType()); + + errors.checkArgument( + stateDecl.field().getDeclaringClass().equals(m.getDeclaringClass()), + "Method %s has State parameter at index %s for state %s" + + " declared in a different class %s." + + " State may be referenced only in the class where it is declared.", + m, + i, + id, + stateDecl.field().getDeclaringClass().getName()); + + DoFnSignature.Parameter.StateParameter stateParameter = Parameter.stateParameter(stateDecl); + stateParameters.put(id, stateParameter); + extraParameters.add(stateParameter); } else { List<String> allowedParamTypes = Arrays.asList( @@ -520,10 +589,10 @@ public class DoFnSignatures { } // A splittable DoFn can not have any other extra context parameters. - if (extraParameters.contains(DoFnSignature.Parameter.RESTRICTION_TRACKER)) { + if (extraParameters.contains(DoFnSignature.Parameter.restrictionTracker())) { errors.checkArgument( extraParameters.size() == 1, - "Splittable DoFn must not have any extra context arguments apart from %s, but has: %s", + "Splittable DoFn must not have any extra arguments apart from BoundedWindow, but has: %s", trackerT, extraParameters); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/9e6246e0/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java index 84b909f..68278c5 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java @@ -80,15 +80,15 @@ public class DoFnSignaturesSplittableDoFnTest { }); assertTrue(signature.isSplittable()); - assertTrue(signature.extraParameters().contains(DoFnSignature.Parameter.RESTRICTION_TRACKER)); + assertTrue(signature.extraParameters().contains(DoFnSignature.Parameter.restrictionTracker())); assertEquals(SomeRestrictionTracker.class, signature.trackerT().getRawType()); } @Test public void testSplittableProcessElementMustNotHaveOtherParams() throws Exception { thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("must not have any extra context arguments"); - thrown.expectMessage("BOUNDED_WINDOW"); + thrown.expectMessage("must not have any extra arguments"); + thrown.expectMessage("BoundedWindow"); DoFnSignature.ProcessElementMethod signature = analyzeProcessElementMethod( http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/9e6246e0/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java index 9813af5..0374775 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java @@ -20,12 +20,15 @@ package org.apache.beam.sdk.transforms.reflect; import static org.apache.beam.sdk.transforms.reflect.DoFnSignaturesTestUtils.errors; import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; import com.google.common.reflect.TypeToken; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.DoFn.TimerId; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignaturesTestUtils.FakeDoFn; import org.apache.beam.sdk.util.TimeDomain; import org.apache.beam.sdk.util.TimerSpec; @@ -33,6 +36,7 @@ import org.apache.beam.sdk.util.TimerSpecs; import org.apache.beam.sdk.util.state.StateSpec; import org.apache.beam.sdk.util.state.StateSpecs; import org.apache.beam.sdk.util.state.ValueState; +import org.apache.beam.sdk.util.state.WatermarkHoldState; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TypeDescriptor; import org.hamcrest.Matchers; @@ -378,6 +382,103 @@ public class DoFnSignaturesTest { } @Test + public void testStateParameterNoAnnotation() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("missing StateId annotation"); + thrown.expectMessage("myProcessElement"); + thrown.expectMessage("index 1"); + DoFnSignature sig = + DoFnSignatures.INSTANCE.getSignature( + new DoFn<KV<String, Integer>, Long>() { + @ProcessElement + public void myProcessElement( + ProcessContext context, ValueState<Integer> noAnnotation) {} + }.getClass()); + } + + @Test + public void testStateParameterUndeclared() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("undeclared"); + thrown.expectMessage("my-state-id"); + thrown.expectMessage("myProcessElement"); + thrown.expectMessage("index 1"); + DoFnSignature sig = + DoFnSignatures.INSTANCE.getSignature( + new DoFn<KV<String, Integer>, Long>() { + @ProcessElement + public void myProcessElement( + ProcessContext context, @StateId("my-state-id") ValueState<Integer> undeclared) {} + }.getClass()); + } + + @Test + public void testStateParameterDuplicate() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("duplicates"); + thrown.expectMessage("my-state-id"); + thrown.expectMessage("myProcessElement"); + thrown.expectMessage("index 2"); + DoFnSignature sig = + DoFnSignatures.INSTANCE.getSignature( + new DoFn<KV<String, Integer>, Long>() { + @StateId("my-state-id") + private final StateSpec<Object, ValueState<Integer>> myfield = + StateSpecs.value(VarIntCoder.of()); + + @ProcessElement + public void myProcessElement( + ProcessContext context, + @StateId("my-state-id") ValueState<Integer> one, + @StateId("my-state-id") ValueState<Integer> two) {} + }.getClass()); + } + + @Test + public void testStateParameterWrongStateType() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("WatermarkHoldState"); + thrown.expectMessage("but is a reference to"); + thrown.expectMessage("ValueState"); + thrown.expectMessage("my-state-id"); + thrown.expectMessage("myProcessElement"); + thrown.expectMessage("index 1"); + DoFnSignature sig = + DoFnSignatures.INSTANCE.getSignature( + new DoFn<KV<String, Integer>, Long>() { + @StateId("my-state-id") + private final StateSpec<Object, ValueState<Integer>> myfield = + StateSpecs.value(VarIntCoder.of()); + + @ProcessElement + public void myProcessElement( + ProcessContext context, @StateId("my-state-id") WatermarkHoldState watermark) {} + }.getClass()); + } + + @Test + public void testStateParameterWrongGenericType() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("ValueState<java.lang.String>"); + thrown.expectMessage("but is a reference to"); + thrown.expectMessage("ValueState<java.lang.Integer>"); + thrown.expectMessage("my-state-id"); + thrown.expectMessage("myProcessElement"); + thrown.expectMessage("index 1"); + DoFnSignature sig = + DoFnSignatures.INSTANCE.getSignature( + new DoFn<KV<String, Integer>, Long>() { + @StateId("my-state-id") + private final StateSpec<Object, ValueState<Integer>> myfield = + StateSpecs.value(VarIntCoder.of()); + + @ProcessElement + public void myProcessElement( + ProcessContext context, @StateId("my-state-id") ValueState<String> stringState) {} + }.getClass()); + } + + @Test public void testSimpleStateIdAnonymousDoFn() throws Exception { DoFnSignature sig = DoFnSignatures.INSTANCE.getSignature( @@ -401,6 +502,97 @@ public class DoFnSignaturesTest { } @Test + public void testUsageOfStateDeclaredInSuperclass() throws Exception { + DoFnDeclaringState fn = + new DoFnDeclaringState() { + @ProcessElement + public void process( + ProcessContext context, + @StateId(DoFnDeclaringState.STATE_ID) ValueState<Integer> state) {} + }; + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("process"); + thrown.expectMessage("declared in a different class"); + thrown.expectMessage(DoFnDeclaringState.STATE_ID); + thrown.expectMessage(fn.getClass().getSimpleName()); + DoFnSignature sig = DoFnSignatures.INSTANCE.getSignature(fn.getClass()); + } + + @Test + public void testDeclOfStateUsedInSuperclass() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("process"); + thrown.expectMessage("declared in a different class"); + thrown.expectMessage(DoFnUsingState.STATE_ID); + DoFnSignature sig = + DoFnSignatures.INSTANCE.getSignature( + new DoFnUsingState() { + @StateId(DoFnUsingState.STATE_ID) + private final StateSpec<Object, ValueState<Integer>> spec = + StateSpecs.value(VarIntCoder.of()); + }.getClass()); + } + + @Test + public void testDeclAndUsageOfStateInSuperclass() throws Exception { + DoFnSignature sig = + DoFnSignatures.INSTANCE.getSignature( + new DoFnOverridingAbstractStateUse().getClass()); + + assertThat(sig.stateDeclarations().size(), equalTo(1)); + assertThat(sig.processElement().extraParameters().size(), equalTo(1)); + + DoFnSignature.StateDeclaration decl = + sig.stateDeclarations().get(DoFnOverridingAbstractStateUse.STATE_ID); + StateParameter stateParam = + (StateParameter) sig.processElement().extraParameters().get(0); + + assertThat( + decl.field(), + equalTo(DoFnDeclaringStateAndAbstractUse.class.getDeclaredField("myStateSpec"))); + + // The method we pull out is the superclass method; this is what allows validation to remain + // simple. The later invokeDynamic instruction causes it to invoke the actual implementation. + assertThat(stateParam.referent(), equalTo(decl)); + } + + /** + * Assuming the proper parsing of declarations, testing elsewhere, this test ensures that + * a simple reference to such a declaration is correctly resolved. + */ + @Test + public void testSimpleStateIdRefAnonymousDoFn() throws Exception { + DoFnSignature sig = + DoFnSignatures.INSTANCE.getSignature( + new DoFn<KV<String, Integer>, Long>() { + @StateId("foo") + private final StateSpec<Object, ValueState<Integer>> bizzleDecl = + StateSpecs.value(VarIntCoder.of()); + + @ProcessElement + public void foo(ProcessContext context, @StateId("foo") ValueState<Integer> bizzle) {} + }.getClass()); + + assertThat(sig.processElement().extraParameters().size(), equalTo(1)); + + final DoFnSignature.StateDeclaration decl = sig.stateDeclarations().get("foo"); + sig.processElement().extraParameters().get(0).match(new Parameter.Cases.WithDefault<Void>() { + @Override + protected Void dispatchDefault(Parameter p) { + fail(String.format("Expected a state parameter but got %s", p)); + return null; + } + + @Override + public Void dispatch(StateParameter stateParam) { + assertThat(stateParam.referent(), equalTo(decl)); + return null; + } + }); + } + + @Test public void testSimpleStateIdNamedDoFn() throws Exception { // Test classes at the bottom of the file DoFnSignature sig = @@ -445,6 +637,40 @@ public class DoFnSignaturesTest { public void foo(ProcessContext context) {} } + private abstract static class DoFnDeclaringState extends DoFn<KV<String, Integer>, Long> { + + public static final String STATE_ID = "my-state-id"; + + @StateId(STATE_ID) + private final StateSpec<Object, ValueState<Integer>> bizzle = + StateSpecs.value(VarIntCoder.of()); + } + + private abstract static class DoFnUsingState extends DoFn<KV<String, Integer>, Long> { + public static final String STATE_ID = "my-state-id"; + @ProcessElement + public void process(ProcessContext context, @StateId(STATE_ID) ValueState<Integer> state) {} + } + + private abstract static class DoFnDeclaringStateAndAbstractUse + extends DoFn<KV<String, Integer>, Long> { + public static final String STATE_ID = "my-state-id"; + @StateId(STATE_ID) + private final StateSpec<Object, ValueState<String>> myStateSpec = + StateSpecs.value(StringUtf8Coder.of()); + + @ProcessElement + public abstract void processWithState( + ProcessContext context, @StateId(STATE_ID) ValueState<String> state); + } + + private static class DoFnOverridingAbstractStateUse extends + DoFnDeclaringStateAndAbstractUse { + + @Override + public void processWithState(ProcessContext c, ValueState<String> state) {} + } + private static class DoFnForTestGenericStatefulDoFn<T> extends DoFn<KV<String, T>, Long> { // Note that in order to have a coder for T it will require initialization in the constructor, // but that isn't important for this test http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/9e6246e0/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java index 88dc423..c276692 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java @@ -19,6 +19,7 @@ package org.apache.beam.sdk.transforms.reflect; import com.google.common.reflect.TypeToken; import java.lang.reflect.Method; +import java.util.Collections; import java.util.NoSuchElementException; import org.apache.beam.sdk.transforms.DoFn; @@ -59,6 +60,7 @@ class DoFnSignaturesTestUtils { TypeToken.of(FakeDoFn.class), method.getMethod(), TypeToken.of(Integer.class), - TypeToken.of(String.class)); + TypeToken.of(String.class), + Collections.EMPTY_MAP); } }
