Repository: beam
Updated Branches:
  refs/heads/master 03dce6dcc -> e31ca8b0d


[BEAM-1337] Infer state coders


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/42e690e8
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/42e690e8
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/42e690e8

Branch: refs/heads/master
Commit: 42e690e84a9f05d508f2528b1444b26ce031e080
Parents: 03dce6d
Author: Aviem Zur <aviem...@gmail.com>
Authored: Wed Mar 1 07:27:57 2017 +0200
Committer: Aviem Zur <aviem...@gmail.com>
Committed: Sat Apr 1 10:27:14 2017 +0300

----------------------------------------------------------------------
 .../org/apache/beam/sdk/transforms/ParDo.java   |  62 ++
 .../apache/beam/sdk/util/state/StateSpec.java   |  15 +
 .../apache/beam/sdk/util/state/StateSpecs.java  | 264 ++++++++-
 .../apache/beam/sdk/transforms/ParDoTest.java   | 578 +++++++++++++++++++
 4 files changed, 902 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/42e690e8/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
index 664fbc3..3de845b 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
@@ -22,6 +22,8 @@ import static 
com.google.common.base.Preconditions.checkArgument;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import java.io.Serializable;
+import java.lang.reflect.ParameterizedType;
+import java.lang.reflect.Type;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
@@ -29,6 +31,7 @@ import java.util.Map;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.transforms.display.DisplayData.Builder;
@@ -41,6 +44,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.WindowFn;
 import org.apache.beam.sdk.util.NameUtils;
 import org.apache.beam.sdk.util.SerializableUtils;
+import org.apache.beam.sdk.util.state.StateSpec;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PCollectionView;
@@ -434,6 +438,59 @@ public class ParDo {
     return DisplayData.item("fn", fn.getClass()).withLabel("Transform 
Function");
   }
 
+  private static void finishSpecifyingStateSpecs(
+      DoFn<?, ?> fn,
+      CoderRegistry coderRegistry,
+      Coder<?> inputCoder) {
+    DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass());
+    Map<String, DoFnSignature.StateDeclaration> stateDeclarations = 
signature.stateDeclarations();
+    for (DoFnSignature.StateDeclaration stateDeclaration : 
stateDeclarations.values()) {
+      try {
+        StateSpec<?, ?> stateSpec = (StateSpec<?, ?>) 
stateDeclaration.field().get(fn);
+        stateSpec.offerCoders(codersForStateSpecTypes(stateDeclaration, 
coderRegistry, inputCoder));
+        stateSpec.finishSpecifying();
+      } catch (IllegalAccessException e) {
+        throw new RuntimeException(e);
+      }
+    }
+  }
+
+  /**
+   * Try to provide coders for as many of the type arguments of given
+   * {@link DoFnSignature.StateDeclaration} as possible.
+   */
+  private static <InputT> Coder[] codersForStateSpecTypes(
+      DoFnSignature.StateDeclaration stateDeclaration,
+      CoderRegistry coderRegistry,
+      Coder<InputT> inputCoder) {
+    Type stateType = stateDeclaration.stateType().getType();
+
+    if (!(stateType instanceof ParameterizedType)) {
+      // No type arguments means no coders to infer.
+      return new Coder[0];
+    }
+
+    Type[] typeArguments = ((ParameterizedType) 
stateType).getActualTypeArguments();
+    Coder[] coders = new Coder[typeArguments.length];
+
+    for (int i = 0; i < typeArguments.length; i++) {
+      Type typeArgument = typeArguments[i];
+      TypeDescriptor<?> typeDescriptor = TypeDescriptor.of(typeArgument);
+      try {
+        coders[i] = coderRegistry.getDefaultCoder(typeDescriptor);
+      } catch (CannotProvideCoderException e) {
+        try {
+          coders[i] = coderRegistry.getDefaultCoder(
+              typeDescriptor, inputCoder.getEncodedTypeDescriptor(), 
inputCoder);
+        } catch (CannotProvideCoderException ignored) {
+          // Since not all type arguments will have a registered coder we 
ignore this exception.
+        }
+      }
+    }
+
+    return coders;
+  }
+
   /**
    * Perform common validations of the {@link DoFn} against the input {@link 
PCollection}, for
    * example ensuring that the window type expected by the {@link DoFn} 
matches the window type of
@@ -557,6 +614,7 @@ public class ParDo {
 
     @Override
     public PCollection<OutputT> expand(PCollection<? extends InputT> input) {
+      finishSpecifyingStateSpecs(fn, input.getPipeline().getCoderRegistry(), 
input.getCoder());
       TupleTag<OutputT> mainOutput = new TupleTag<>();
       return input.apply(withOutputTags(mainOutput, 
TupleTagList.empty())).get(mainOutput);
     }
@@ -681,6 +739,10 @@ public class ParDo {
     public PCollectionTuple expand(PCollection<? extends InputT> input) {
       // SplittableDoFn should be forbidden on the runner-side.
       validateWindowType(input, fn);
+
+      // Use coder registry to determine coders for all StateSpec defined in 
the fn signature.
+      finishSpecifyingStateSpecs(fn, input.getPipeline().getCoderRegistry(), 
input.getCoder());
+
       PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal(
           input.getPipeline(),
           TupleTagList.of(mainOutputTag).and(sideOutputTags.getAll()),

http://git-wip-us.apache.org/repos/asf/beam/blob/42e690e8/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpec.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpec.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpec.java
index 4fdeefb..6b94c40 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpec.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpec.java
@@ -20,6 +20,7 @@ package org.apache.beam.sdk.util.state;
 import java.io.Serializable;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.annotations.Experimental.Kind;
+import org.apache.beam.sdk.coders.Coder;
 
 /**
  * A specification of a persistent state cell. This includes information 
necessary to encode the
@@ -36,4 +37,18 @@ public interface StateSpec<K, StateT extends State> extends 
Serializable {
    * Use the {@code binder} to create an instance of {@code StateT} 
appropriate for this address.
    */
   StateT bind(String id, StateBinder<? extends K> binder);
+
+  /**
+   * Given {code coders} are inferred from type arguments defined for this 
class.
+   * Coders which are already set should take precedence over offered coders.
+   * @param coders Array of coders indexed by the type arguments order. 
Entries might be null if
+   *               the coder could not be inferred.
+   */
+  void offerCoders(Coder[] coders);
+
+  /**
+   * Validates that this {@link StateSpec} has been specified correctly and 
finalizes it.
+   * Automatically invoked when the pipeline is built.
+   */
+  void finishSpecifying();
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/42e690e8/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java
index 8912993..6a8c80b 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java
@@ -17,7 +17,10 @@
  */
 package org.apache.beam.sdk.util.state;
 
+import static com.google.common.base.Preconditions.checkArgument;
+
 import java.util.Objects;
+import javax.annotation.Nullable;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.annotations.Experimental.Kind;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
@@ -44,7 +47,13 @@ public class StateSpecs {
   private StateSpecs() {}
 
   /** Create a simple state spec for values of type {@code T}. */
+  public static <T> StateSpec<Object, ValueState<T>> value() {
+    return new ValueStateSpec<>(null);
+  }
+
+  /** Create a simple state spec for values of type {@code T}. */
   public static <T> StateSpec<Object, ValueState<T>> value(Coder<T> 
valueCoder) {
+    checkArgument(valueCoder != null, "valueCoder should not be null. Consider 
value() instead");
     return new ValueStateSpec<>(valueCoder);
   }
 
@@ -53,8 +62,21 @@ public class StateSpecs {
    * {@code InputT}s into a single {@code OutputT}.
    */
   public static <InputT, AccumT, OutputT>
+  StateSpec<Object, AccumulatorCombiningState<InputT, AccumT, OutputT>> 
combiningValue(
+      CombineFn<InputT, AccumT, OutputT> combineFn) {
+    return new CombiningValueStateSpec<InputT, AccumT, OutputT>(null, 
combineFn);
+  }
+
+  /**
+   * Create a state spec for values that use a {@link CombineFn} to 
automatically merge multiple
+   * {@code InputT}s into a single {@code OutputT}.
+   */
+  public static <InputT, AccumT, OutputT>
       StateSpec<Object, AccumulatorCombiningState<InputT, AccumT, OutputT>> 
combiningValue(
           Coder<AccumT> accumCoder, CombineFn<InputT, AccumT, OutputT> 
combineFn) {
+    checkArgument(accumCoder != null,
+        "accumCoder should not be null. "
+            + "Consider using combiningValue(CombineFn<> combineFn) instead.");
     return combiningValueInternal(accumCoder, combineFn);
   }
 
@@ -63,8 +85,21 @@ public class StateSpecs {
    * multiple {@code InputT}s into a single {@code OutputT}.
    */
   public static <K, InputT, AccumT, OutputT>
+  StateSpec<K, AccumulatorCombiningState<InputT, AccumT, OutputT>> 
keyedCombiningValue(
+      KeyedCombineFn<K, InputT, AccumT, OutputT> combineFn) {
+    return new KeyedCombiningValueStateSpec<K, InputT, AccumT, OutputT>(null, 
combineFn);
+  }
+
+  /**
+   * Create a state spec for values that use a {@link KeyedCombineFn} to 
automatically merge
+   * multiple {@code InputT}s into a single {@code OutputT}.
+   */
+  public static <K, InputT, AccumT, OutputT>
       StateSpec<K, AccumulatorCombiningState<InputT, AccumT, OutputT>> 
keyedCombiningValue(
           Coder<AccumT> accumCoder, KeyedCombineFn<K, InputT, AccumT, OutputT> 
combineFn) {
+    checkArgument(accumCoder != null,
+        "accumCoder should not be null. "
+            + "Consider using keyedCombiningValue(KeyedCombineFn<> combineFn) 
instead.");
     return keyedCombiningValueInternal(accumCoder, combineFn);
   }
 
@@ -73,10 +108,23 @@ public class StateSpecs {
    * merge multiple {@code InputT}s into a single {@code OutputT}.
    */
   public static <K, InputT, AccumT, OutputT>
+  StateSpec<K, AccumulatorCombiningState<InputT, AccumT, OutputT>>
+  keyedCombiningValueWithContext(KeyedCombineFnWithContext<K, InputT, AccumT, 
OutputT> combineFn) {
+    return new KeyedCombiningValueWithContextStateSpec<K, InputT, AccumT, 
OutputT>(null, combineFn);
+  }
+
+  /**
+   * Create a state spec for values that use a {@link 
KeyedCombineFnWithContext} to automatically
+   * merge multiple {@code InputT}s into a single {@code OutputT}.
+   */
+  public static <K, InputT, AccumT, OutputT>
       StateSpec<K, AccumulatorCombiningState<InputT, AccumT, OutputT>>
           keyedCombiningValueWithContext(
               Coder<AccumT> accumCoder,
               KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> combineFn) 
{
+    checkArgument(accumCoder != null,
+        "accumCoder should not be null. Consider using "
+            + "keyedCombiningValueWithContext(KeyedCombineFnWithContext<> 
combineFn) instead.");
     return new KeyedCombiningValueWithContextStateSpec<K, InputT, AccumT, 
OutputT>(
         accumCoder, combineFn);
   }
@@ -121,8 +169,23 @@ public class StateSpecs {
    * Create a state spec that is optimized for adding values frequently, and 
occasionally retrieving
    * all the values that have been added.
    */
+  public static <T> StateSpec<Object, BagState<T>> bag() {
+    return bag(null);
+  }
+
+  /**
+   * Create a state spec that is optimized for adding values frequently, and 
occasionally retrieving
+   * all the values that have been added.
+   */
   public static <T> StateSpec<Object, BagState<T>> bag(Coder<T> elemCoder) {
-    return new BagStateSpec<T>(elemCoder);
+    return new BagStateSpec<>(elemCoder);
+  }
+
+  /**
+   * Create a state spec that supporting for {@link java.util.Set} like access 
patterns.
+   */
+  public static <T> StateSpec<Object, SetState<T>> set() {
+    return set(null);
   }
 
   /**
@@ -135,6 +198,13 @@ public class StateSpecs {
   /**
    * Create a state spec that supporting for {@link java.util.Map} like access 
patterns.
    */
+  public static <K, V> StateSpec<Object, MapState<K, V>> map() {
+    return new MapStateSpec<>(null, null);
+  }
+
+  /**
+   * Create a state spec that supporting for {@link java.util.Map} like access 
patterns.
+   */
   public static <K, V> StateSpec<Object, MapState<K, V>> map(Coder<K> keyCoder,
                                                              Coder<V> 
valueCoder) {
     return new MapStateSpec<>(keyCoder, valueCoder);
@@ -174,9 +244,10 @@ public class StateSpecs {
    */
   private static class ValueStateSpec<T> implements StateSpec<Object, 
ValueState<T>> {
 
-    private final Coder<T> coder;
+    @Nullable
+    private Coder<T> coder;
 
-    private ValueStateSpec(Coder<T> coder) {
+    private ValueStateSpec(@Nullable Coder<T> coder) {
       this.coder = coder;
     }
 
@@ -185,6 +256,25 @@ public class StateSpecs {
       return visitor.bindValue(id, this, coder);
     }
 
+    @SuppressWarnings("unchecked")
+    @Override
+    public void offerCoders(Coder[] coders) {
+      if (this.coder == null) {
+        if (coders[0] != null) {
+          this.coder = (Coder<T>) coders[0];
+        }
+      }
+    }
+
+    @Override public void finishSpecifying() {
+      if (coder == null) {
+        throw new IllegalStateException("Unable to infer a coder for 
ValueState and no Coder"
+            + " was specified. Please set a coder by either invoking"
+            + " StateSpecs.value(Coder<T> valueCoder) or by registering the 
coder in the"
+            + " Pipeline's CoderRegistry.");
+      }
+    }
+
     @Override
     public boolean equals(Object obj) {
       if (obj == this) {
@@ -214,15 +304,32 @@ public class StateSpecs {
       extends KeyedCombiningValueStateSpec<Object, InputT, AccumT, OutputT>
       implements StateSpec<Object, AccumulatorCombiningState<InputT, AccumT, 
OutputT>> {
 
-    private final Coder<AccumT> accumCoder;
+    @Nullable
+    private Coder<AccumT> accumCoder;
     private final CombineFn<InputT, AccumT, OutputT> combineFn;
 
     private CombiningValueStateSpec(
-        Coder<AccumT> accumCoder, CombineFn<InputT, AccumT, OutputT> 
combineFn) {
+        @Nullable Coder<AccumT> accumCoder,
+        CombineFn<InputT, AccumT, OutputT> combineFn) {
       super(accumCoder, combineFn.asKeyedFn());
       this.combineFn = combineFn;
       this.accumCoder = accumCoder;
     }
+
+    @Override
+    protected Coder<AccumT> getAccumCoder() {
+      return accumCoder;
+    }
+
+    @SuppressWarnings("unchecked")
+    @Override
+    public void offerCoders(Coder[] coders) {
+      if (this.accumCoder == null) {
+        if (coders[1] != null) {
+          this.accumCoder = (Coder<AccumT>) coders[1];
+        }
+      }
+    }
   }
 
   /**
@@ -234,11 +341,13 @@ public class StateSpecs {
   private static class KeyedCombiningValueWithContextStateSpec<K, InputT, 
AccumT, OutputT>
       implements StateSpec<K, AccumulatorCombiningState<InputT, AccumT, 
OutputT>> {
 
-    private final Coder<AccumT> accumCoder;
+    @Nullable
+    private Coder<AccumT> accumCoder;
     private final KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> 
combineFn;
 
     protected KeyedCombiningValueWithContextStateSpec(
-        Coder<AccumT> accumCoder, KeyedCombineFnWithContext<K, InputT, AccumT, 
OutputT> combineFn) {
+        @Nullable Coder<AccumT> accumCoder,
+        KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> combineFn) {
       this.combineFn = combineFn;
       this.accumCoder = accumCoder;
     }
@@ -249,6 +358,27 @@ public class StateSpecs {
       return visitor.bindKeyedCombiningValueWithContext(id, this, accumCoder, 
combineFn);
     }
 
+    @SuppressWarnings("unchecked")
+    @Override
+    public void offerCoders(Coder[] coders) {
+      if (this.accumCoder == null) {
+        if (coders[2] != null) {
+          this.accumCoder = (Coder<AccumT>) coders[2];
+        }
+      }
+    }
+
+    @Override public void finishSpecifying() {
+      if (accumCoder == null) {
+        throw new IllegalStateException("Unable to infer a coder for"
+            + " KeyedCombiningValueWithContextState and no Coder was 
specified."
+            + " Please set a coder by either invoking"
+            + " StateSpecs.keyedCombiningValue(Coder<AccumT> accumCoder,"
+            + " KeyedCombineFn<K, InputT, AccumT, OutputT> combineFn)"
+            + " or by registering the coder in the Pipeline's CoderRegistry.");
+      }
+    }
+
     @Override
     public boolean equals(Object obj) {
       if (obj == this) {
@@ -282,19 +412,45 @@ public class StateSpecs {
   private static class KeyedCombiningValueStateSpec<K, InputT, AccumT, OutputT>
       implements StateSpec<K, AccumulatorCombiningState<InputT, AccumT, 
OutputT>> {
 
-    private final Coder<AccumT> accumCoder;
+    @Nullable
+    private Coder<AccumT> accumCoder;
     private final KeyedCombineFn<K, InputT, AccumT, OutputT> keyedCombineFn;
 
     protected KeyedCombiningValueStateSpec(
-        Coder<AccumT> accumCoder, KeyedCombineFn<K, InputT, AccumT, OutputT> 
keyedCombineFn) {
+        @Nullable Coder<AccumT> accumCoder,
+        KeyedCombineFn<K, InputT, AccumT, OutputT> keyedCombineFn) {
       this.keyedCombineFn = keyedCombineFn;
       this.accumCoder = accumCoder;
     }
 
+    protected Coder<AccumT> getAccumCoder() {
+      return accumCoder;
+    }
+
     @Override
     public AccumulatorCombiningState<InputT, AccumT, OutputT> bind(
         String id, StateBinder<? extends K> visitor) {
-      return visitor.bindKeyedCombiningValue(id, this, accumCoder, 
keyedCombineFn);
+      return visitor.bindKeyedCombiningValue(id, this, getAccumCoder(), 
keyedCombineFn);
+    }
+
+    @SuppressWarnings("unchecked")
+    @Override
+    public void offerCoders(Coder[] coders) {
+      if (this.accumCoder == null) {
+        if (coders[2] != null) {
+          this.accumCoder = (Coder<AccumT>) coders[2];
+        }
+      }
+    }
+
+    @Override public void finishSpecifying() {
+      if (getAccumCoder() == null) {
+        throw new IllegalStateException("Unable to infer a coder for 
CombiningState and no"
+            + " Coder was specified. Please set a coder by either invoking"
+            + " StateSpecs.combiningValue(Coder<AccumT> accumCoder,"
+            + " CombineFn<InputT, AccumT, OutputT> combineFn)"
+            + " or by registering the coder in the Pipeline's CoderRegistry.");
+      }
     }
 
     @Override
@@ -330,9 +486,10 @@ public class StateSpecs {
    */
   private static class BagStateSpec<T> implements StateSpec<Object, 
BagState<T>> {
 
-    private final Coder<T> elemCoder;
+    @Nullable
+    private Coder<T> elemCoder;
 
-    private BagStateSpec(Coder<T> elemCoder) {
+    private BagStateSpec(@Nullable Coder<T> elemCoder) {
       this.elemCoder = elemCoder;
     }
 
@@ -341,6 +498,25 @@ public class StateSpecs {
       return visitor.bindBag(id, this, elemCoder);
     }
 
+    @SuppressWarnings("unchecked")
+    @Override
+    public void offerCoders(Coder[] coders) {
+      if (this.elemCoder == null) {
+        if (coders[0] != null) {
+          this.elemCoder = (Coder<T>) coders[0];
+        }
+      }
+    }
+
+    @Override public void finishSpecifying() {
+      if (elemCoder == null) {
+        throw new IllegalStateException("Unable to infer a coder for BagState 
and no Coder"
+            + " was specified. Please set a coder by either invoking"
+            + " StateSpecs.bag(Coder<T> elemCoder) or by registering the coder 
in the"
+            + " Pipeline's CoderRegistry.");
+      }
+    }
+
     @Override
     public boolean equals(Object obj) {
       if (obj == this) {
@@ -363,10 +539,12 @@ public class StateSpecs {
 
   private static class MapStateSpec<K, V> implements StateSpec<Object, 
MapState<K, V>> {
 
-    private final Coder<K> keyCoder;
-    private final Coder<V> valueCoder;
+    @Nullable
+    private Coder<K> keyCoder;
+    @Nullable
+    private Coder<V> valueCoder;
 
-    private MapStateSpec(Coder<K> keyCoder, Coder<V> valueCoder) {
+    private MapStateSpec(@Nullable Coder<K> keyCoder, @Nullable Coder<V> 
valueCoder) {
       this.keyCoder = keyCoder;
       this.valueCoder = valueCoder;
     }
@@ -376,6 +554,30 @@ public class StateSpecs {
       return visitor.bindMap(id, this, keyCoder, valueCoder);
     }
 
+    @SuppressWarnings("unchecked")
+    @Override
+    public void offerCoders(Coder[] coders) {
+      if (this.keyCoder == null) {
+        if (coders[0] != null) {
+          this.keyCoder = (Coder<K>) coders[0];
+        }
+      }
+      if (this.valueCoder == null) {
+        if (coders[1] != null) {
+          this.valueCoder = (Coder<V>) coders[1];
+        }
+      }
+    }
+
+    @Override public void finishSpecifying() {
+      if (keyCoder == null || valueCoder == null) {
+        throw new IllegalStateException("Unable to infer a coder for MapState 
and no Coder"
+            + " was specified. Please set a coder by either invoking"
+            + " StateSpecs.map(Coder<K> keyCoder, Coder<V> valueCoder) or by 
registering the"
+            + " coder in the Pipeline's CoderRegistry.");
+      }
+    }
+
     @Override
     public boolean equals(Object obj) {
       if (obj == this) {
@@ -404,9 +606,10 @@ public class StateSpecs {
    */
   private static class SetStateSpec<T> implements StateSpec<Object, 
SetState<T>> {
 
-    private final Coder<T> elemCoder;
+    @Nullable
+    private Coder<T> elemCoder;
 
-    private SetStateSpec(Coder<T> elemCoder) {
+    private SetStateSpec(@Nullable Coder<T> elemCoder) {
       this.elemCoder = elemCoder;
     }
 
@@ -415,6 +618,25 @@ public class StateSpecs {
       return visitor.bindSet(id, this, elemCoder);
     }
 
+    @SuppressWarnings("unchecked")
+    @Override
+    public void offerCoders(Coder[] coders) {
+      if (this.elemCoder == null) {
+        if (coders[0] != null) {
+          this.elemCoder = (Coder<T>) coders[0];
+        }
+      }
+    }
+
+    @Override public void finishSpecifying() {
+      if (elemCoder == null) {
+        throw new IllegalStateException("Unable to infer a coder for SetState 
and no Coder"
+            + " was specified. Please set a coder by either invoking"
+            + " StateSpecs.set(Coder<T> elemCoder) or by registering the coder 
in the"
+            + " Pipeline's CoderRegistry.");
+      }
+    }
+
     @Override
     public boolean equals(Object obj) {
       if (obj == this) {
@@ -461,6 +683,14 @@ public class StateSpecs {
     }
 
     @Override
+    public void offerCoders(Coder[] coders) {
+    }
+
+    @Override public void finishSpecifying() {
+      // Currently an empty implementation as there are no coders to validate.
+    }
+
+    @Override
     public boolean equals(Object obj) {
       if (obj == this) {
         return true;

http://git-wip-us.apache.org/repos/asf/beam/blob/42e690e8/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index cbbbe5f..4249a77 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -39,6 +39,7 @@ import static org.junit.Assert.assertThat;
 
 import com.fasterxml.jackson.annotation.JsonCreator;
 import com.google.common.base.MoreObjects;
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
@@ -55,8 +56,12 @@ import java.util.Map;
 import java.util.Set;
 import org.apache.beam.sdk.Pipeline.PipelineExecutionException;
 import org.apache.beam.sdk.coders.AtomicCoder;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.CustomCoder;
 import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.coders.SetCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.io.CountingInput;
@@ -1036,6 +1041,71 @@ public class ParDoTest implements Serializable {
      }
   }
 
+  private static class MyInteger implements Comparable<MyInteger> {
+    private final int value;
+
+    MyInteger(int value) {
+      this.value = value;
+    }
+
+    public int getValue() {
+      return value;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+
+      if (!(o instanceof MyInteger)) {
+        return false;
+      }
+
+      MyInteger myInteger = (MyInteger) o;
+
+      return value == myInteger.value;
+
+    }
+
+    @Override
+    public int hashCode() {
+      return value;
+    }
+
+    @Override
+    public int compareTo(MyInteger o) {
+      return Integer.compare(this.getValue(), o.getValue());
+    }
+
+    @Override
+    public String toString() {
+      return "MyInteger{" + "value=" + value + '}';
+    }
+  }
+
+  private static class MyIntegerCoder extends CustomCoder<MyInteger> {
+    private static final MyIntegerCoder INSTANCE = new MyIntegerCoder();
+
+    private final VarIntCoder delegate = VarIntCoder.of();
+
+    public static MyIntegerCoder of() {
+      return INSTANCE;
+    }
+
+    @Override
+    public void encode(MyInteger value, OutputStream outStream, Context 
context)
+        throws CoderException, IOException {
+      delegate.encode(value.getValue(), outStream, context);
+    }
+
+    @Override
+    public MyInteger decode(InputStream inStream, Context context) throws 
CoderException,
+        IOException {
+      return new MyInteger(delegate.decode(inStream, context));
+    }
+  }
+
   /** PAssert "matcher" for expected output. */
   static class HasExpectedOutput
       implements SerializableFunction<Iterable<String>, Void>, Serializable {
@@ -1619,6 +1689,132 @@ public class ParDoTest implements Serializable {
 
   @Test
   @Category({ValidatesRunner.class, UsesStatefulParDo.class})
+  public void testValueStateCoderInference() {
+    final String stateId = "foo";
+    MyIntegerCoder myIntegerCoder = MyIntegerCoder.of();
+    pipeline.getCoderRegistry().registerCoder(MyInteger.class, myIntegerCoder);
+
+    DoFn<KV<String, Integer>, MyInteger> fn =
+        new DoFn<KV<String, Integer>, MyInteger>() {
+
+          @StateId(stateId)
+          private final StateSpec<Object, ValueState<MyInteger>> intState =
+              StateSpecs.value();
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c, @StateId(stateId) ValueState<MyInteger> state) 
{
+            MyInteger currentValue = MoreObjects.firstNonNull(state.read(), 
new MyInteger(0));
+            c.output(currentValue);
+            state.write(new MyInteger(currentValue.getValue() + 1));
+          }
+        };
+
+    PCollection<MyInteger> output =
+        pipeline.apply(Create.of(KV.of("hello", 42), KV.of("hello", 97), 
KV.of("hello", 84)))
+            .apply(ParDo.of(fn)).setCoder(myIntegerCoder);
+
+    PAssert.that(output).containsInAnyOrder(new MyInteger(0), new 
MyInteger(1), new MyInteger(2));
+    pipeline.run();
+  }
+
+  @Test
+  @Category({ValidatesRunner.class, UsesStatefulParDo.class})
+  public void testValueStateCoderInferenceFailure() throws Exception {
+    final String stateId = "foo";
+    MyIntegerCoder myIntegerCoder = MyIntegerCoder.of();
+
+    DoFn<KV<String, Integer>, MyInteger> fn =
+        new DoFn<KV<String, Integer>, MyInteger>() {
+
+          @StateId(stateId)
+          private final StateSpec<Object, ValueState<MyInteger>> intState =
+              StateSpecs.value();
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c, @StateId(stateId) ValueState<MyInteger> state) 
{
+            MyInteger currentValue = MoreObjects.firstNonNull(state.read(), 
new MyInteger(0));
+            c.output(currentValue);
+            state.write(new MyInteger(currentValue.getValue() + 1));
+          }
+        };
+
+    thrown.expect(RuntimeException.class);
+    thrown.expectMessage("Unable to infer a coder for ValueState and no Coder 
was specified.");
+
+    pipeline.apply(Create.of(KV.of("hello", 42), KV.of("hello", 97), 
KV.of("hello", 84)))
+        .apply(ParDo.of(fn)).setCoder(myIntegerCoder);
+
+    pipeline.run();
+  }
+
+  @Test
+  @Category({ValidatesRunner.class, UsesStatefulParDo.class})
+  public void testValueStateCoderInferenceFromInputCoder() {
+    final String stateId = "foo";
+    MyIntegerCoder myIntegerCoder = MyIntegerCoder.of();
+
+    DoFn<KV<String, MyInteger>, MyInteger> fn =
+        new DoFn<KV<String, MyInteger>, MyInteger>() {
+
+          @StateId(stateId)
+          private final StateSpec<Object, ValueState<MyInteger>> intState =
+              StateSpecs.value();
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c, @StateId(stateId) ValueState<MyInteger> state) 
{
+            MyInteger currentValue = MoreObjects.firstNonNull(state.read(), 
new MyInteger(0));
+            c.output(currentValue);
+            state.write(new MyInteger(currentValue.getValue() + 1));
+          }
+        };
+
+        pipeline
+            .apply(Create.of(KV.of("hello", new MyInteger(42)),
+                KV.of("hello", new MyInteger(97)), KV.of("hello", new 
MyInteger(84)))
+                .withCoder(KvCoder.of(StringUtf8Coder.of(), myIntegerCoder)))
+            .apply(ParDo.of(fn)).setCoder(myIntegerCoder);
+
+    pipeline.run();
+  }
+
+  @Test
+  @Category({ValidatesRunner.class, UsesStatefulParDo.class})
+  public void testCoderInferenceOfList() {
+    final String stateId = "foo";
+    MyIntegerCoder myIntegerCoder = MyIntegerCoder.of();
+    pipeline.getCoderRegistry().registerCoder(MyInteger.class, myIntegerCoder);
+
+    DoFn<KV<String, Integer>, List<MyInteger>> fn =
+        new DoFn<KV<String, Integer>, List<MyInteger>>() {
+
+          @StateId(stateId)
+          private final StateSpec<Object, ValueState<List<MyInteger>>> 
intState =
+              StateSpecs.value();
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c, @StateId(stateId) ValueState<List<MyInteger>> 
state) {
+            MyInteger myInteger = new MyInteger(c.element().getValue());
+            List<MyInteger> currentValue = state.read();
+            List<MyInteger> newValue = currentValue != null
+                ? 
ImmutableList.<MyInteger>builder().addAll(currentValue).add(myInteger).build()
+                : Collections.singletonList(myInteger);
+            c.output(newValue);
+            state.write(newValue);
+          }
+        };
+
+    pipeline.apply(Create.of(KV.of("hello", 42), KV.of("hello", 97), 
KV.of("hello", 84)))
+        .apply(ParDo.of(fn)).setCoder(ListCoder.of(myIntegerCoder));
+
+    pipeline.run();
+  }
+
+  @Test
+  @Category({ValidatesRunner.class, UsesStatefulParDo.class})
   public void testValueStateFixedWindows() {
     final String stateId = "foo";
 
@@ -1801,6 +1997,82 @@ public class ParDoTest implements Serializable {
   }
 
   @Test
+  @Category({ValidatesRunner.class, UsesStatefulParDo.class})
+  public void testBagStateCoderInference() {
+    final String stateId = "foo";
+    Coder<MyInteger> myIntegerCoder = MyIntegerCoder.of();
+    pipeline.getCoderRegistry().registerCoder(MyInteger.class, myIntegerCoder);
+
+    DoFn<KV<String, Integer>, List<MyInteger>> fn =
+        new DoFn<KV<String, Integer>, List<MyInteger>>() {
+
+          @StateId(stateId)
+          private final StateSpec<Object, BagState<MyInteger>> bufferState =
+              StateSpecs.bag();
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c, @StateId(stateId) BagState<MyInteger> state) {
+            Iterable<MyInteger> currentValue = state.read();
+            state.add(new MyInteger(c.element().getValue()));
+            if (Iterables.size(state.read()) >= 4) {
+              List<MyInteger> sorted = Lists.newArrayList(currentValue);
+              Collections.sort(sorted);
+              c.output(sorted);
+            }
+          }
+        };
+
+    PCollection<List<MyInteger>> output =
+        pipeline.apply(
+            Create.of(
+                KV.of("hello", 97), KV.of("hello", 42), KV.of("hello", 84), 
KV.of("hello", 12)))
+            .apply(ParDo.of(fn)).setCoder(ListCoder.of(myIntegerCoder));
+
+    PAssert.that(output).containsInAnyOrder(Lists.newArrayList(new 
MyInteger(12), new MyInteger(42),
+        new MyInteger(84), new MyInteger(97)));
+
+    pipeline.run();
+  }
+
+  @Test
+  @Category({ValidatesRunner.class, UsesStatefulParDo.class})
+  public void testBagStateCoderInferenceFailure() throws Exception {
+    final String stateId = "foo";
+    Coder<MyInteger> myIntegerCoder = MyIntegerCoder.of();
+
+    DoFn<KV<String, Integer>, List<MyInteger>> fn =
+        new DoFn<KV<String, Integer>, List<MyInteger>>() {
+
+          @StateId(stateId)
+          private final StateSpec<Object, BagState<MyInteger>> bufferState =
+              StateSpecs.bag();
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c, @StateId(stateId) BagState<MyInteger> state) {
+            Iterable<MyInteger> currentValue = state.read();
+            state.add(new MyInteger(c.element().getValue()));
+            if (Iterables.size(state.read()) >= 4) {
+              List<MyInteger> sorted = Lists.newArrayList(currentValue);
+              Collections.sort(sorted);
+              c.output(sorted);
+            }
+          }
+        };
+
+    thrown.expect(RuntimeException.class);
+    thrown.expectMessage("Unable to infer a coder for BagState and no Coder 
was specified.");
+
+    pipeline.apply(
+        Create.of(
+            KV.of("hello", 97), KV.of("hello", 42), KV.of("hello", 84), 
KV.of("hello", 12)))
+        .apply(ParDo.of(fn)).setCoder(ListCoder.of(myIntegerCoder));
+
+    pipeline.run();
+  }
+
+  @Test
   @Category({ValidatesRunner.class, UsesStatefulParDo.class, 
UsesSetState.class})
   public void testSetState() {
     final String stateId = "foo";
@@ -1843,6 +2115,93 @@ public class ParDoTest implements Serializable {
   }
 
   @Test
+  @Category({ValidatesRunner.class, UsesStatefulParDo.class, 
UsesSetState.class})
+  public void testSetStateCoderInference() {
+    final String stateId = "foo";
+    final String countStateId = "count";
+    Coder<MyInteger> myIntegerCoder = MyIntegerCoder.of();
+    pipeline.getCoderRegistry().registerCoder(MyInteger.class, myIntegerCoder);
+
+    DoFn<KV<String, Integer>, Set<MyInteger>> fn =
+        new DoFn<KV<String, Integer>, Set<MyInteger>>() {
+
+          @StateId(stateId)
+          private final StateSpec<Object, SetState<MyInteger>> setState = 
StateSpecs.set();
+
+          @StateId(countStateId)
+          private final StateSpec<Object, AccumulatorCombiningState<Integer, 
int[], Integer>>
+              countState = 
StateSpecs.combiningValueFromInputInternal(VarIntCoder.of(),
+              Sum.ofIntegers());
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c,
+              @StateId(stateId) SetState<MyInteger> state,
+              @StateId(countStateId) AccumulatorCombiningState<Integer, int[], 
Integer> count) {
+            state.add(new MyInteger(c.element().getValue()));
+            count.add(1);
+            if (count.read() >= 4) {
+              Set<MyInteger> set = Sets.newHashSet(state.read());
+              c.output(set);
+            }
+          }
+        };
+
+    PCollection<Set<MyInteger>> output =
+        pipeline.apply(
+            Create.of(
+                KV.of("hello", 97), KV.of("hello", 42), KV.of("hello", 42), 
KV.of("hello", 12)))
+            .apply(ParDo.of(fn)).setCoder(SetCoder.of(myIntegerCoder));
+
+    PAssert.that(output).containsInAnyOrder(
+        Sets.newHashSet(new MyInteger(97), new MyInteger(42), new 
MyInteger(12)));
+    pipeline.run();
+  }
+
+  @Test
+  @Category({ValidatesRunner.class, UsesStatefulParDo.class, 
UsesSetState.class})
+  public void testSetStateCoderInferenceFailure() throws Exception {
+    final String stateId = "foo";
+    final String countStateId = "count";
+    Coder<MyInteger> myIntegerCoder = MyIntegerCoder.of();
+
+    DoFn<KV<String, Integer>, Set<MyInteger>> fn =
+        new DoFn<KV<String, Integer>, Set<MyInteger>>() {
+
+          @StateId(stateId)
+          private final StateSpec<Object, SetState<MyInteger>> setState = 
StateSpecs.set();
+
+          @StateId(countStateId)
+          private final StateSpec<Object, AccumulatorCombiningState<Integer, 
int[], Integer>>
+              countState = 
StateSpecs.combiningValueFromInputInternal(VarIntCoder.of(),
+              Sum.ofIntegers());
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c,
+              @StateId(stateId) SetState<MyInteger> state,
+              @StateId(countStateId) AccumulatorCombiningState<Integer, int[], 
Integer> count) {
+            state.add(new MyInteger(c.element().getValue()));
+            count.add(1);
+            if (count.read() >= 4) {
+              Set<MyInteger> set = Sets.newHashSet(state.read());
+              c.output(set);
+            }
+          }
+        };
+
+    thrown.expect(RuntimeException.class);
+    thrown.expectMessage("Unable to infer a coder for SetState and no Coder 
was specified.");
+
+    pipeline.apply(
+        Create.of(
+            KV.of("hello", 97), KV.of("hello", 42), KV.of("hello", 42), 
KV.of("hello", 12)))
+        .apply(ParDo.of(fn)).setCoder(SetCoder.of(myIntegerCoder));
+
+    pipeline.run();
+  }
+
+  @Test
   @Category({ValidatesRunner.class, UsesStatefulParDo.class, 
UsesMapState.class})
   public void testMapState() {
     final String stateId = "foo";
@@ -1888,6 +2247,99 @@ public class ParDoTest implements Serializable {
   }
 
   @Test
+  @Category({ValidatesRunner.class, UsesStatefulParDo.class, 
UsesMapState.class})
+  public void testMapStateCoderInference() {
+    final String stateId = "foo";
+    final String countStateId = "count";
+    Coder<MyInteger> myIntegerCoder = MyIntegerCoder.of();
+    pipeline.getCoderRegistry().registerCoder(MyInteger.class, myIntegerCoder);
+
+    DoFn<KV<String, KV<String, Integer>>, KV<String, MyInteger>> fn =
+        new DoFn<KV<String, KV<String, Integer>>, KV<String, MyInteger>>() {
+
+          @StateId(stateId)
+          private final StateSpec<Object, MapState<String, MyInteger>> 
mapState = StateSpecs.map();
+          @StateId(countStateId)
+          private final StateSpec<Object, AccumulatorCombiningState<Integer, 
int[], Integer>>
+              countState = 
StateSpecs.combiningValueFromInputInternal(VarIntCoder.of(),
+              Sum.ofIntegers());
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c, @StateId(stateId) MapState<String, MyInteger> 
state,
+              @StateId(countStateId) AccumulatorCombiningState<Integer, int[], 
Integer>
+                  count) {
+            KV<String, Integer> value = c.element().getValue();
+            state.put(value.getKey(), new MyInteger(value.getValue()));
+            count.add(1);
+            if (count.read() >= 4) {
+              Iterable<Map.Entry<String, MyInteger>> iterate = state.iterate();
+              for (Map.Entry<String, MyInteger> entry : iterate) {
+                c.output(KV.of(entry.getKey(), entry.getValue()));
+              }
+            }
+          }
+        };
+
+    PCollection<KV<String, MyInteger>> output =
+        pipeline.apply(
+            Create.of(
+                KV.of("hello", KV.of("a", 97)), KV.of("hello", KV.of("b", 42)),
+                KV.of("hello", KV.of("b", 42)), KV.of("hello", KV.of("c", 
12))))
+            .apply(ParDo.of(fn)).setCoder(KvCoder.of(StringUtf8Coder.of(), 
myIntegerCoder));
+
+    PAssert.that(output).containsInAnyOrder(KV.of("a", new MyInteger(97)),
+        KV.of("b", new MyInteger(42)), KV.of("c", new MyInteger(12)));
+    pipeline.run();
+  }
+
+  @Test
+  @Category({ValidatesRunner.class, UsesStatefulParDo.class, 
UsesMapState.class})
+  public void testMapStateCoderInferenceFailure() throws Exception {
+    final String stateId = "foo";
+    final String countStateId = "count";
+    Coder<MyInteger> myIntegerCoder = MyIntegerCoder.of();
+
+    DoFn<KV<String, KV<String, Integer>>, KV<String, MyInteger>> fn =
+        new DoFn<KV<String, KV<String, Integer>>, KV<String, MyInteger>>() {
+
+          @StateId(stateId)
+          private final StateSpec<Object, MapState<String, MyInteger>> 
mapState = StateSpecs.map();
+          @StateId(countStateId)
+          private final StateSpec<Object, AccumulatorCombiningState<Integer, 
int[], Integer>>
+              countState = 
StateSpecs.combiningValueFromInputInternal(VarIntCoder.of(),
+              Sum.ofIntegers());
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c, @StateId(stateId) MapState<String, MyInteger> 
state,
+              @StateId(countStateId) AccumulatorCombiningState<Integer, int[], 
Integer>
+                  count) {
+            KV<String, Integer> value = c.element().getValue();
+            state.put(value.getKey(), new MyInteger(value.getValue()));
+            count.add(1);
+            if (count.read() >= 4) {
+              Iterable<Map.Entry<String, MyInteger>> iterate = state.iterate();
+              for (Map.Entry<String, MyInteger> entry : iterate) {
+                c.output(KV.of(entry.getKey(), entry.getValue()));
+              }
+            }
+          }
+        };
+
+    thrown.expect(RuntimeException.class);
+    thrown.expectMessage("Unable to infer a coder for MapState and no Coder 
was specified.");
+
+    pipeline.apply(
+        Create.of(
+            KV.of("hello", KV.of("a", 97)), KV.of("hello", KV.of("b", 42)),
+            KV.of("hello", KV.of("b", 42)), KV.of("hello", KV.of("c", 12))))
+        .apply(ParDo.of(fn)).setCoder(KvCoder.of(StringUtf8Coder.of(), 
myIntegerCoder));
+
+    pipeline.run();
+  }
+
+  @Test
   @Category({ValidatesRunner.class, UsesStatefulParDo.class})
   public void testCombiningState() {
     final String stateId = "foo";
@@ -1928,6 +2380,132 @@ public class ParDoTest implements Serializable {
 
   @Test
   @Category({ValidatesRunner.class, UsesStatefulParDo.class})
+  public void testCombiningStateCoderInference() {
+    pipeline.getCoderRegistry().registerCoder(MyInteger.class, 
MyIntegerCoder.of());
+
+    final String stateId = "foo";
+
+    DoFn<KV<String, Integer>, String> fn =
+        new DoFn<KV<String, Integer>, String>() {
+          private static final int EXPECTED_SUM = 16;
+
+          @StateId(stateId)
+          private final StateSpec<
+              Object, AccumulatorCombiningState<Integer, MyInteger, Integer>>
+              combiningState =
+              StateSpecs.combiningValue(new Combine.CombineFn<Integer, 
MyInteger, Integer>() {
+                @Override
+                public MyInteger createAccumulator() {
+                  return new MyInteger(0);
+                }
+
+                @Override
+                public MyInteger addInput(MyInteger accumulator, Integer 
input) {
+                  return new MyInteger(accumulator.getValue() + input);
+                }
+
+                @Override
+                public MyInteger mergeAccumulators(Iterable<MyInteger> 
accumulators) {
+                  int newValue = 0;
+                  for (MyInteger myInteger : accumulators) {
+                    newValue += myInteger.getValue();
+                  }
+                  return new MyInteger(newValue);
+                }
+
+                @Override
+                public Integer extractOutput(MyInteger accumulator) {
+                  return accumulator.getValue();
+                }
+              });
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c,
+              @StateId(stateId)
+                  AccumulatorCombiningState<Integer, MyInteger, Integer> 
state) {
+            state.add(c.element().getValue());
+            Integer currentValue = state.read();
+            if (currentValue == EXPECTED_SUM) {
+              c.output("right on");
+            }
+          }
+        };
+
+    PCollection<String> output =
+        pipeline
+            .apply(Create.of(KV.of("hello", 3), KV.of("hello", 6), 
KV.of("hello", 7)))
+            .apply(ParDo.of(fn));
+
+    // There should only be one moment at which the average is exactly 16
+    PAssert.that(output).containsInAnyOrder("right on");
+    pipeline.run();
+  }
+
+  @Test
+  @Category({ValidatesRunner.class, UsesStatefulParDo.class})
+  public void testCombiningStateCoderInferenceFailure() throws Exception {
+    final String stateId = "foo";
+
+    DoFn<KV<String, Integer>, String> fn =
+        new DoFn<KV<String, Integer>, String>() {
+          private static final int EXPECTED_SUM = 16;
+
+          @StateId(stateId)
+          private final StateSpec<
+              Object, AccumulatorCombiningState<Integer, MyInteger, Integer>>
+              combiningState =
+              StateSpecs.combiningValue(new Combine.CombineFn<Integer, 
MyInteger, Integer>() {
+                @Override
+                public MyInteger createAccumulator() {
+                  return new MyInteger(0);
+                }
+
+                @Override
+                public MyInteger addInput(MyInteger accumulator, Integer 
input) {
+                  return new MyInteger(accumulator.getValue() + input);
+                }
+
+                @Override
+                public MyInteger mergeAccumulators(Iterable<MyInteger> 
accumulators) {
+                  int newValue = 0;
+                  for (MyInteger myInteger : accumulators) {
+                    newValue += myInteger.getValue();
+                  }
+                  return new MyInteger(newValue);
+                }
+
+                @Override
+                public Integer extractOutput(MyInteger accumulator) {
+                  return accumulator.getValue();
+                }
+              });
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c,
+              @StateId(stateId)
+                  AccumulatorCombiningState<Integer, MyInteger, Integer> 
state) {
+            state.add(c.element().getValue());
+            Integer currentValue = state.read();
+            if (currentValue == EXPECTED_SUM) {
+              c.output("right on");
+            }
+          }
+        };
+
+    thrown.expect(RuntimeException.class);
+    thrown.expectMessage("Unable to infer a coder for CombiningState and no 
Coder was specified.");
+
+    pipeline
+        .apply(Create.of(KV.of("hello", 3), KV.of("hello", 6), KV.of("hello", 
7)))
+        .apply(ParDo.of(fn));
+
+    pipeline.run();
+  }
+
+  @Test
+  @Category({ValidatesRunner.class, UsesStatefulParDo.class})
   public void testBagStateSideInput() {
 
     final PCollectionView<List<Integer>> listView =

Reply via email to