Repository: incubator-beam
Updated Branches:
  refs/heads/master 5bfeb958d -> a0f649eac


Add DoFn.StateId annotation and validation on fields


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/add34518
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/add34518
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/add34518

Branch: refs/heads/master
Commit: add34518bbfb6668a421157e3c1cfaa119a6031b
Parents: 7322616
Author: Kenneth Knowles <k...@google.com>
Authored: Mon Oct 10 21:16:37 2016 -0700
Committer: Kenneth Knowles <k...@google.com>
Committed: Thu Oct 13 19:31:20 2016 -0700

----------------------------------------------------------------------
 .../org/apache/beam/sdk/transforms/DoFn.java    |  44 ++++++
 .../org/apache/beam/sdk/transforms/ParDo.java   |  32 ++++-
 .../sdk/transforms/reflect/DoFnSignature.java   |  25 ++++
 .../sdk/transforms/reflect/DoFnSignatures.java  | 129 +++++++++++++++--
 .../apache/beam/sdk/transforms/ParDoTest.java   |  27 ++++
 .../transforms/reflect/DoFnSignaturesTest.java  | 143 +++++++++++++++++++
 6 files changed, 388 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/add34518/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
index 62da28c..c86693b 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
@@ -44,6 +44,8 @@ import 
org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.util.WindowingInternals;
+import org.apache.beam.sdk.util.state.State;
+import org.apache.beam.sdk.util.state.StateSpec;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
@@ -394,6 +396,48 @@ public abstract class DoFn<InputT, OutputT> implements 
Serializable, HasDisplayD
 
   /////////////////////////////////////////////////////////////////////////////
 
+  /**
+   * Annotation for declaring and dereferencing state cells.
+   *
+   * <p><i>Not currently supported by any runner</i>.
+   *
+   * <p>To declare a state cell, create a field of type {@link StateSpec} 
annotated with a {@link
+   * StateId}. To use the cell during processing, add a parameter of the 
appropriate {@link State}
+   * subclass to your {@link ProcessElement @ProcessElement} method, and 
annotate it with {@link
+   * StateId}. See the following code for an example:
+   *
+   * <pre>{@code
+   * new DoFn<KV<Key, Foo>, Baz>() {
+   *   @StateId("my-state-id")
+   *   private final StateSpec<K, ValueState<MyState>> myStateSpec =
+   *       StateSpecs.value(new MyStateCoder());
+   *
+   *   @ProcessElement
+   *   public void processElement(
+   *       ProcessContext c,
+   *       @StateId("my-state-id") ValueState<MyState> myState) {
+   *     myState.read();
+   *     myState.write(...);
+   *   }
+   * }
+   * }</pre>
+   *
+   * <p>State is subject to the following validity conditions:
+   *
+   * <ul>
+   * <li>Each state ID must be declared at most once.
+   * <li>Any state referenced in a parameter must be declared with the same 
state type.
+   * <li>State declarations must be final.
+   * </ul>
+   */
+  @Documented
+  @Retention(RetentionPolicy.RUNTIME)
+  @Target({ElementType.FIELD, ElementType.PARAMETER})
+  @Experimental(Kind.STATE)
+  public @interface StateId {
+    /** The state ID. */
+    String value();
+  }
 
   /**
    * Annotation for the method to use to prepare an instance for processing 
bundles of elements. The

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/add34518/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
index fdef908..c5a80c6 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
@@ -29,6 +29,7 @@ import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.transforms.display.DisplayData.Builder;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
 import org.apache.beam.sdk.transforms.windowing.WindowFn;
 import org.apache.beam.sdk.util.SerializableUtils;
@@ -519,6 +520,7 @@ public class ParDo {
    * properties can be set on it first.
    */
   public static <InputT, OutputT> Bound<InputT, OutputT> of(DoFn<InputT, 
OutputT> fn) {
+    validate(fn);
     return of(adapt(fn), fn.getClass());
   }
 
@@ -544,8 +546,32 @@ public class ParDo {
     return new Unbound().of(fn, fnClass);
   }
 
-  private static <InputT, OutputT> OldDoFn<InputT, OutputT>
-      adapt(DoFn<InputT, OutputT> fn) {
+  /**
+   * Perform common validations of the {@link DoFn}, for example ensuring that 
state is used
+   * correctly and that its features can be supported.
+   */
+  private static <InputT, OutputT> void validate(DoFn<InputT, OutputT> fn) {
+    DoFnSignature signature = 
DoFnSignatures.INSTANCE.getOrParseSignature((Class) fn.getClass());
+
+    // To be removed when the features are complete and runners have their own 
adequate
+    // rejection logic
+    if (!signature.stateDeclarations().isEmpty()) {
+      throw new UnsupportedOperationException(
+          String.format("Found %s annotations on %s, but %s cannot yet be used 
with state.",
+              DoFn.StateId.class.getSimpleName(),
+              fn.getClass().getName(),
+              DoFn.class.getSimpleName()));
+    }
+
+    // State is semantically incompatible with splitting
+    if (!signature.stateDeclarations().isEmpty()) {
+      throw new UnsupportedOperationException(
+          String.format("%s is splittable and uses state, but these are not 
compatible",
+              fn.getClass().getName()));
+    }
+  }
+
+  private static <InputT, OutputT> OldDoFn<InputT, OutputT> adapt(DoFn<InputT, 
OutputT> fn) {
     return DoFnAdapters.toOldDoFn(fn);
   }
 
@@ -622,6 +648,7 @@ public class ParDo {
      * still be specified.
      */
     public <InputT, OutputT> Bound<InputT, OutputT> of(DoFn<InputT, OutputT> 
fn) {
+      validate(fn);
       return of(adapt(fn), fn.getClass());
     }
 
@@ -838,6 +865,7 @@ public class ParDo {
      * more properties can still be specified.
      */
     public <InputT> BoundMulti<InputT, OutputT> of(DoFn<InputT, OutputT> fn) {
+      validate(fn);
       return of(adapt(fn), fn.getClass());
     }
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/add34518/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
index 632f817..5e261a4 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
@@ -19,15 +19,20 @@ package org.apache.beam.sdk.transforms.reflect;
 
 import com.google.auto.value.AutoValue;
 import com.google.common.reflect.TypeToken;
+import java.lang.reflect.Field;
 import java.lang.reflect.Method;
 import java.util.Collections;
 import java.util.List;
+import java.util.Map;
 import javax.annotation.Nullable;
+import javax.swing.plaf.nimbus.State;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.util.state.StateSpec;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TypeDescriptor;
 
 /**
  * Describes the signature of a {@link DoFn}, in particular, which features it 
uses, which extra
@@ -46,6 +51,9 @@ public abstract class DoFnSignature {
   /** Details about this {@link DoFn}'s {@link DoFn.ProcessElement} method. */
   public abstract ProcessElementMethod processElement();
 
+  /** Details about the state cells that this {@link DoFn} declares. 
Immutable. */
+  public abstract Map<String, StateDeclaration> stateDeclarations();
+
   /** Details about this {@link DoFn}'s {@link DoFn.StartBundle} method. */
   @Nullable
   public abstract BundleMethod startBundle();
@@ -95,6 +103,7 @@ public abstract class DoFnSignature {
     abstract Builder setSplitRestriction(SplitRestrictionMethod 
splitRestriction);
     abstract Builder setGetRestrictionCoder(GetRestrictionCoderMethod 
getRestrictionCoder);
     abstract Builder setNewTracker(NewTrackerMethod newTracker);
+    abstract Builder setStateDeclarations(Map<String, StateDeclaration> 
stateDeclarations);
     abstract DoFnSignature build();
   }
 
@@ -163,6 +172,22 @@ public abstract class DoFnSignature {
     }
   }
 
+  /**
+   * Describes a state declaration; a field of type {@link StateSpec} 
annotated with
+   * {@link DoFn.StateId}.
+   */
+  @AutoValue
+  public abstract static class StateDeclaration {
+    public abstract String id();
+    public abstract Field field();
+    public abstract TypeDescriptor<? extends State<?>> stateType();
+
+    static StateDeclaration create(
+        String id, Field field, TypeDescriptor<? extends State<?>> stateType) {
+      return new AutoValue_DoFnSignature_StateDeclaration(id, field, 
stateType);
+    }
+  }
+
   /** Describes a {@link DoFn.Setup} or {@link DoFn.Teardown} method. */
   @AutoValue
   public abstract static class LifecycleMethod implements DoFnMethod {

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/add34518/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
index 524ea24..b7f773d 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
@@ -20,9 +20,12 @@ package org.apache.beam.sdk.transforms.reflect;
 import static com.google.common.base.Preconditions.checkState;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.reflect.TypeParameter;
 import com.google.common.reflect.TypeToken;
 import java.lang.annotation.Annotation;
+import java.lang.reflect.AnnotatedElement;
+import java.lang.reflect.Field;
 import java.lang.reflect.Method;
 import java.lang.reflect.Modifier;
 import java.lang.reflect.ParameterizedType;
@@ -30,18 +33,21 @@ import java.lang.reflect.Type;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
-import java.util.Collections;
+import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
 import javax.annotation.Nullable;
+import javax.swing.plaf.nimbus.State;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.common.ReflectHelpers;
+import org.apache.beam.sdk.util.state.StateSpec;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TypeDescriptor;
 
 /**
  * Parses a {@link DoFn} and computes its {@link DoFnSignature}. See {@link 
#getOrParseSignature}.
@@ -53,7 +59,12 @@ public class DoFnSignatures {
 
   private final Map<Class<?>, DoFnSignature> signatureCache = new 
LinkedHashMap<>();
 
-  /** @return the {@link DoFnSignature} for the given {@link DoFn}. */
+  /** @return the {@link DoFnSignature} for the given {@link DoFn} instance. */
+  public <FnT extends DoFn<?, ?>> DoFnSignature signatureForDoFn(FnT fn) {
+    return getOrParseSignature(fn.getClass());
+  }
+
+  /** @return the {@link DoFnSignature} for the given {@link DoFn} subclass. */
   public synchronized <FnT extends DoFn<?, ?>> DoFnSignature 
getOrParseSignature(Class<FnT> fn) {
     DoFnSignature signature = signatureCache.get(fn);
     if (signature == null) {
@@ -172,6 +183,8 @@ public class DoFnSignatures {
 
     builder.setIsBoundedPerElement(inferBoundedness(fnToken, processElement, 
errors));
 
+    builder.setStateDeclarations(analyzeStateDeclarations(errors, fnClass));
+
     DoFnSignature signature = builder.build();
 
     // Additional validation for splittable DoFn's.
@@ -592,33 +605,129 @@ public class DoFnSignatures {
 
   private static Collection<Method> declaredMethodsWithAnnotation(
       Class<? extends Annotation> anno, Class<?> startClass, Class<?> 
stopClass) {
-    Collection<Method> matches = new ArrayList<>();
+    return declaredMembersWithAnnotation(anno, startClass, stopClass, 
GET_METHODS);
+  }
+
+  private static Collection<Field> declaredFieldsWithAnnotation(
+      Class<? extends Annotation> anno, Class<?> startClass, Class<?> 
stopClass) {
+    return declaredMembersWithAnnotation(anno, startClass, stopClass, 
GET_FIELDS);
+  }
+
+  private static interface MemberGetter<MemberT> {
+    public MemberT[] getMembers(Class<?> clazz);
+  }
+
+  // Class::getDeclaredMethods for Java 7
+  private static final MemberGetter<Method> GET_METHODS =
+      new MemberGetter<Method>() {
+        @Override
+        public Method[] getMembers(Class<?> clazz) {
+          return clazz.getDeclaredMethods();
+        }
+      };
+
+  // Class::getDeclaredFields for Java 7
+  private static final MemberGetter<Field> GET_FIELDS =
+      new MemberGetter<Field>() {
+        @Override
+        public Field[] getMembers(Class<?> clazz) {
+          return clazz.getDeclaredFields();
+        }
+      };
+
+  private static <MemberT extends AnnotatedElement>
+      Collection<MemberT> declaredMembersWithAnnotation(
+          Class<? extends Annotation> anno,
+          Class<?> startClass,
+          Class<?> stopClass,
+          MemberGetter<MemberT> getter) {
+    Collection<MemberT> matches = new ArrayList<>();
 
     Class<?> clazz = startClass;
     LinkedHashSet<Class<?>> interfaces = new LinkedHashSet<>();
 
     // First, find all declared methods on the startClass and parents (up to 
stopClass)
     while (clazz != null && !clazz.equals(stopClass)) {
-      for (Method method : clazz.getDeclaredMethods()) {
-        if (method.isAnnotationPresent(anno)) {
-          matches.add(method);
+      for (MemberT member : getter.getMembers(clazz)) {
+        if (member.isAnnotationPresent(anno)) {
+          matches.add(member);
         }
       }
 
-      Collections.addAll(interfaces, clazz.getInterfaces());
+      // Add all interfaces, including transitive
+      for (TypeDescriptor<?> iface : TypeDescriptor.of(clazz).getInterfaces()) 
{
+        interfaces.add(iface.getRawType());
+      }
 
       clazz = clazz.getSuperclass();
     }
 
     // Now, iterate over all the discovered interfaces
-    for (Method method : 
ReflectHelpers.getClosureOfMethodsOnInterfaces(interfaces)) {
-      if (method.isAnnotationPresent(anno)) {
-        matches.add(method);
+    for (Class<?> iface : interfaces) {
+      for (MemberT member : getter.getMembers(iface)) {
+        if (member.isAnnotationPresent(anno)) {
+          matches.add(member);
+        }
       }
     }
     return matches;
   }
 
+  private static ImmutableMap<String, DoFnSignature.StateDeclaration> 
analyzeStateDeclarations(
+      ErrorReporter errors,
+      Class<?> fnClazz) {
+
+    Map<String, DoFnSignature.StateDeclaration> declarations = new HashMap<>();
+
+    for (Field field : declaredFieldsWithAnnotation(DoFn.StateId.class, 
fnClazz, DoFn.class)) {
+      String id = field.getAnnotation(DoFn.StateId.class).value();
+
+      if (declarations.containsKey(id)) {
+        errors.throwIllegalArgument(
+            "Duplicate %s \"%s\", used on both of [%s] and [%s]",
+            DoFn.StateId.class.getSimpleName(),
+            id,
+            field.toString(),
+            declarations.get(id).field().toString());
+        continue;
+      }
+
+      Class<?> stateSpecRawType = field.getType();
+      if (!(stateSpecRawType.equals(StateSpec.class))) {
+        errors.throwIllegalArgument(
+                "%s annotation on non-%s field [%s] that has class %s",
+            DoFn.StateId.class.getSimpleName(),
+            StateSpec.class.getSimpleName(),
+            field.toString(),
+            stateSpecRawType.getName());
+        continue;
+      }
+
+      if (!Modifier.isFinal(field.getModifiers())) {
+        errors.throwIllegalArgument(
+            "Non-final field %s annotated with %s. State declarations must be 
final.",
+            field.toString(),
+            DoFn.StateId.class.getSimpleName());
+        continue;
+      }
+
+      Type stateSpecType = field.getGenericType();
+
+      // By static typing this is already a well-formed State subclass
+      TypeDescriptor<? extends State<?>> stateType =
+          (TypeDescriptor<? extends State<?>>)
+              TypeDescriptor.of(fnClazz)
+                  .resolveType(
+                      TypeDescriptor.of(
+                              ((ParameterizedType) 
stateSpecType).getActualTypeArguments()[1])
+                          .getType());
+
+      declarations.put(id, DoFnSignature.StateDeclaration.create(id, field, 
stateType));
+    }
+
+    return  ImmutableMap.copyOf(declarations);
+  }
+
   private static Method findAnnotatedMethod(
       ErrorReporter errors, Class<? extends Annotation> anno, Class<?> 
fnClazz, boolean required) {
     Collection<Method> matches = declaredMethodsWithAnnotation(anno, fnClazz, 
DoFn.class);

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/add34518/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 9c7b991..bda696f 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -60,6 +60,10 @@ import 
org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
+import org.apache.beam.sdk.util.state.StateSpec;
+import org.apache.beam.sdk.util.state.StateSpecs;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PCollectionView;
@@ -1450,6 +1454,29 @@ public class ParDoTest implements Serializable {
     assertThat(displayData, hasDisplayItem("fn", fn.getClass()));
   }
 
+  /**
+   * A test that we properly reject {@link DoFn} implementations that
+   * include {@link DoFn.StateId} annotations, for now.
+   */
+  @Test
+  public void testUnsupportedState() {
+    thrown.expect(UnsupportedOperationException.class);
+    thrown.expectMessage("cannot yet be used with state");
+
+    DoFn<KV<String, String>, KV<String, String>> fn =
+        new DoFn<KV<String, String>, KV<String, String>>() {
+
+      @StateId("foo")
+      private final StateSpec<Object, ValueState<Integer>> intState =
+          StateSpecs.value(VarIntCoder.of());
+
+      @ProcessElement
+      public void processElement(ProcessContext c) { }
+    };
+
+    ParDo.of(fn);
+  }
+
   @Test
   public void testWithOutputTagsDisplayData() {
     DoFn<String, String> fn = new DoFn<String, String>() {

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/add34518/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
index fc468c9..e040179 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
@@ -18,10 +18,20 @@
 package org.apache.beam.sdk.transforms.reflect;
 
 import static 
org.apache.beam.sdk.transforms.reflect.DoFnSignaturesTestUtils.errors;
+import static org.hamcrest.Matchers.equalTo;
+import static org.junit.Assert.assertThat;
 
 import com.google.common.reflect.TypeToken;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignaturesTestUtils.FakeDoFn;
+import org.apache.beam.sdk.util.state.StateSpec;
+import org.apache.beam.sdk.util.state.StateSpecs;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import org.hamcrest.Matchers;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -120,4 +130,137 @@ public class DoFnSignaturesTest {
           void finishBundle() {}
         }.getClass());
   }
+
+  @Test
+  public void testStateIdWithWrongType() throws Exception {
+    thrown.expect(IllegalArgumentException.class);
+    thrown.expectMessage("StateId");
+    thrown.expectMessage("StateSpec");
+    DoFnSignatures.INSTANCE.getOrParseSignature(
+        new DoFn<String, String>() {
+          @StateId("foo")
+          String bizzle = "bazzle";
+
+          @ProcessElement
+          public void foo(ProcessContext context) {}
+        }.getClass());
+  }
+
+  @Test
+  public void testStateIdDuplicate() throws Exception {
+    thrown.expect(IllegalArgumentException.class);
+    thrown.expectMessage("Duplicate");
+    thrown.expectMessage("StateId");
+    thrown.expectMessage("my-state-id");
+    thrown.expectMessage("myfield1");
+    thrown.expectMessage("myfield2");
+    DoFnSignature sig =
+        DoFnSignatures.INSTANCE.getOrParseSignature(
+            new DoFn<KV<String, Integer>, Long>() {
+              @StateId("my-state-id")
+              private final StateSpec<Object, ValueState<Integer>> myfield1 =
+                  StateSpecs.value(VarIntCoder.of());
+
+              @StateId("my-state-id")
+              StateSpec<Object, ValueState<Long>> myfield2 = 
StateSpecs.value(VarLongCoder.of());
+
+              @ProcessElement
+              public void foo(ProcessContext context) {}
+            }.getClass());
+  }
+
+  @Test
+  public void testStateIdNonFinal() throws Exception {
+    thrown.expect(IllegalArgumentException.class);
+    thrown.expectMessage("State declarations must be final");
+    thrown.expectMessage("Non-final field");
+    thrown.expectMessage("myfield");
+    DoFnSignature sig =
+        DoFnSignatures.INSTANCE.getOrParseSignature(
+            new DoFn<KV<String, Integer>, Long>() {
+              @StateId("my-state-id")
+              private StateSpec<Object, ValueState<Integer>> myfield =
+                  StateSpecs.value(VarIntCoder.of());
+
+              @ProcessElement
+              public void foo(ProcessContext context) {}
+            }.getClass());
+  }
+
+  @Test
+  public void testSimpleStateIdAnonymousDoFn() throws Exception {
+    DoFnSignature sig =
+        DoFnSignatures.INSTANCE.getOrParseSignature(
+            new DoFn<KV<String, Integer>, Long>() {
+              @StateId("foo")
+              private final StateSpec<Object, ValueState<Integer>> bizzle =
+                  StateSpecs.value(VarIntCoder.of());
+
+              @ProcessElement
+              public void foo(ProcessContext context) {}
+            }.getClass());
+
+    assertThat(sig.stateDeclarations().size(), equalTo(1));
+    DoFnSignature.StateDeclaration decl = sig.stateDeclarations().get("foo");
+
+    assertThat(decl.id(), equalTo("foo"));
+    assertThat(decl.field().getName(), equalTo("bizzle"));
+    assertThat(
+        decl.stateType(),
+        Matchers.<TypeDescriptor<?>>equalTo(new 
TypeDescriptor<ValueState<Integer>>() {}));
+  }
+
+  @Test
+  public void testSimpleStateIdNamedDoFn() throws Exception {
+    // Test classes at the bottom of the file
+    DoFnSignature sig =
+        DoFnSignatures.INSTANCE.signatureForDoFn(new 
DoFnForTestSimpleStateIdNamedDoFn());
+
+    assertThat(sig.stateDeclarations().size(), equalTo(1));
+    DoFnSignature.StateDeclaration decl = sig.stateDeclarations().get("foo");
+
+    assertThat(decl.id(), equalTo("foo"));
+    assertThat(
+        decl.field(), 
equalTo(DoFnForTestSimpleStateIdNamedDoFn.class.getDeclaredField("bizzle")));
+    assertThat(
+        decl.stateType(),
+        Matchers.<TypeDescriptor<?>>equalTo(new 
TypeDescriptor<ValueState<Integer>>() {}));
+  }
+
+  @Test
+  public void testGenericStatefulDoFn() throws Exception {
+    // Test classes at the bottom of the file
+    DoFn<KV<String, Integer>, Long> myDoFn = new 
DoFnForTestGenericStatefulDoFn<Integer>(){};
+
+    DoFnSignature sig = DoFnSignatures.INSTANCE.signatureForDoFn(myDoFn);
+
+    assertThat(sig.stateDeclarations().size(), equalTo(1));
+    DoFnSignature.StateDeclaration decl = sig.stateDeclarations().get("foo");
+
+    assertThat(decl.id(), equalTo("foo"));
+    assertThat(
+        decl.field(), 
equalTo(DoFnForTestGenericStatefulDoFn.class.getDeclaredField("bizzle")));
+    assertThat(
+        decl.stateType(),
+        Matchers.<TypeDescriptor<?>>equalTo(new 
TypeDescriptor<ValueState<Integer>>() {}));
+  }
+
+  private static class DoFnForTestSimpleStateIdNamedDoFn extends 
DoFn<KV<String, Integer>, Long> {
+    @StateId("foo")
+    private final StateSpec<Object, ValueState<Integer>> bizzle =
+        StateSpecs.value(VarIntCoder.of());
+
+    @ProcessElement
+    public void foo(ProcessContext context) {}
+  }
+
+  private static class DoFnForTestGenericStatefulDoFn<T> extends 
DoFn<KV<String, T>, Long> {
+    // Note that in order to have a coder for T it will require initialization 
in the constructor,
+    // but that isn't important for this test
+    @StateId("foo")
+    private final StateSpec<Object, ValueState<T>> bizzle = null;
+
+    @ProcessElement
+    public void foo(ProcessContext context) {}
+  }
 }

Reply via email to