Repository: beam
Updated Branches:
  refs/heads/master f03f6ac19 -> 137fee95b


http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
 
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
new file mode 100644
index 0000000..9033ba7
--- /dev/null
+++ 
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
@@ -0,0 +1,1053 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.translation.wrappers.streaming.state;
+
+import com.google.common.collect.Lists;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.InstantCoder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.CombineWithContext;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.OutputTimeFn;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.sdk.util.CombineContextFactory;
+import org.apache.beam.sdk.util.state.AccumulatorCombiningState;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.MapState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.SetState;
+import org.apache.beam.sdk.util.state.State;
+import org.apache.beam.sdk.util.state.StateContext;
+import org.apache.beam.sdk.util.state.StateContexts;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.util.state.WatermarkHoldState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.joda.time.Instant;
+
+/**
+ * {@link StateInternals} that uses a Flink {@link KeyedStateBackend} to 
manage state.
+ *
+ * <p>Note: In the Flink streaming runner the key is always encoded
+ * using an {@link Coder} and stored in a {@link ByteBuffer}.
+ */
+public class FlinkStateInternals<K> implements StateInternals<K> {
+
+  private final KeyedStateBackend<ByteBuffer> flinkStateBackend;
+  private Coder<K> keyCoder;
+
+  // on recovery, these will no be properly set because we don't
+  // know which watermark hold states there are in the Flink State Backend
+  private final Map<String, Instant> watermarkHolds = new HashMap<>();
+
+  public FlinkStateInternals(KeyedStateBackend<ByteBuffer> flinkStateBackend, 
Coder<K> keyCoder) {
+    this.flinkStateBackend = flinkStateBackend;
+    this.keyCoder = keyCoder;
+  }
+
+  /**
+   * Returns the minimum over all watermark holds.
+   */
+  public Instant watermarkHold() {
+    long min = Long.MAX_VALUE;
+    for (Instant hold: watermarkHolds.values()) {
+      min = Math.min(min, hold.getMillis());
+    }
+    return new Instant(min);
+  }
+
+  @Override
+  public K getKey() {
+    ByteBuffer keyBytes = flinkStateBackend.getCurrentKey();
+    try {
+      return CoderUtils.decodeFromByteArray(keyCoder, keyBytes.array());
+    } catch (CoderException e) {
+      throw new RuntimeException("Error decoding key.", e);
+    }
+  }
+
+  @Override
+  public <T extends State> T state(
+      final StateNamespace namespace,
+      StateTag<? super K, T> address) {
+
+    return state(namespace, address, StateContexts.nullContext());
+  }
+
+  @Override
+  public <T extends State> T state(
+      final StateNamespace namespace,
+      StateTag<? super K, T> address,
+      final StateContext<?> context) {
+
+    return address.bind(new StateTag.StateBinder<K>() {
+
+      @Override
+      public <T> ValueState<T> bindValue(
+          StateTag<? super K, ValueState<T>> address,
+          Coder<T> coder) {
+
+        return new FlinkValueState<>(flinkStateBackend, address, namespace, 
coder);
+      }
+
+      @Override
+      public <T> BagState<T> bindBag(
+          StateTag<? super K, BagState<T>> address,
+          Coder<T> elemCoder) {
+
+        return new FlinkBagState<>(flinkStateBackend, address, namespace, 
elemCoder);
+      }
+
+      @Override
+      public <T> SetState<T> bindSet(
+          StateTag<? super K, SetState<T>> address,
+          Coder<T> elemCoder) {
+        throw new UnsupportedOperationException(
+            String.format("%s is not supported", 
SetState.class.getSimpleName()));
+      }
+
+      @Override
+      public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+          StateTag<? super K, MapState<KeyT, ValueT>> spec,
+          Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) {
+        throw new UnsupportedOperationException(
+            String.format("%s is not supported", 
MapState.class.getSimpleName()));
+      }
+
+      @Override
+      public <InputT, AccumT, OutputT>
+          AccumulatorCombiningState<InputT, AccumT, OutputT>
+      bindCombiningValue(
+          StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, 
OutputT>> address,
+          Coder<AccumT> accumCoder,
+          Combine.CombineFn<InputT, AccumT, OutputT> combineFn) {
+
+        return new FlinkAccumulatorCombiningState<>(
+            flinkStateBackend, address, combineFn, namespace, accumCoder);
+      }
+
+      @Override
+      public <InputT, AccumT, OutputT>
+          AccumulatorCombiningState<InputT, AccumT, OutputT> 
bindKeyedCombiningValue(
+          StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, 
OutputT>> address,
+          Coder<AccumT> accumCoder,
+          final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> 
combineFn) {
+        return new FlinkKeyedAccumulatorCombiningState<>(
+            flinkStateBackend,
+            address,
+            combineFn,
+            namespace,
+            accumCoder,
+            FlinkStateInternals.this);
+      }
+
+      @Override
+      public <InputT, AccumT, OutputT>
+          AccumulatorCombiningState<InputT, AccumT, OutputT> 
bindKeyedCombiningValueWithContext(
+          StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, 
OutputT>> address,
+          Coder<AccumT> accumCoder,
+          CombineWithContext.KeyedCombineFnWithContext<
+              ? super K, InputT, AccumT, OutputT> combineFn) {
+        return new FlinkAccumulatorCombiningStateWithContext<>(
+            flinkStateBackend,
+            address,
+            combineFn,
+            namespace,
+            accumCoder,
+            FlinkStateInternals.this,
+            CombineContextFactory.createFromStateContext(context));
+      }
+
+      @Override
+      public <W extends BoundedWindow> WatermarkHoldState<W> bindWatermark(
+          StateTag<? super K, WatermarkHoldState<W>> address,
+          OutputTimeFn<? super W> outputTimeFn) {
+
+        return new FlinkWatermarkHoldState<>(
+            flinkStateBackend, FlinkStateInternals.this, address, namespace, 
outputTimeFn);
+      }
+    });
+  }
+
+  private static class FlinkValueState<K, T> implements ValueState<T> {
+
+    private final StateNamespace namespace;
+    private final StateTag<? super K, ValueState<T>> address;
+    private final ValueStateDescriptor<T> flinkStateDescriptor;
+    private final KeyedStateBackend<ByteBuffer> flinkStateBackend;
+
+    FlinkValueState(
+        KeyedStateBackend<ByteBuffer> flinkStateBackend,
+        StateTag<? super K, ValueState<T>> address,
+        StateNamespace namespace,
+        Coder<T> coder) {
+
+      this.namespace = namespace;
+      this.address = address;
+      this.flinkStateBackend = flinkStateBackend;
+
+      CoderTypeInformation<T> typeInfo = new CoderTypeInformation<>(coder);
+
+      flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), 
typeInfo, null);
+    }
+
+    @Override
+    public void write(T input) {
+      try {
+        flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).update(input);
+      } catch (Exception e) {
+        throw new RuntimeException("Error updating state.", e);
+      }
+    }
+
+    @Override
+    public ValueState<T> readLater() {
+      return this;
+    }
+
+    @Override
+    public T read() {
+      try {
+        return flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).value();
+      } catch (Exception e) {
+        throw new RuntimeException("Error reading state.", e);
+      }
+    }
+
+    @Override
+    public void clear() {
+      try {
+        flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).clear();
+      } catch (Exception e) {
+        throw new RuntimeException("Error clearing state.", e);
+      }
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+
+      FlinkValueState<?, ?> that = (FlinkValueState<?, ?>) o;
+
+      return namespace.equals(that.namespace) && address.equals(that.address);
+
+    }
+
+    @Override
+    public int hashCode() {
+      int result = namespace.hashCode();
+      result = 31 * result + address.hashCode();
+      return result;
+    }
+  }
+
+  private static class FlinkBagState<K, T> implements BagState<T> {
+
+    private final StateNamespace namespace;
+    private final StateTag<? super K, BagState<T>> address;
+    private final ListStateDescriptor<T> flinkStateDescriptor;
+    private final KeyedStateBackend<ByteBuffer> flinkStateBackend;
+
+    FlinkBagState(
+        KeyedStateBackend<ByteBuffer> flinkStateBackend,
+        StateTag<? super K, BagState<T>> address,
+        StateNamespace namespace,
+        Coder<T> coder) {
+
+      this.namespace = namespace;
+      this.address = address;
+      this.flinkStateBackend = flinkStateBackend;
+
+      CoderTypeInformation<T> typeInfo = new CoderTypeInformation<>(coder);
+
+      flinkStateDescriptor = new ListStateDescriptor<>(address.getId(), 
typeInfo);
+    }
+
+    @Override
+    public void add(T input) {
+      try {
+        flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).add(input);
+      } catch (Exception e) {
+        throw new RuntimeException("Error adding to bag state.", e);
+      }
+    }
+
+    @Override
+    public BagState<T> readLater() {
+      return this;
+    }
+
+    @Override
+    public Iterable<T> read() {
+      try {
+        Iterable<T> result = flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).get();
+
+        return result != null ? result : Collections.<T>emptyList();
+      } catch (Exception e) {
+        throw new RuntimeException("Error reading state.", e);
+      }
+    }
+
+    @Override
+    public ReadableState<Boolean> isEmpty() {
+      return new ReadableState<Boolean>() {
+        @Override
+        public Boolean read() {
+          try {
+            Iterable<T> result = flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor).get();
+            return result == null;
+          } catch (Exception e) {
+            throw new RuntimeException("Error reading state.", e);
+          }
+
+        }
+
+        @Override
+        public ReadableState<Boolean> readLater() {
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public void clear() {
+      try {
+        flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).clear();
+      } catch (Exception e) {
+        throw new RuntimeException("Error clearing state.", e);
+      }
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+
+      FlinkBagState<?, ?> that = (FlinkBagState<?, ?>) o;
+
+      return namespace.equals(that.namespace) && address.equals(that.address);
+
+    }
+
+    @Override
+    public int hashCode() {
+      int result = namespace.hashCode();
+      result = 31 * result + address.hashCode();
+      return result;
+    }
+  }
+
+  private static class FlinkAccumulatorCombiningState<K, InputT, AccumT, 
OutputT>
+      implements AccumulatorCombiningState<InputT, AccumT, OutputT> {
+
+    private final StateNamespace namespace;
+    private final StateTag<? super K, AccumulatorCombiningState<InputT, 
AccumT, OutputT>> address;
+    private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn;
+    private final ValueStateDescriptor<AccumT> flinkStateDescriptor;
+    private final KeyedStateBackend<ByteBuffer> flinkStateBackend;
+
+    FlinkAccumulatorCombiningState(
+        KeyedStateBackend<ByteBuffer> flinkStateBackend,
+        StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, 
OutputT>> address,
+        Combine.CombineFn<InputT, AccumT, OutputT> combineFn,
+        StateNamespace namespace,
+        Coder<AccumT> accumCoder) {
+
+      this.namespace = namespace;
+      this.address = address;
+      this.combineFn = combineFn;
+      this.flinkStateBackend = flinkStateBackend;
+
+      CoderTypeInformation<AccumT> typeInfo = new 
CoderTypeInformation<>(accumCoder);
+
+      flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), 
typeInfo, null);
+    }
+
+    @Override
+    public AccumulatorCombiningState<InputT, AccumT, OutputT> readLater() {
+      return this;
+    }
+
+    @Override
+    public void add(InputT value) {
+      try {
+        org.apache.flink.api.common.state.ValueState<AccumT> state =
+            flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor);
+
+        AccumT current = state.value();
+        if (current == null) {
+          current = combineFn.createAccumulator();
+        }
+        current = combineFn.addInput(current, value);
+        state.update(current);
+      } catch (Exception e) {
+        throw new RuntimeException("Error adding to state." , e);
+      }
+    }
+
+    @Override
+    public void addAccum(AccumT accum) {
+      try {
+        org.apache.flink.api.common.state.ValueState<AccumT> state =
+            flinkStateBackend.getPartitionedState(
+              namespace.stringKey(),
+              StringSerializer.INSTANCE,
+              flinkStateDescriptor);
+
+        AccumT current = state.value();
+        if (current == null) {
+          state.update(accum);
+        } else {
+          current = combineFn.mergeAccumulators(Lists.newArrayList(current, 
accum));
+          state.update(current);
+        }
+      } catch (Exception e) {
+        throw new RuntimeException("Error adding to state.", e);
+      }
+    }
+
+    @Override
+    public AccumT getAccum() {
+      try {
+        return flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).value();
+      } catch (Exception e) {
+        throw new RuntimeException("Error reading state.", e);
+      }
+    }
+
+    @Override
+    public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+      return combineFn.mergeAccumulators(accumulators);
+    }
+
+    @Override
+    public OutputT read() {
+      try {
+        org.apache.flink.api.common.state.ValueState<AccumT> state =
+            flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor);
+
+        AccumT accum = state.value();
+        if (accum != null) {
+          return combineFn.extractOutput(accum);
+        } else {
+          return combineFn.extractOutput(combineFn.createAccumulator());
+        }
+      } catch (Exception e) {
+        throw new RuntimeException("Error reading state.", e);
+      }
+    }
+
+    @Override
+    public ReadableState<Boolean> isEmpty() {
+      return new ReadableState<Boolean>() {
+        @Override
+        public Boolean read() {
+          try {
+            return flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor).value() == null;
+          } catch (Exception e) {
+            throw new RuntimeException("Error reading state.", e);
+          }
+
+        }
+
+        @Override
+        public ReadableState<Boolean> readLater() {
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public void clear() {
+      try {
+        flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).clear();
+      } catch (Exception e) {
+        throw new RuntimeException("Error clearing state.", e);
+      }
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+
+      FlinkAccumulatorCombiningState<?, ?, ?, ?> that =
+          (FlinkAccumulatorCombiningState<?, ?, ?, ?>) o;
+
+      return namespace.equals(that.namespace) && address.equals(that.address);
+
+    }
+
+    @Override
+    public int hashCode() {
+      int result = namespace.hashCode();
+      result = 31 * result + address.hashCode();
+      return result;
+    }
+  }
+
+  private static class FlinkKeyedAccumulatorCombiningState<K, InputT, AccumT, 
OutputT>
+      implements AccumulatorCombiningState<InputT, AccumT, OutputT> {
+
+    private final StateNamespace namespace;
+    private final StateTag<? super K, AccumulatorCombiningState<InputT, 
AccumT, OutputT>> address;
+    private final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> 
combineFn;
+    private final ValueStateDescriptor<AccumT> flinkStateDescriptor;
+    private final KeyedStateBackend<ByteBuffer> flinkStateBackend;
+    private final FlinkStateInternals<K> flinkStateInternals;
+
+    FlinkKeyedAccumulatorCombiningState(
+        KeyedStateBackend<ByteBuffer> flinkStateBackend,
+        StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, 
OutputT>> address,
+        Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn,
+        StateNamespace namespace,
+        Coder<AccumT> accumCoder,
+        FlinkStateInternals<K> flinkStateInternals) {
+
+      this.namespace = namespace;
+      this.address = address;
+      this.combineFn = combineFn;
+      this.flinkStateBackend = flinkStateBackend;
+      this.flinkStateInternals = flinkStateInternals;
+
+      CoderTypeInformation<AccumT> typeInfo = new 
CoderTypeInformation<>(accumCoder);
+
+      flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), 
typeInfo, null);
+    }
+
+    @Override
+    public AccumulatorCombiningState<InputT, AccumT, OutputT> readLater() {
+      return this;
+    }
+
+    @Override
+    public void add(InputT value) {
+      try {
+        org.apache.flink.api.common.state.ValueState<AccumT> state =
+            flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor);
+
+        AccumT current = state.value();
+        if (current == null) {
+          current = combineFn.createAccumulator(flinkStateInternals.getKey());
+        }
+        current = combineFn.addInput(flinkStateInternals.getKey(), current, 
value);
+        state.update(current);
+      } catch (Exception e) {
+        throw new RuntimeException("Error adding to state." , e);
+      }
+    }
+
+    @Override
+    public void addAccum(AccumT accum) {
+      try {
+        org.apache.flink.api.common.state.ValueState<AccumT> state =
+            flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor);
+
+        AccumT current = state.value();
+        if (current == null) {
+          state.update(accum);
+        } else {
+          current = combineFn.mergeAccumulators(
+              flinkStateInternals.getKey(),
+              Lists.newArrayList(current, accum));
+          state.update(current);
+        }
+      } catch (Exception e) {
+        throw new RuntimeException("Error adding to state.", e);
+      }
+    }
+
+    @Override
+    public AccumT getAccum() {
+      try {
+        return flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).value();
+      } catch (Exception e) {
+        throw new RuntimeException("Error reading state.", e);
+      }
+    }
+
+    @Override
+    public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+      return combineFn.mergeAccumulators(flinkStateInternals.getKey(), 
accumulators);
+    }
+
+    @Override
+    public OutputT read() {
+      try {
+        org.apache.flink.api.common.state.ValueState<AccumT> state =
+            flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor);
+
+        AccumT accum = state.value();
+        if (accum != null) {
+          return combineFn.extractOutput(flinkStateInternals.getKey(), accum);
+        } else {
+          return combineFn.extractOutput(
+              flinkStateInternals.getKey(),
+              combineFn.createAccumulator(flinkStateInternals.getKey()));
+        }
+      } catch (Exception e) {
+        throw new RuntimeException("Error reading state.", e);
+      }
+    }
+
+    @Override
+    public ReadableState<Boolean> isEmpty() {
+      return new ReadableState<Boolean>() {
+        @Override
+        public Boolean read() {
+          try {
+            return flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor).value() == null;
+          } catch (Exception e) {
+            throw new RuntimeException("Error reading state.", e);
+          }
+
+        }
+
+        @Override
+        public ReadableState<Boolean> readLater() {
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public void clear() {
+      try {
+        flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).clear();
+      } catch (Exception e) {
+        throw new RuntimeException("Error clearing state.", e);
+      }
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+
+      FlinkKeyedAccumulatorCombiningState<?, ?, ?, ?> that =
+          (FlinkKeyedAccumulatorCombiningState<?, ?, ?, ?>) o;
+
+      return namespace.equals(that.namespace) && address.equals(that.address);
+
+    }
+
+    @Override
+    public int hashCode() {
+      int result = namespace.hashCode();
+      result = 31 * result + address.hashCode();
+      return result;
+    }
+  }
+
+  private static class FlinkAccumulatorCombiningStateWithContext<K, InputT, 
AccumT, OutputT>
+      implements AccumulatorCombiningState<InputT, AccumT, OutputT> {
+
+    private final StateNamespace namespace;
+    private final StateTag<? super K, AccumulatorCombiningState<InputT, 
AccumT, OutputT>> address;
+    private final CombineWithContext.KeyedCombineFnWithContext<
+        ? super K, InputT, AccumT, OutputT> combineFn;
+    private final ValueStateDescriptor<AccumT> flinkStateDescriptor;
+    private final KeyedStateBackend<ByteBuffer> flinkStateBackend;
+    private final FlinkStateInternals<K> flinkStateInternals;
+    private final CombineWithContext.Context context;
+
+    FlinkAccumulatorCombiningStateWithContext(
+        KeyedStateBackend<ByteBuffer> flinkStateBackend,
+        StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, 
OutputT>> address,
+        CombineWithContext.KeyedCombineFnWithContext<
+            ? super K, InputT, AccumT, OutputT> combineFn,
+        StateNamespace namespace,
+        Coder<AccumT> accumCoder,
+        FlinkStateInternals<K> flinkStateInternals,
+        CombineWithContext.Context context) {
+
+      this.namespace = namespace;
+      this.address = address;
+      this.combineFn = combineFn;
+      this.flinkStateBackend = flinkStateBackend;
+      this.flinkStateInternals = flinkStateInternals;
+      this.context = context;
+
+      CoderTypeInformation<AccumT> typeInfo = new 
CoderTypeInformation<>(accumCoder);
+
+      flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), 
typeInfo, null);
+    }
+
+    @Override
+    public AccumulatorCombiningState<InputT, AccumT, OutputT> readLater() {
+      return this;
+    }
+
+    @Override
+    public void add(InputT value) {
+      try {
+        org.apache.flink.api.common.state.ValueState<AccumT> state =
+            flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor);
+
+        AccumT current = state.value();
+        if (current == null) {
+          current = combineFn.createAccumulator(flinkStateInternals.getKey(), 
context);
+        }
+        current = combineFn.addInput(flinkStateInternals.getKey(), current, 
value, context);
+        state.update(current);
+      } catch (Exception e) {
+        throw new RuntimeException("Error adding to state." , e);
+      }
+    }
+
+    @Override
+    public void addAccum(AccumT accum) {
+      try {
+        org.apache.flink.api.common.state.ValueState<AccumT> state =
+            flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor);
+
+        AccumT current = state.value();
+        if (current == null) {
+          state.update(accum);
+        } else {
+          current = combineFn.mergeAccumulators(
+              flinkStateInternals.getKey(),
+              Lists.newArrayList(current, accum),
+              context);
+          state.update(current);
+        }
+      } catch (Exception e) {
+        throw new RuntimeException("Error adding to state.", e);
+      }
+    }
+
+    @Override
+    public AccumT getAccum() {
+      try {
+        return flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).value();
+      } catch (Exception e) {
+        throw new RuntimeException("Error reading state.", e);
+      }
+    }
+
+    @Override
+    public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+      return combineFn.mergeAccumulators(flinkStateInternals.getKey(), 
accumulators, context);
+    }
+
+    @Override
+    public OutputT read() {
+      try {
+        org.apache.flink.api.common.state.ValueState<AccumT> state =
+            flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor);
+
+        AccumT accum = state.value();
+        return combineFn.extractOutput(flinkStateInternals.getKey(), accum, 
context);
+      } catch (Exception e) {
+        throw new RuntimeException("Error reading state.", e);
+      }
+    }
+
+    @Override
+    public ReadableState<Boolean> isEmpty() {
+      return new ReadableState<Boolean>() {
+        @Override
+        public Boolean read() {
+          try {
+            return flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor).value() == null;
+          } catch (Exception e) {
+            throw new RuntimeException("Error reading state.", e);
+          }
+
+        }
+
+        @Override
+        public ReadableState<Boolean> readLater() {
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public void clear() {
+      try {
+        flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).clear();
+      } catch (Exception e) {
+        throw new RuntimeException("Error clearing state.", e);
+      }
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+
+      FlinkAccumulatorCombiningStateWithContext<?, ?, ?, ?> that =
+          (FlinkAccumulatorCombiningStateWithContext<?, ?, ?, ?>) o;
+
+      return namespace.equals(that.namespace) && address.equals(that.address);
+
+    }
+
+    @Override
+    public int hashCode() {
+      int result = namespace.hashCode();
+      result = 31 * result + address.hashCode();
+      return result;
+    }
+  }
+
+  private static class FlinkWatermarkHoldState<K, W extends BoundedWindow>
+      implements WatermarkHoldState<W> {
+    private final StateTag<? super K, WatermarkHoldState<W>> address;
+    private final OutputTimeFn<? super W> outputTimeFn;
+    private final StateNamespace namespace;
+    private final KeyedStateBackend<ByteBuffer> flinkStateBackend;
+    private final FlinkStateInternals<K> flinkStateInternals;
+    private final ValueStateDescriptor<Instant> flinkStateDescriptor;
+
+    public FlinkWatermarkHoldState(
+        KeyedStateBackend<ByteBuffer> flinkStateBackend,
+        FlinkStateInternals<K> flinkStateInternals,
+        StateTag<? super K, WatermarkHoldState<W>> address,
+        StateNamespace namespace,
+        OutputTimeFn<? super W> outputTimeFn) {
+      this.address = address;
+      this.outputTimeFn = outputTimeFn;
+      this.namespace = namespace;
+      this.flinkStateBackend = flinkStateBackend;
+      this.flinkStateInternals = flinkStateInternals;
+
+      CoderTypeInformation<Instant> typeInfo = new 
CoderTypeInformation<>(InstantCoder.of());
+      flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), 
typeInfo, null);
+    }
+
+    @Override
+    public OutputTimeFn<? super W> getOutputTimeFn() {
+      return outputTimeFn;
+    }
+
+    @Override
+    public WatermarkHoldState<W> readLater() {
+      return this;
+    }
+
+    @Override
+    public ReadableState<Boolean> isEmpty() {
+      return new ReadableState<Boolean>() {
+        @Override
+        public Boolean read() {
+          try {
+            return flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor).value() == null;
+          } catch (Exception e) {
+            throw new RuntimeException("Error reading state.", e);
+          }
+        }
+
+        @Override
+        public ReadableState<Boolean> readLater() {
+          return this;
+        }
+      };
+
+    }
+
+    @Override
+    public void add(Instant value) {
+      try {
+        org.apache.flink.api.common.state.ValueState<Instant> state =
+            flinkStateBackend.getPartitionedState(
+              namespace.stringKey(),
+              StringSerializer.INSTANCE,
+              flinkStateDescriptor);
+
+        Instant current = state.value();
+        if (current == null) {
+          state.update(value);
+          flinkStateInternals.watermarkHolds.put(namespace.stringKey(), value);
+        } else {
+          Instant combined = outputTimeFn.combine(current, value);
+          state.update(combined);
+          flinkStateInternals.watermarkHolds.put(namespace.stringKey(), 
combined);
+        }
+      } catch (Exception e) {
+        throw new RuntimeException("Error updating state.", e);
+      }
+    }
+
+    @Override
+    public Instant read() {
+      try {
+        org.apache.flink.api.common.state.ValueState<Instant> state =
+            flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor);
+        return state.value();
+      } catch (Exception e) {
+        throw new RuntimeException("Error reading state.", e);
+      }
+    }
+
+    @Override
+    public void clear() {
+      flinkStateInternals.watermarkHolds.remove(namespace.stringKey());
+      try {
+        org.apache.flink.api.common.state.ValueState<Instant> state =
+            flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor);
+        state.clear();
+      } catch (Exception e) {
+        throw new RuntimeException("Error reading state.", e);
+      }
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+
+      FlinkWatermarkHoldState<?, ?> that = (FlinkWatermarkHoldState<?, ?>) o;
+
+      if (!address.equals(that.address)) {
+        return false;
+      }
+      if (!outputTimeFn.equals(that.outputTimeFn)) {
+        return false;
+      }
+      return namespace.equals(that.namespace);
+
+    }
+
+    @Override
+    public int hashCode() {
+      int result = address.hashCode();
+      result = 31 * result + outputTimeFn.hashCode();
+      result = 31 * result + namespace.hashCode();
+      return result;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupCheckpointedOperator.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupCheckpointedOperator.java
 
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupCheckpointedOperator.java
new file mode 100644
index 0000000..b38a520
--- /dev/null
+++ 
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupCheckpointedOperator.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.translation.wrappers.streaming.state;
+
+import java.io.DataOutputStream;
+
+/**
+ * This interface is used to checkpoint key-groups state.
+ */
+public interface KeyGroupCheckpointedOperator extends 
KeyGroupRestoringOperator{
+  /**
+   * Snapshots the state for a given {@code keyGroupIdx}.
+   *
+   * <p>AbstractStreamOperator would call this hook in
+   * AbstractStreamOperator.snapshotState() while iterating over the key 
groups.
+   * @param keyGroupIndex the id of the key-group to be put in the snapshot.
+   * @param out the stream to write to.
+   */
+  void snapshotKeyGroupState(int keyGroupIndex, DataOutputStream out) throws 
Exception;
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupRestoringOperator.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupRestoringOperator.java
 
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupRestoringOperator.java
new file mode 100644
index 0000000..2bdfc6e
--- /dev/null
+++ 
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupRestoringOperator.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.translation.wrappers.streaming.state;
+
+import java.io.DataInputStream;
+
+/**
+ * This interface is used to restore key-groups state.
+ */
+public interface KeyGroupRestoringOperator {
+  /**
+   * Restore the state for a given {@code keyGroupIndex}.
+   * @param keyGroupIndex the id of the key-group to be put in the snapshot.
+   * @param in the stream to read from.
+   */
+  void restoreKeyGroupState(int keyGroupIndex, DataInputStream in) throws 
Exception;
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/package-info.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/package-info.java
 
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/package-info.java
new file mode 100644
index 0000000..0004e9e
--- /dev/null
+++ 
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Internal state implementation of the Beam runner for Apache Flink.
+ */
+package org.apache.beam.runners.flink.translation.wrappers.streaming.state;

http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java
 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java
index d07861c..2cb3dd3 100644
--- 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java
+++ 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java
@@ -26,6 +26,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import 
org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
 import 
org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.options.Default;
 import org.apache.beam.sdk.options.Description;
 import org.apache.beam.sdk.options.PipelineOptions;
@@ -110,12 +111,12 @@ public class PipelineOptionsTest {
 
   @Test(expected = Exception.class)
   public void parDoBaseClassPipelineOptionsNullTest() {
-    DoFnOperator<Object, Object, Object> doFnOperator = new DoFnOperator<>(
+    DoFnOperator<String, String, String> doFnOperator = new DoFnOperator<>(
         new TestDoFn(),
-        TypeInformation.of(new TypeHint<WindowedValue<Object>>() {}),
-        new TupleTag<>("main-output"),
+        WindowedValue.getValueOnlyCoder(StringUtf8Coder.of()),
+        new TupleTag<String>("main-output"),
         Collections.<TupleTag<?>>emptyList(),
-        new DoFnOperator.DefaultOutputManagerFactory<>(),
+        new DoFnOperator.DefaultOutputManagerFactory<String>(),
         WindowingStrategy.globalDefault(),
         new HashMap<Integer, PCollectionView<?>>(),
         Collections.<PCollectionView<?>>emptyList(),
@@ -130,12 +131,12 @@ public class PipelineOptionsTest {
   @Test
   public void parDoBaseClassPipelineOptionsSerializationTest() throws 
Exception {
 
-    DoFnOperator<Object, Object, Object> doFnOperator = new DoFnOperator<>(
+    DoFnOperator<String, String, String> doFnOperator = new DoFnOperator<>(
         new TestDoFn(),
-        TypeInformation.of(new TypeHint<WindowedValue<Object>>() {}),
-        new TupleTag<>("main-output"),
+        WindowedValue.getValueOnlyCoder(StringUtf8Coder.of()),
+        new TupleTag<String>("main-output"),
         Collections.<TupleTag<?>>emptyList(),
-        new DoFnOperator.DefaultOutputManagerFactory<>(),
+        new DoFnOperator.DefaultOutputManagerFactory<String>(),
         WindowingStrategy.globalDefault(),
         new HashMap<Integer, PCollectionView<?>>(),
         Collections.<PCollectionView<?>>emptyList(),
@@ -148,8 +149,12 @@ public class PipelineOptionsTest {
     DoFnOperator<Object, Object, Object> deserialized =
         (DoFnOperator<Object, Object, Object>) 
SerializationUtils.deserialize(serialized);
 
+    TypeInformation<WindowedValue<Object>> typeInformation = 
TypeInformation.of(
+        new TypeHint<WindowedValue<Object>>() {});
+
     OneInputStreamOperatorTestHarness<WindowedValue<Object>, Object> 
testHarness =
-        new OneInputStreamOperatorTestHarness<>(deserialized, new 
ExecutionConfig());
+        new OneInputStreamOperatorTestHarness<>(deserialized,
+            typeInformation.createSerializer(new ExecutionConfig()));
 
     testHarness.open();
 
@@ -166,7 +171,7 @@ public class PipelineOptionsTest {
   }
 
 
-  private static class TestDoFn extends DoFn<Object, Object> {
+  private static class TestDoFn extends DoFn<String, String> {
 
     @ProcessElement
     public void processElement(ProcessContext c) throws Exception {

http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
index 3598d10..7d14a87 100644
--- 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
+++ 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
@@ -29,8 +29,8 @@ import java.util.Collections;
 import java.util.HashMap;
 import javax.annotation.Nullable;
 import org.apache.beam.runners.flink.FlinkPipelineOptions;
-import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
 import 
org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.testing.PCollectionViewTesting;
@@ -44,12 +44,15 @@ import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
+
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.TwoInputStreamOperatorTestHarness;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -89,14 +92,11 @@ public class DoFnOperatorTest {
     WindowedValue.ValueOnlyWindowedValueCoder<String> windowedValueCoder =
         WindowedValue.getValueOnlyCoder(StringUtf8Coder.of());
 
-    CoderTypeInformation<WindowedValue<String>> coderTypeInfo =
-        new CoderTypeInformation<>(windowedValueCoder);
-
     TupleTag<String> outputTag = new TupleTag<>("main-output");
 
     DoFnOperator<String, String, String> doFnOperator = new DoFnOperator<>(
         new IdentityDoFn<String>(),
-        coderTypeInfo,
+        windowedValueCoder,
         outputTag,
         Collections.<TupleTag<?>>emptyList(),
         new DoFnOperator.DefaultOutputManagerFactory(),
@@ -127,9 +127,6 @@ public class DoFnOperatorTest {
     WindowedValue.ValueOnlyWindowedValueCoder<String> windowedValueCoder =
         WindowedValue.getValueOnlyCoder(StringUtf8Coder.of());
 
-    CoderTypeInformation<WindowedValue<String>> coderTypeInfo =
-        new CoderTypeInformation<>(windowedValueCoder);
-
     TupleTag<String> mainOutput = new TupleTag<>("main-output");
     TupleTag<String> sideOutput1 = new TupleTag<>("side-output-1");
     TupleTag<String> sideOutput2 = new TupleTag<>("side-output-2");
@@ -141,7 +138,7 @@ public class DoFnOperatorTest {
 
     DoFnOperator<String, String, RawUnionValue> doFnOperator = new 
DoFnOperator<>(
         new MultiOutputDoFn(sideOutput1, sideOutput2),
-        coderTypeInfo,
+        windowedValueCoder,
         mainOutput,
         ImmutableList.<TupleTag<?>>of(sideOutput1, sideOutput2),
         new DoFnOperator.MultiOutputOutputManagerFactory(outputMapping),
@@ -172,26 +169,11 @@ public class DoFnOperatorTest {
     testHarness.close();
   }
 
-  /**
-   * For now, this test doesn't work because {@link 
TwoInputStreamOperatorTestHarness} is not
-   * sufficiently well equipped to handle more complex operators that require 
a state backend.
-   * We have to revisit this once we update to a newer version of Flink and 
also add some more
-   * tests that verify pushback behaviour and watermark hold behaviour.
-   *
-   * <p>The behaviour that we would test here is also exercised by the
-   * {@link org.apache.beam.sdk.testing.RunnableOnService} tests, so the code 
is not untested.
-   */
-  @Test
-  @Ignore
-  @SuppressWarnings("unchecked")
-  public void testSideInputs() throws Exception {
+  public void testSideInputs(boolean keyed) throws Exception {
 
     WindowedValue.ValueOnlyWindowedValueCoder<String> windowedValueCoder =
         WindowedValue.getValueOnlyCoder(StringUtf8Coder.of());
 
-    CoderTypeInformation<WindowedValue<String>> coderTypeInfo =
-        new CoderTypeInformation<>(windowedValueCoder);
-
     TupleTag<String> outputTag = new TupleTag<>("main-output");
 
     ImmutableMap<Integer, PCollectionView<?>> sideInputMapping =
@@ -200,46 +182,92 @@ public class DoFnOperatorTest {
             .put(2, view2)
             .build();
 
+    Coder<String> keyCoder = null;
+    if (keyed) {
+      keyCoder = StringUtf8Coder.of();
+    }
+
     DoFnOperator<String, String, String> doFnOperator = new DoFnOperator<>(
         new IdentityDoFn<String>(),
-        coderTypeInfo,
+        windowedValueCoder,
         outputTag,
         Collections.<TupleTag<?>>emptyList(),
-        new DoFnOperator.DefaultOutputManagerFactory(),
+        new DoFnOperator.DefaultOutputManagerFactory<String>(),
         WindowingStrategy.globalDefault(),
         sideInputMapping, /* side-input mapping */
         ImmutableList.<PCollectionView<?>>of(view1, view2), /* side inputs */
         PipelineOptionsFactory.as(FlinkPipelineOptions.class),
-        null);
+        keyCoder);
 
     TwoInputStreamOperatorTestHarness<WindowedValue<String>, RawUnionValue, 
String> testHarness =
         new TwoInputStreamOperatorTestHarness<>(doFnOperator);
 
+    if (keyed) {
+      // we use a dummy key for the second input since it is considered to be 
broadcast
+      testHarness = new KeyedTwoInputStreamOperatorTestHarness<>(
+          doFnOperator,
+          new StringKeySelector(),
+          new DummyKeySelector(),
+          BasicTypeInfo.STRING_TYPE_INFO);
+    }
+
     testHarness.open();
 
     IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new 
Instant(100));
+    IntervalWindow secondWindow = new IntervalWindow(new Instant(0), new 
Instant(500));
 
-    // push in some side-input elements
+    // test the keep of sideInputs events
     testHarness.processElement2(
         new StreamRecord<>(
             new RawUnionValue(
                 1,
                 valuesInWindow(ImmutableList.of("hello", "ciao"), new 
Instant(0), firstWindow))));
-
     testHarness.processElement2(
         new StreamRecord<>(
             new RawUnionValue(
                 2,
-                valuesInWindow(ImmutableList.of("foo", "bar"), new Instant(0), 
firstWindow))));
+                valuesInWindow(ImmutableList.of("foo", "bar"), new Instant(0), 
secondWindow))));
 
     // push in a regular elements
-    testHarness.processElement1(new 
StreamRecord<>(WindowedValue.valueInGlobalWindow("Hello")));
+    WindowedValue<String> helloElement = valueInWindow("Hello", new 
Instant(0), firstWindow);
+    WindowedValue<String> worldElement = valueInWindow("World", new 
Instant(1000), firstWindow);
+    testHarness.processElement1(new StreamRecord<>(helloElement));
+    testHarness.processElement1(new StreamRecord<>(worldElement));
+
+    // test the keep of pushed-back events
+    testHarness.processElement2(
+        new StreamRecord<>(
+            new RawUnionValue(
+                1,
+                valuesInWindow(ImmutableList.of("hello", "ciao"),
+                    new Instant(1000), firstWindow))));
+    testHarness.processElement2(
+        new StreamRecord<>(
+            new RawUnionValue(
+                2,
+                valuesInWindow(ImmutableList.of("foo", "bar"), new 
Instant(1000), secondWindow))));
 
     assertThat(
         
this.<String>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
-        contains(WindowedValue.valueInGlobalWindow("Hello")));
+        contains(helloElement, worldElement));
 
     testHarness.close();
+
+  }
+
+  /**
+   * {@link TwoInputStreamOperatorTestHarness} support OperatorStateBackend,
+   * but don't support KeyedStateBackend. So we just test sideInput of normal 
ParDo.
+   */
+  @Test
+  @SuppressWarnings("unchecked")
+  public void testNormalParDoSideInputs() throws Exception {
+    testSideInputs(false);
+  }
+
+  @Test
+  public void testKeyedSideInputs() throws Exception {
+    testSideInputs(true);
   }
 
   private <T> Iterable<WindowedValue<T>> stripStreamRecordFromWindowedValue(
@@ -325,4 +353,17 @@ public class DoFnOperatorTest {
   }
 
 
+  private static class DummyKeySelector implements KeySelector<RawUnionValue, 
String> {
+    @Override
+    public String getKey(RawUnionValue stringWindowedValue) throws Exception {
+      return "dummy_key";
+    }
+  }
+
+  private static class StringKeySelector implements 
KeySelector<WindowedValue<String>, String> {
+    @Override
+    public String getKey(WindowedValue<String> stringWindowedValue) throws 
Exception {
+      return stringWindowedValue.getValue();
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java
 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java
new file mode 100644
index 0000000..db02cb3
--- /dev/null
+++ 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertThat;
+
+import java.util.Arrays;
+import org.apache.beam.runners.core.StateMerging;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaceForTest;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.core.StateTags;
+import 
org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkBroadcastStateInternals;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.util.state.AccumulatorCombiningState;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.CombiningState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link FlinkBroadcastStateInternals}. This is based on the tests 
for
+ * {@code InMemoryStateInternals}.
+ */
+@RunWith(JUnit4.class)
+public class FlinkBroadcastStateInternalsTest {
+  private static final StateNamespace NAMESPACE_1 = new 
StateNamespaceForTest("ns1");
+  private static final StateNamespace NAMESPACE_2 = new 
StateNamespaceForTest("ns2");
+  private static final StateNamespace NAMESPACE_3 = new 
StateNamespaceForTest("ns3");
+
+  private static final StateTag<Object, ValueState<String>> STRING_VALUE_ADDR =
+      StateTags.value("stringValue", StringUtf8Coder.of());
+  private static final StateTag<Object, AccumulatorCombiningState<Integer, 
int[], Integer>>
+      SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal(
+          "sumInteger", VarIntCoder.of(), Sum.ofIntegers());
+  private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR =
+      StateTags.bag("stringBag", StringUtf8Coder.of());
+
+  FlinkBroadcastStateInternals<String> underTest;
+
+  @Before
+  public void initStateInternals() {
+    MemoryStateBackend backend = new MemoryStateBackend();
+    try {
+      OperatorStateBackend operatorStateBackend =
+          backend.createOperatorStateBackend(new DummyEnvironment("test", 1, 
0), "");
+      underTest = new FlinkBroadcastStateInternals<>(1, operatorStateBackend);
+
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  @Test
+  public void testValue() throws Exception {
+    ValueState<String> value = underTest.state(NAMESPACE_1, STRING_VALUE_ADDR);
+
+    assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value);
+    assertNotEquals(
+        underTest.state(NAMESPACE_2, STRING_VALUE_ADDR),
+        value);
+
+    assertThat(value.read(), Matchers.nullValue());
+    value.write("hello");
+    assertThat(value.read(), Matchers.equalTo("hello"));
+    value.write("world");
+    assertThat(value.read(), Matchers.equalTo("world"));
+
+    value.clear();
+    assertThat(value.read(), Matchers.nullValue());
+    assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value);
+
+  }
+
+  @Test
+  public void testBag() throws Exception {
+    BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+    assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR));
+    assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)));
+
+    assertThat(value.read(), Matchers.emptyIterable());
+    value.add("hello");
+    assertThat(value.read(), Matchers.containsInAnyOrder("hello"));
+
+    value.add("world");
+    assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world"));
+
+    value.clear();
+    assertThat(value.read(), Matchers.emptyIterable());
+    assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value);
+
+  }
+
+  @Test
+  public void testBagIsEmpty() throws Exception {
+    BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+    assertThat(value.isEmpty().read(), Matchers.is(true));
+    ReadableState<Boolean> readFuture = value.isEmpty();
+    value.add("hello");
+    assertThat(readFuture.read(), Matchers.is(false));
+
+    value.clear();
+    assertThat(readFuture.read(), Matchers.is(true));
+  }
+
+  @Test
+  public void testMergeBagIntoSource() throws Exception {
+    BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+    BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+
+    bag1.add("Hello");
+    bag2.add("World");
+    bag1.add("!");
+
+    StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1);
+
+    // Reading the merged bag gets both the contents
+    assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", 
"!"));
+    assertThat(bag2.read(), Matchers.emptyIterable());
+  }
+
+  @Test
+  public void testMergeBagIntoNewNamespace() throws Exception {
+    BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+    BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+    BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR);
+
+    bag1.add("Hello");
+    bag2.add("World");
+    bag1.add("!");
+
+    StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3);
+
+    // Reading the merged bag gets both the contents
+    assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", 
"!"));
+    assertThat(bag1.read(), Matchers.emptyIterable());
+    assertThat(bag2.read(), Matchers.emptyIterable());
+  }
+
+  @Test
+  public void testCombiningValue() throws Exception {
+    CombiningState<Integer, Integer> value = underTest.state(NAMESPACE_1, 
SUM_INTEGER_ADDR);
+
+    // State instances are cached, but depend on the namespace.
+    assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR));
+    assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR)));
+
+    assertThat(value.read(), Matchers.equalTo(0));
+    value.add(2);
+    assertThat(value.read(), Matchers.equalTo(2));
+
+    value.add(3);
+    assertThat(value.read(), Matchers.equalTo(5));
+
+    value.clear();
+    assertThat(value.read(), Matchers.equalTo(0));
+    assertEquals(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), value);
+  }
+
+  @Test
+  public void testCombiningIsEmpty() throws Exception {
+    CombiningState<Integer, Integer> value = underTest.state(NAMESPACE_1, 
SUM_INTEGER_ADDR);
+
+    assertThat(value.isEmpty().read(), Matchers.is(true));
+    ReadableState<Boolean> readFuture = value.isEmpty();
+    value.add(5);
+    assertThat(readFuture.read(), Matchers.is(false));
+
+    value.clear();
+    assertThat(readFuture.read(), Matchers.is(true));
+  }
+
+  @Test
+  public void testMergeCombiningValueIntoSource() throws Exception {
+    AccumulatorCombiningState<Integer, int[], Integer> value1 =
+        underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+    AccumulatorCombiningState<Integer, int[], Integer> value2 =
+        underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
+
+    value1.add(5);
+    value2.add(10);
+    value1.add(6);
+
+    assertThat(value1.read(), Matchers.equalTo(11));
+    assertThat(value2.read(), Matchers.equalTo(10));
+
+    // Merging clears the old values and updates the result value.
+    StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1);
+
+    assertThat(value1.read(), Matchers.equalTo(21));
+    assertThat(value2.read(), Matchers.equalTo(0));
+  }
+
+  @Test
+  public void testMergeCombiningValueIntoNewNamespace() throws Exception {
+    AccumulatorCombiningState<Integer, int[], Integer> value1 =
+        underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+    AccumulatorCombiningState<Integer, int[], Integer> value2 =
+        underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
+    AccumulatorCombiningState<Integer, int[], Integer> value3 =
+        underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR);
+
+    value1.add(5);
+    value2.add(10);
+    value1.add(6);
+
+    StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3);
+
+    // Merging clears the old values and updates the result value.
+    assertThat(value1.read(), Matchers.equalTo(0));
+    assertThat(value2.read(), Matchers.equalTo(0));
+    assertThat(value3.read(), Matchers.equalTo(21));
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java
 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java
new file mode 100644
index 0000000..5433d07
--- /dev/null
+++ 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import org.apache.beam.runners.core.StateMerging;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaceForTest;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.core.StateTags;
+import 
org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkKeyGroupStateInternals;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.java.typeutils.GenericTypeInfo;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.query.KvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.apache.flink.streaming.api.operators.KeyContext;
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link FlinkKeyGroupStateInternals}. This is based on the tests 
for
+ * {@code InMemoryStateInternals}.
+ */
+@RunWith(JUnit4.class)
+public class FlinkKeyGroupStateInternalsTest {
+  private static final StateNamespace NAMESPACE_1 = new 
StateNamespaceForTest("ns1");
+  private static final StateNamespace NAMESPACE_2 = new 
StateNamespaceForTest("ns2");
+  private static final StateNamespace NAMESPACE_3 = new 
StateNamespaceForTest("ns3");
+
+  private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR =
+      StateTags.bag("stringBag", StringUtf8Coder.of());
+
+  FlinkKeyGroupStateInternals<String> underTest;
+  private KeyedStateBackend keyedStateBackend;
+
+  @Before
+  public void initStateInternals() {
+    try {
+      keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1));
+      underTest = new FlinkKeyGroupStateInternals<>(StringUtf8Coder.of(), 
keyedStateBackend);
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  private KeyedStateBackend getKeyedStateBackend(int numberOfKeyGroups,
+                                                   KeyGroupRange 
keyGroupRange) {
+    MemoryStateBackend backend = new MemoryStateBackend();
+    try {
+      AbstractKeyedStateBackend<ByteBuffer> keyedStateBackend = 
backend.createKeyedStateBackend(
+          new DummyEnvironment("test", 1, 0),
+          new JobID(),
+          "test_op",
+          new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new 
ExecutionConfig()),
+          numberOfKeyGroups,
+          keyGroupRange,
+          new KvStateRegistry().createTaskRegistry(new JobID(), new 
JobVertexID()));
+      keyedStateBackend.setCurrentKey(ByteBuffer.wrap(
+          CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "1")));
+      return keyedStateBackend;
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  @Test
+  public void testBag() throws Exception {
+    BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+    assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR));
+    assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)));
+
+    assertThat(value.read(), Matchers.emptyIterable());
+    value.add("hello");
+    assertThat(value.read(), Matchers.containsInAnyOrder("hello"));
+
+    value.add("world");
+    assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world"));
+
+    value.clear();
+    assertThat(value.read(), Matchers.emptyIterable());
+    assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value);
+
+  }
+
+  @Test
+  public void testBagIsEmpty() throws Exception {
+    BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+    assertThat(value.isEmpty().read(), Matchers.is(true));
+    ReadableState<Boolean> readFuture = value.isEmpty();
+    value.add("hello");
+    assertThat(readFuture.read(), Matchers.is(false));
+
+    value.clear();
+    assertThat(readFuture.read(), Matchers.is(true));
+  }
+
+  @Test
+  public void testMergeBagIntoSource() throws Exception {
+    BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+    BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+
+    bag1.add("Hello");
+    bag2.add("World");
+    bag1.add("!");
+
+    StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1);
+
+    // Reading the merged bag gets both the contents
+    assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", 
"!"));
+    assertThat(bag2.read(), Matchers.emptyIterable());
+  }
+
+  @Test
+  public void testMergeBagIntoNewNamespace() throws Exception {
+    BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+    BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+    BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR);
+
+    bag1.add("Hello");
+    bag2.add("World");
+    bag1.add("!");
+
+    StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3);
+
+    // Reading the merged bag gets both the contents
+    assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", 
"!"));
+    assertThat(bag1.read(), Matchers.emptyIterable());
+    assertThat(bag2.read(), Matchers.emptyIterable());
+  }
+
+  @Test
+  public void testKeyGroupAndCheckpoint() throws Exception {
+    // assign to keyGroup 0
+    ByteBuffer key0 = ByteBuffer.wrap(
+        CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "11111111"));
+    // assign to keyGroup 1
+    ByteBuffer key1 = ByteBuffer.wrap(
+        CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "22222222"));
+    FlinkKeyGroupStateInternals<String> allState;
+    {
+      KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new 
KeyGroupRange(0, 1));
+      allState = new FlinkKeyGroupStateInternals<>(
+          StringUtf8Coder.of(), keyedStateBackend);
+      BagState<String> valueForNamespace0 = allState.state(NAMESPACE_1, 
STRING_BAG_ADDR);
+      BagState<String> valueForNamespace1 = allState.state(NAMESPACE_2, 
STRING_BAG_ADDR);
+      keyedStateBackend.setCurrentKey(key0);
+      valueForNamespace0.add("0");
+      valueForNamespace1.add("2");
+      keyedStateBackend.setCurrentKey(key1);
+      valueForNamespace0.add("1");
+      valueForNamespace1.add("3");
+      assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", 
"1"));
+      assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", 
"3"));
+    }
+
+    ClassLoader classLoader = 
FlinkKeyGroupStateInternalsTest.class.getClassLoader();
+
+    // 1. scale up
+    ByteArrayOutputStream out0 = new ByteArrayOutputStream();
+    allState.snapshotKeyGroupState(0, new DataOutputStream(out0));
+    DataInputStream in0 = new DataInputStream(
+        new ByteArrayInputStream(out0.toByteArray()));
+    {
+      KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new 
KeyGroupRange(0, 0));
+      FlinkKeyGroupStateInternals<String> state0 =
+          new FlinkKeyGroupStateInternals<>(
+              StringUtf8Coder.of(), keyedStateBackend);
+      state0.restoreKeyGroupState(0, in0, classLoader);
+      BagState<String> valueForNamespace0 = state0.state(NAMESPACE_1, 
STRING_BAG_ADDR);
+      BagState<String> valueForNamespace1 = state0.state(NAMESPACE_2, 
STRING_BAG_ADDR);
+      assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0"));
+      assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2"));
+    }
+
+    ByteArrayOutputStream out1 = new ByteArrayOutputStream();
+    allState.snapshotKeyGroupState(1, new DataOutputStream(out1));
+    DataInputStream in1 = new DataInputStream(
+        new ByteArrayInputStream(out1.toByteArray()));
+    {
+      KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new 
KeyGroupRange(1, 1));
+      FlinkKeyGroupStateInternals<String> state1 =
+          new FlinkKeyGroupStateInternals<>(
+              StringUtf8Coder.of(), keyedStateBackend);
+      state1.restoreKeyGroupState(1, in1, classLoader);
+      BagState<String> valueForNamespace0 = state1.state(NAMESPACE_1, 
STRING_BAG_ADDR);
+      BagState<String> valueForNamespace1 = state1.state(NAMESPACE_2, 
STRING_BAG_ADDR);
+      assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("1"));
+      assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("3"));
+    }
+
+    // 2. scale down
+    {
+      KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new 
KeyGroupRange(0, 1));
+      FlinkKeyGroupStateInternals<String> newAllState = new 
FlinkKeyGroupStateInternals<>(
+          StringUtf8Coder.of(), keyedStateBackend);
+      in0.reset();
+      in1.reset();
+      newAllState.restoreKeyGroupState(0, in0, classLoader);
+      newAllState.restoreKeyGroupState(1, in1, classLoader);
+      BagState<String> valueForNamespace0 = newAllState.state(NAMESPACE_1, 
STRING_BAG_ADDR);
+      BagState<String> valueForNamespace1 = newAllState.state(NAMESPACE_2, 
STRING_BAG_ADDR);
+      assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", 
"1"));
+      assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", 
"3"));
+    }
+
+  }
+
+  private static class TestKeyContext implements KeyContext {
+
+    private Object key;
+
+    @Override
+    public void setCurrentKey(Object key) {
+      this.key = key;
+    }
+
+    @Override
+    public Object getCurrentKey() {
+      return key;
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
new file mode 100644
index 0000000..08ae0c4
--- /dev/null
+++ 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
+
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaceForTest;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.core.StateTags;
+import 
org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkSplitStateInternals;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link FlinkSplitStateInternals}. This is based on the tests for
+ * {@code InMemoryStateInternals}.
+ */
+@RunWith(JUnit4.class)
+public class FlinkSplitStateInternalsTest {
+  private static final StateNamespace NAMESPACE_1 = new 
StateNamespaceForTest("ns1");
+  private static final StateNamespace NAMESPACE_2 = new 
StateNamespaceForTest("ns2");
+
+  private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR =
+      StateTags.bag("stringBag", StringUtf8Coder.of());
+
+  FlinkSplitStateInternals<String> underTest;
+
+  @Before
+  public void initStateInternals() {
+    MemoryStateBackend backend = new MemoryStateBackend();
+    try {
+      OperatorStateBackend operatorStateBackend =
+          backend.createOperatorStateBackend(new DummyEnvironment("test", 1, 
0), "");
+      underTest = new FlinkSplitStateInternals<>(operatorStateBackend);
+
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  @Test
+  public void testBag() throws Exception {
+    BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+    assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR));
+    assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)));
+
+    assertThat(value.read(), Matchers.emptyIterable());
+    value.add("hello");
+    assertThat(value.read(), Matchers.containsInAnyOrder("hello"));
+
+    value.add("world");
+    assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world"));
+
+    value.clear();
+    assertThat(value.read(), Matchers.emptyIterable());
+    assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value);
+
+  }
+
+  @Test
+  public void testBagIsEmpty() throws Exception {
+    BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+    assertThat(value.isEmpty().read(), Matchers.is(true));
+    ReadableState<Boolean> readFuture = value.isEmpty();
+    value.add("hello");
+    assertThat(readFuture.read(), Matchers.is(false));
+
+    value.clear();
+    assertThat(readFuture.read(), Matchers.is(true));
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
index 465dad3..7839cf3 100644
--- 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
+++ 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
@@ -29,8 +29,7 @@ import org.apache.beam.runners.core.StateNamespace;
 import org.apache.beam.runners.core.StateNamespaceForTest;
 import org.apache.beam.runners.core.StateTag;
 import org.apache.beam.runners.core.StateTags;
-import 
org.apache.beam.runners.flink.translation.wrappers.streaming.FlinkStateInternals;
-import org.apache.beam.sdk.coders.CoderException;
+import 
org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.transforms.Sum;
@@ -45,8 +44,13 @@ import org.apache.beam.sdk.util.state.ReadableState;
 import org.apache.beam.sdk.util.state.ValueState;
 import org.apache.beam.sdk.util.state.WatermarkHoldState;
 import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.java.typeutils.GenericTypeInfo;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.query.KvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.hamcrest.Matchers;
 import org.joda.time.Instant;
@@ -88,18 +92,19 @@ public class FlinkStateInternalsTest {
   public void initStateInternals() {
     MemoryStateBackend backend = new MemoryStateBackend();
     try {
-      backend.initializeForJob(
+      AbstractKeyedStateBackend<ByteBuffer> keyedStateBackend = 
backend.createKeyedStateBackend(
           new DummyEnvironment("test", 1, 0),
+          new JobID(),
           "test_op",
-          new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new 
ExecutionConfig()));
-    } catch (Exception e) {
-      throw new RuntimeException(e);
-    }
-    underTest = new FlinkStateInternals<>(backend, StringUtf8Coder.of());
-    try {
-      backend.setCurrentKey(
+          new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new 
ExecutionConfig()),
+          1,
+          new KeyGroupRange(0, 0),
+          new KvStateRegistry().createTaskRegistry(new JobID(), new 
JobVertexID()));
+      underTest = new FlinkStateInternals<>(keyedStateBackend, 
StringUtf8Coder.of());
+
+      keyedStateBackend.setCurrentKey(
           ByteBuffer.wrap(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), 
"Hello")));
-    } catch (CoderException e) {
+    } catch (Exception e) {
       throw new RuntimeException(e);
     }
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java
 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java
index b0be98b..5b3d088 100644
--- 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java
+++ 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java
@@ -44,8 +44,10 @@ import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamSource;
 import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService;
 import org.apache.flink.util.InstantiationUtil;
 import org.junit.Test;
 import org.junit.experimental.runners.Enclosed;
@@ -123,6 +125,10 @@ public class UnboundedSourceWrapperTest {
               }
 
               @Override
+              public void emitLatencyMarker(LatencyMarker latencyMarker) {
+              }
+
+              @Override
               public void collect(
                   StreamRecord<WindowedValue<KV<Integer, Integer>>> 
windowedValueStreamRecord) {
 
@@ -191,6 +197,10 @@ public class UnboundedSourceWrapperTest {
               }
 
               @Override
+              public void emitLatencyMarker(LatencyMarker latencyMarker) {
+              }
+
+              @Override
               public void collect(
                   StreamRecord<WindowedValue<KV<Integer, Integer>>> 
windowedValueStreamRecord) {
 
@@ -256,6 +266,10 @@ public class UnboundedSourceWrapperTest {
               }
 
               @Override
+              public void emitLatencyMarker(LatencyMarker latencyMarker) {
+              }
+
+              @Override
               public void collect(
                   StreamRecord<WindowedValue<KV<Integer, Integer>>> 
windowedValueStreamRecord) {
                 
emittedElements.add(windowedValueStreamRecord.getValue().getValue());
@@ -300,6 +314,8 @@ public class UnboundedSourceWrapperTest {
       when(mockTask.getExecutionConfig()).thenReturn(executionConfig);
       when(mockTask.getAccumulatorMap())
           .thenReturn(Collections.<String, Accumulator<?, ?>>emptyMap());
+      TestProcessingTimeService testProcessingTimeService = new 
TestProcessingTimeService();
+      
when(mockTask.getProcessingTimeService()).thenReturn(testProcessingTimeService);
 
       operator.setup(mockTask, cfg, (Output<StreamRecord<T>>) 
mock(Output.class));
     }

Reply via email to