Repository: incubator-beam Updated Branches: refs/heads/master 5bfeb958d -> a0f649eac
Add DoFn.StateId annotation and validation on fields Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/add34518 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/add34518 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/add34518 Branch: refs/heads/master Commit: add34518bbfb6668a421157e3c1cfaa119a6031b Parents: 7322616 Author: Kenneth Knowles <k...@google.com> Authored: Mon Oct 10 21:16:37 2016 -0700 Committer: Kenneth Knowles <k...@google.com> Committed: Thu Oct 13 19:31:20 2016 -0700 ---------------------------------------------------------------------- .../org/apache/beam/sdk/transforms/DoFn.java | 44 ++++++ .../org/apache/beam/sdk/transforms/ParDo.java | 32 ++++- .../sdk/transforms/reflect/DoFnSignature.java | 25 ++++ .../sdk/transforms/reflect/DoFnSignatures.java | 129 +++++++++++++++-- .../apache/beam/sdk/transforms/ParDoTest.java | 27 ++++ .../transforms/reflect/DoFnSignaturesTest.java | 143 +++++++++++++++++++ 6 files changed, 388 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/add34518/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java index 62da28c..c86693b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java @@ -44,6 +44,8 @@ import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowingInternals; +import org.apache.beam.sdk.util.state.State; +import org.apache.beam.sdk.util.state.StateSpec; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; @@ -394,6 +396,48 @@ public abstract class DoFn<InputT, OutputT> implements Serializable, HasDisplayD ///////////////////////////////////////////////////////////////////////////// + /** + * Annotation for declaring and dereferencing state cells. + * + * <p><i>Not currently supported by any runner</i>. + * + * <p>To declare a state cell, create a field of type {@link StateSpec} annotated with a {@link + * StateId}. To use the cell during processing, add a parameter of the appropriate {@link State} + * subclass to your {@link ProcessElement @ProcessElement} method, and annotate it with {@link + * StateId}. See the following code for an example: + * + * <pre>{@code + * new DoFn<KV<Key, Foo>, Baz>() { + * @StateId("my-state-id") + * private final StateSpec<K, ValueState<MyState>> myStateSpec = + * StateSpecs.value(new MyStateCoder()); + * + * @ProcessElement + * public void processElement( + * ProcessContext c, + * @StateId("my-state-id") ValueState<MyState> myState) { + * myState.read(); + * myState.write(...); + * } + * } + * }</pre> + * + * <p>State is subject to the following validity conditions: + * + * <ul> + * <li>Each state ID must be declared at most once. + * <li>Any state referenced in a parameter must be declared with the same state type. + * <li>State declarations must be final. + * </ul> + */ + @Documented + @Retention(RetentionPolicy.RUNTIME) + @Target({ElementType.FIELD, ElementType.PARAMETER}) + @Experimental(Kind.STATE) + public @interface StateId { + /** The state ID. */ + String value(); + } /** * Annotation for the method to use to prepare an instance for processing bundles of elements. The http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/add34518/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index fdef908..c5a80c6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -29,6 +29,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.display.DisplayData.Builder; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.SerializableUtils; @@ -519,6 +520,7 @@ public class ParDo { * properties can be set on it first. */ public static <InputT, OutputT> Bound<InputT, OutputT> of(DoFn<InputT, OutputT> fn) { + validate(fn); return of(adapt(fn), fn.getClass()); } @@ -544,8 +546,32 @@ public class ParDo { return new Unbound().of(fn, fnClass); } - private static <InputT, OutputT> OldDoFn<InputT, OutputT> - adapt(DoFn<InputT, OutputT> fn) { + /** + * Perform common validations of the {@link DoFn}, for example ensuring that state is used + * correctly and that its features can be supported. + */ + private static <InputT, OutputT> void validate(DoFn<InputT, OutputT> fn) { + DoFnSignature signature = DoFnSignatures.INSTANCE.getOrParseSignature((Class) fn.getClass()); + + // To be removed when the features are complete and runners have their own adequate + // rejection logic + if (!signature.stateDeclarations().isEmpty()) { + throw new UnsupportedOperationException( + String.format("Found %s annotations on %s, but %s cannot yet be used with state.", + DoFn.StateId.class.getSimpleName(), + fn.getClass().getName(), + DoFn.class.getSimpleName())); + } + + // State is semantically incompatible with splitting + if (!signature.stateDeclarations().isEmpty()) { + throw new UnsupportedOperationException( + String.format("%s is splittable and uses state, but these are not compatible", + fn.getClass().getName())); + } + } + + private static <InputT, OutputT> OldDoFn<InputT, OutputT> adapt(DoFn<InputT, OutputT> fn) { return DoFnAdapters.toOldDoFn(fn); } @@ -622,6 +648,7 @@ public class ParDo { * still be specified. */ public <InputT, OutputT> Bound<InputT, OutputT> of(DoFn<InputT, OutputT> fn) { + validate(fn); return of(adapt(fn), fn.getClass()); } @@ -838,6 +865,7 @@ public class ParDo { * more properties can still be specified. */ public <InputT> BoundMulti<InputT, OutputT> of(DoFn<InputT, OutputT> fn) { + validate(fn); return of(adapt(fn), fn.getClass()); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/add34518/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 632f817..5e261a4 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -19,15 +19,20 @@ package org.apache.beam.sdk.transforms.reflect; import com.google.auto.value.AutoValue; import com.google.common.reflect.TypeToken; +import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.Collections; import java.util.List; +import java.util.Map; import javax.annotation.Nullable; +import javax.swing.plaf.nimbus.State; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; +import org.apache.beam.sdk.util.state.StateSpec; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; /** * Describes the signature of a {@link DoFn}, in particular, which features it uses, which extra @@ -46,6 +51,9 @@ public abstract class DoFnSignature { /** Details about this {@link DoFn}'s {@link DoFn.ProcessElement} method. */ public abstract ProcessElementMethod processElement(); + /** Details about the state cells that this {@link DoFn} declares. Immutable. */ + public abstract Map<String, StateDeclaration> stateDeclarations(); + /** Details about this {@link DoFn}'s {@link DoFn.StartBundle} method. */ @Nullable public abstract BundleMethod startBundle(); @@ -95,6 +103,7 @@ public abstract class DoFnSignature { abstract Builder setSplitRestriction(SplitRestrictionMethod splitRestriction); abstract Builder setGetRestrictionCoder(GetRestrictionCoderMethod getRestrictionCoder); abstract Builder setNewTracker(NewTrackerMethod newTracker); + abstract Builder setStateDeclarations(Map<String, StateDeclaration> stateDeclarations); abstract DoFnSignature build(); } @@ -163,6 +172,22 @@ public abstract class DoFnSignature { } } + /** + * Describes a state declaration; a field of type {@link StateSpec} annotated with + * {@link DoFn.StateId}. + */ + @AutoValue + public abstract static class StateDeclaration { + public abstract String id(); + public abstract Field field(); + public abstract TypeDescriptor<? extends State<?>> stateType(); + + static StateDeclaration create( + String id, Field field, TypeDescriptor<? extends State<?>> stateType) { + return new AutoValue_DoFnSignature_StateDeclaration(id, field, stateType); + } + } + /** Describes a {@link DoFn.Setup} or {@link DoFn.Teardown} method. */ @AutoValue public abstract static class LifecycleMethod implements DoFnMethod { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/add34518/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index 524ea24..b7f773d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -20,9 +20,12 @@ package org.apache.beam.sdk.transforms.reflect; import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; import com.google.common.reflect.TypeParameter; import com.google.common.reflect.TypeToken; import java.lang.annotation.Annotation; +import java.lang.reflect.AnnotatedElement; +import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.lang.reflect.ParameterizedType; @@ -30,18 +33,21 @@ import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import javax.annotation.Nullable; +import javax.swing.plaf.nimbus.State; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.apache.beam.sdk.util.state.StateSpec; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; /** * Parses a {@link DoFn} and computes its {@link DoFnSignature}. See {@link #getOrParseSignature}. @@ -53,7 +59,12 @@ public class DoFnSignatures { private final Map<Class<?>, DoFnSignature> signatureCache = new LinkedHashMap<>(); - /** @return the {@link DoFnSignature} for the given {@link DoFn}. */ + /** @return the {@link DoFnSignature} for the given {@link DoFn} instance. */ + public <FnT extends DoFn<?, ?>> DoFnSignature signatureForDoFn(FnT fn) { + return getOrParseSignature(fn.getClass()); + } + + /** @return the {@link DoFnSignature} for the given {@link DoFn} subclass. */ public synchronized <FnT extends DoFn<?, ?>> DoFnSignature getOrParseSignature(Class<FnT> fn) { DoFnSignature signature = signatureCache.get(fn); if (signature == null) { @@ -172,6 +183,8 @@ public class DoFnSignatures { builder.setIsBoundedPerElement(inferBoundedness(fnToken, processElement, errors)); + builder.setStateDeclarations(analyzeStateDeclarations(errors, fnClass)); + DoFnSignature signature = builder.build(); // Additional validation for splittable DoFn's. @@ -592,33 +605,129 @@ public class DoFnSignatures { private static Collection<Method> declaredMethodsWithAnnotation( Class<? extends Annotation> anno, Class<?> startClass, Class<?> stopClass) { - Collection<Method> matches = new ArrayList<>(); + return declaredMembersWithAnnotation(anno, startClass, stopClass, GET_METHODS); + } + + private static Collection<Field> declaredFieldsWithAnnotation( + Class<? extends Annotation> anno, Class<?> startClass, Class<?> stopClass) { + return declaredMembersWithAnnotation(anno, startClass, stopClass, GET_FIELDS); + } + + private static interface MemberGetter<MemberT> { + public MemberT[] getMembers(Class<?> clazz); + } + + // Class::getDeclaredMethods for Java 7 + private static final MemberGetter<Method> GET_METHODS = + new MemberGetter<Method>() { + @Override + public Method[] getMembers(Class<?> clazz) { + return clazz.getDeclaredMethods(); + } + }; + + // Class::getDeclaredFields for Java 7 + private static final MemberGetter<Field> GET_FIELDS = + new MemberGetter<Field>() { + @Override + public Field[] getMembers(Class<?> clazz) { + return clazz.getDeclaredFields(); + } + }; + + private static <MemberT extends AnnotatedElement> + Collection<MemberT> declaredMembersWithAnnotation( + Class<? extends Annotation> anno, + Class<?> startClass, + Class<?> stopClass, + MemberGetter<MemberT> getter) { + Collection<MemberT> matches = new ArrayList<>(); Class<?> clazz = startClass; LinkedHashSet<Class<?>> interfaces = new LinkedHashSet<>(); // First, find all declared methods on the startClass and parents (up to stopClass) while (clazz != null && !clazz.equals(stopClass)) { - for (Method method : clazz.getDeclaredMethods()) { - if (method.isAnnotationPresent(anno)) { - matches.add(method); + for (MemberT member : getter.getMembers(clazz)) { + if (member.isAnnotationPresent(anno)) { + matches.add(member); } } - Collections.addAll(interfaces, clazz.getInterfaces()); + // Add all interfaces, including transitive + for (TypeDescriptor<?> iface : TypeDescriptor.of(clazz).getInterfaces()) { + interfaces.add(iface.getRawType()); + } clazz = clazz.getSuperclass(); } // Now, iterate over all the discovered interfaces - for (Method method : ReflectHelpers.getClosureOfMethodsOnInterfaces(interfaces)) { - if (method.isAnnotationPresent(anno)) { - matches.add(method); + for (Class<?> iface : interfaces) { + for (MemberT member : getter.getMembers(iface)) { + if (member.isAnnotationPresent(anno)) { + matches.add(member); + } } } return matches; } + private static ImmutableMap<String, DoFnSignature.StateDeclaration> analyzeStateDeclarations( + ErrorReporter errors, + Class<?> fnClazz) { + + Map<String, DoFnSignature.StateDeclaration> declarations = new HashMap<>(); + + for (Field field : declaredFieldsWithAnnotation(DoFn.StateId.class, fnClazz, DoFn.class)) { + String id = field.getAnnotation(DoFn.StateId.class).value(); + + if (declarations.containsKey(id)) { + errors.throwIllegalArgument( + "Duplicate %s \"%s\", used on both of [%s] and [%s]", + DoFn.StateId.class.getSimpleName(), + id, + field.toString(), + declarations.get(id).field().toString()); + continue; + } + + Class<?> stateSpecRawType = field.getType(); + if (!(stateSpecRawType.equals(StateSpec.class))) { + errors.throwIllegalArgument( + "%s annotation on non-%s field [%s] that has class %s", + DoFn.StateId.class.getSimpleName(), + StateSpec.class.getSimpleName(), + field.toString(), + stateSpecRawType.getName()); + continue; + } + + if (!Modifier.isFinal(field.getModifiers())) { + errors.throwIllegalArgument( + "Non-final field %s annotated with %s. State declarations must be final.", + field.toString(), + DoFn.StateId.class.getSimpleName()); + continue; + } + + Type stateSpecType = field.getGenericType(); + + // By static typing this is already a well-formed State subclass + TypeDescriptor<? extends State<?>> stateType = + (TypeDescriptor<? extends State<?>>) + TypeDescriptor.of(fnClazz) + .resolveType( + TypeDescriptor.of( + ((ParameterizedType) stateSpecType).getActualTypeArguments()[1]) + .getType()); + + declarations.put(id, DoFnSignature.StateDeclaration.create(id, field, stateType)); + } + + return ImmutableMap.copyOf(declarations); + } + private static Method findAnnotatedMethod( ErrorReporter errors, Class<? extends Annotation> anno, Class<?> fnClazz, boolean required) { Collection<Method> matches = declaredMethodsWithAnnotation(anno, fnClazz, DoFn.class); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/add34518/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 9c7b991..bda696f 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -60,6 +60,10 @@ import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.util.common.ElementByteSizeObserver; +import org.apache.beam.sdk.util.state.StateSpec; +import org.apache.beam.sdk.util.state.StateSpecs; +import org.apache.beam.sdk.util.state.ValueState; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; @@ -1450,6 +1454,29 @@ public class ParDoTest implements Serializable { assertThat(displayData, hasDisplayItem("fn", fn.getClass())); } + /** + * A test that we properly reject {@link DoFn} implementations that + * include {@link DoFn.StateId} annotations, for now. + */ + @Test + public void testUnsupportedState() { + thrown.expect(UnsupportedOperationException.class); + thrown.expectMessage("cannot yet be used with state"); + + DoFn<KV<String, String>, KV<String, String>> fn = + new DoFn<KV<String, String>, KV<String, String>>() { + + @StateId("foo") + private final StateSpec<Object, ValueState<Integer>> intState = + StateSpecs.value(VarIntCoder.of()); + + @ProcessElement + public void processElement(ProcessContext c) { } + }; + + ParDo.of(fn); + } + @Test public void testWithOutputTagsDisplayData() { DoFn<String, String> fn = new DoFn<String, String>() { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/add34518/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java index fc468c9..e040179 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java @@ -18,10 +18,20 @@ package org.apache.beam.sdk.transforms.reflect; import static org.apache.beam.sdk.transforms.reflect.DoFnSignaturesTestUtils.errors; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; import com.google.common.reflect.TypeToken; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.reflect.DoFnSignaturesTestUtils.FakeDoFn; +import org.apache.beam.sdk.util.state.StateSpec; +import org.apache.beam.sdk.util.state.StateSpecs; +import org.apache.beam.sdk.util.state.ValueState; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.hamcrest.Matchers; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -120,4 +130,137 @@ public class DoFnSignaturesTest { void finishBundle() {} }.getClass()); } + + @Test + public void testStateIdWithWrongType() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("StateId"); + thrown.expectMessage("StateSpec"); + DoFnSignatures.INSTANCE.getOrParseSignature( + new DoFn<String, String>() { + @StateId("foo") + String bizzle = "bazzle"; + + @ProcessElement + public void foo(ProcessContext context) {} + }.getClass()); + } + + @Test + public void testStateIdDuplicate() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Duplicate"); + thrown.expectMessage("StateId"); + thrown.expectMessage("my-state-id"); + thrown.expectMessage("myfield1"); + thrown.expectMessage("myfield2"); + DoFnSignature sig = + DoFnSignatures.INSTANCE.getOrParseSignature( + new DoFn<KV<String, Integer>, Long>() { + @StateId("my-state-id") + private final StateSpec<Object, ValueState<Integer>> myfield1 = + StateSpecs.value(VarIntCoder.of()); + + @StateId("my-state-id") + StateSpec<Object, ValueState<Long>> myfield2 = StateSpecs.value(VarLongCoder.of()); + + @ProcessElement + public void foo(ProcessContext context) {} + }.getClass()); + } + + @Test + public void testStateIdNonFinal() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("State declarations must be final"); + thrown.expectMessage("Non-final field"); + thrown.expectMessage("myfield"); + DoFnSignature sig = + DoFnSignatures.INSTANCE.getOrParseSignature( + new DoFn<KV<String, Integer>, Long>() { + @StateId("my-state-id") + private StateSpec<Object, ValueState<Integer>> myfield = + StateSpecs.value(VarIntCoder.of()); + + @ProcessElement + public void foo(ProcessContext context) {} + }.getClass()); + } + + @Test + public void testSimpleStateIdAnonymousDoFn() throws Exception { + DoFnSignature sig = + DoFnSignatures.INSTANCE.getOrParseSignature( + new DoFn<KV<String, Integer>, Long>() { + @StateId("foo") + private final StateSpec<Object, ValueState<Integer>> bizzle = + StateSpecs.value(VarIntCoder.of()); + + @ProcessElement + public void foo(ProcessContext context) {} + }.getClass()); + + assertThat(sig.stateDeclarations().size(), equalTo(1)); + DoFnSignature.StateDeclaration decl = sig.stateDeclarations().get("foo"); + + assertThat(decl.id(), equalTo("foo")); + assertThat(decl.field().getName(), equalTo("bizzle")); + assertThat( + decl.stateType(), + Matchers.<TypeDescriptor<?>>equalTo(new TypeDescriptor<ValueState<Integer>>() {})); + } + + @Test + public void testSimpleStateIdNamedDoFn() throws Exception { + // Test classes at the bottom of the file + DoFnSignature sig = + DoFnSignatures.INSTANCE.signatureForDoFn(new DoFnForTestSimpleStateIdNamedDoFn()); + + assertThat(sig.stateDeclarations().size(), equalTo(1)); + DoFnSignature.StateDeclaration decl = sig.stateDeclarations().get("foo"); + + assertThat(decl.id(), equalTo("foo")); + assertThat( + decl.field(), equalTo(DoFnForTestSimpleStateIdNamedDoFn.class.getDeclaredField("bizzle"))); + assertThat( + decl.stateType(), + Matchers.<TypeDescriptor<?>>equalTo(new TypeDescriptor<ValueState<Integer>>() {})); + } + + @Test + public void testGenericStatefulDoFn() throws Exception { + // Test classes at the bottom of the file + DoFn<KV<String, Integer>, Long> myDoFn = new DoFnForTestGenericStatefulDoFn<Integer>(){}; + + DoFnSignature sig = DoFnSignatures.INSTANCE.signatureForDoFn(myDoFn); + + assertThat(sig.stateDeclarations().size(), equalTo(1)); + DoFnSignature.StateDeclaration decl = sig.stateDeclarations().get("foo"); + + assertThat(decl.id(), equalTo("foo")); + assertThat( + decl.field(), equalTo(DoFnForTestGenericStatefulDoFn.class.getDeclaredField("bizzle"))); + assertThat( + decl.stateType(), + Matchers.<TypeDescriptor<?>>equalTo(new TypeDescriptor<ValueState<Integer>>() {})); + } + + private static class DoFnForTestSimpleStateIdNamedDoFn extends DoFn<KV<String, Integer>, Long> { + @StateId("foo") + private final StateSpec<Object, ValueState<Integer>> bizzle = + StateSpecs.value(VarIntCoder.of()); + + @ProcessElement + public void foo(ProcessContext context) {} + } + + private static class DoFnForTestGenericStatefulDoFn<T> extends DoFn<KV<String, T>, Long> { + // Note that in order to have a coder for T it will require initialization in the constructor, + // but that isn't important for this test + @StateId("foo") + private final StateSpec<Object, ValueState<T>> bizzle = null; + + @ProcessElement + public void foo(ProcessContext context) {} + } }