zhengbuqian commented on code in PR #23492:
URL: https://github.com/apache/beam/pull/23492#discussion_r1067508090


##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all 
values of this key is
+      // cached(both KeyState#values and localAdditions).

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all 
values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they 
are added to
+      // localAdditions but not KeyState#values. New values will be added to 
KeyState#values only
+      // after they are persisted into windmill and removed from 
localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to 
provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so 
that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == 
KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> 
new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, 
localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                // empty in both cache and windmill, remove key from cache.
+                keyStateMap.remove(structuralKey);
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, 
localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest 
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = 
WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();

Review Comment:
   Ah good to know, done.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all 
values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they 
are added to
+      // localAdditions but not KeyState#values. New values will be added to 
KeyState#values only
+      // after they are persisted into windmill and removed from 
localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to 
provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so 
that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == 
KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> 
new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, 
localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                // empty in both cache and windmill, remove key from cache.
+                keyStateMap.remove(structuralKey);
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, 
localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest 
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = 
WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = 
builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as 
those new values now are
+        // also persisted in Windmill. If a key now has no more values and is 
not KNOWN_EXIST,
+        // remove it from cache.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        } else {
+          if (keyState.existence != KeyExistence.KNOWN_EXIST) 
keyStateMap.remove(structuralKey);
+        }
+      }
+
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new 
KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (keyState.valuesCached || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesCached = false;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      localAdditions.removeAll(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(
+                    Iterables.filter(
+                        keyStateMap.values(),
+                        keyState -> keyState.existence == 
KeyExistence.KNOWN_EXIST),
+                    keyState -> keyState.originalKey));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
+            Iterable<Entry<Object, K>> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        K originalKey = 
keyCoder.decode(entry.getKey().newInput());
+                        return new AbstractMap.SimpleEntry<>(
+                            keyCoder.structuralValue(originalKey), 
originalKey);
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys,
+                    entry -> {
+                      KeyState keyState = 
keyStateMap.getOrDefault(entry.getKey(), null);
+                      // this is a key that exists in windmill but is not 
cached.
+                      if (keyState == null) return true;
+                      // if the key is marked as deleted in cache, ignore it.
+                      return keyState.existence != 
KeyExistence.KNOWN_NONEXISTENT;
+                    });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  entry -> {
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(
+                            entry.getKey(), stk -> new 
KeyState(entry.getValue()));
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != 
KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(keyStateMap.values(), keyState -> 
keyState.originalKey));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.

Review Comment:
   We don't provide ordering guarantees on either keys of the values.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all 
values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they 
are added to
+      // localAdditions but not KeyState#values. New values will be added to 
KeyState#values only
+      // after they are persisted into windmill and removed from 
localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to 
provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so 
that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == 
KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> 
new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, 
localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                // empty in both cache and windmill, remove key from cache.
+                keyStateMap.remove(structuralKey);
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, 
localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest 
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = 
WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = 
builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as 
those new values now are
+        // also persisted in Windmill. If a key now has no more values and is 
not KNOWN_EXIST,
+        // remove it from cache.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        } else {
+          if (keyState.existence != KeyExistence.KNOWN_EXIST) 
keyStateMap.remove(structuralKey);
+        }
+      }
+
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new 
KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (keyState.valuesCached || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesCached = false;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      localAdditions.removeAll(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(
+                    Iterables.filter(
+                        keyStateMap.values(),
+                        keyState -> keyState.existence == 
KeyExistence.KNOWN_EXIST),
+                    keyState -> keyState.originalKey));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
+            Iterable<Entry<Object, K>> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        K originalKey = 
keyCoder.decode(entry.getKey().newInput());
+                        return new AbstractMap.SimpleEntry<>(
+                            keyCoder.structuralValue(originalKey), 
originalKey);
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys,
+                    entry -> {
+                      KeyState keyState = 
keyStateMap.getOrDefault(entry.getKey(), null);
+                      // this is a key that exists in windmill but is not 
cached.
+                      if (keyState == null) return true;
+                      // if the key is marked as deleted in cache, ignore it.
+                      return keyState.existence != 
KeyExistence.KNOWN_NONEXISTENT;
+                    });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  entry -> {
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(
+                            entry.getKey(), stk -> new 
KeyState(entry.getValue()));
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != 
KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(keyStateMap.values(), keyState -> 
keyState.originalKey));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(
+                          Iterables.filter(
+                              keyStateMap.values(),
+                              keyState -> keyState.existence == 
KeyExistence.KNOWN_EXIST),
+                          keyState -> keyState.originalKey),
+                      // This is the part of the keys returned from Windmill 
that are not cached.
+                      Iterables.transform(
+                          Iterables.filter(keys, e -> 
!keyStateMap.containsKey(e.getKey())),
+                          entry -> entry.getValue())));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : 
localAdditions.asMap().entrySet()) {
+        K key = keyStateMap.get(entry.getKey()).originalKey;
+        result.extendWith(key, entry.getValue());
+      }
+      for (Entry<Object, KeyState> entry : keyStateMap.entrySet()) {
+        if (entry.getValue().valuesCached) {
+          result.extendWith(entry.getValue().originalKey, 
entry.getValue().values);
+        }
+      }
+      return result;
+    }
+
+    private static class MultimapIterables<K, V> implements Iterable<Entry<K, 
V>> {
+      Map<K, ConcatIterables<V>> map;
+
+      public MultimapIterables() {
+        this.map = new HashMap<>();
+      }
+
+      public void extendWith(K key, Iterable<V> iterable) {
+        map.compute(
+            key,
+            (k, v) -> {
+              if (v == null) v = new ConcatIterables<>();
+              v.extendWith(iterable);
+              return v;
+            });
+      }
+
+      @Override
+      public Iterator<Entry<K, V>> iterator() {
+        return Iterators.concat(
+            Iterables.transform(
+                    map.entrySet(),
+                    entry ->
+                        Iterables.transform(
+                                entry.getValue(),
+                                v -> new 
AbstractMap.SimpleEntry<>(entry.getKey(), v))
+                            .iterator())
+                .iterator());
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(mergedCachedEntries());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
+            Map<Object, ConcatIterables<V>> entryMap = Maps.newHashMap();
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K key = keyCoder.decode(entry.getKey().newInput());
+                    final Object structuralKey = keyCoder.structuralValue(key);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new 
KeyState(key));
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) 
return;
+                    entryMap.compute(
+                        structuralKey,
+                        (k, v) -> {
+                          if (v == null) v = new ConcatIterables<>();

Review Comment:
   > do we expect the same structure key to be returned by persistent state?
   
   Yes. 
   
   > should it be an error?
   
   It is expected and not an error. when a single entry is paginated there will 
be multiple `Entry<ByteString, Iterable<V>` pairs for the same key in the 
fetched result. see test `testMultimapEntriesPaginated`. 



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all 
values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they 
are added to

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then 
the complete content of
+     * this key is cached: persisted values of this key in backing store are 
cached in
+     * cachedEntries, newly added values not yet written to backing store are 
cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} 
then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll 
need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in 
the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. 
Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and 
localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, 
Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill 
nor local additions.");
+              }
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());
+              ((ConcatIterables<V>) 
cachedEntries.get(structuralKey)).extendWith(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, 
localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest 
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -447,6 +493,12 @@ public Iterable<ResultT> apply(
           contStateTag =

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all 
values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they 
are added to
+      // localAdditions but not KeyState#values. New values will be added to 
KeyState#values only
+      // after they are persisted into windmill and removed from 
localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to 
provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so 
that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == 
KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> 
new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, 
localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                // empty in both cache and windmill, remove key from cache.
+                keyStateMap.remove(structuralKey);
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, 
localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest 
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = 
WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = 
builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as 
those new values now are
+        // also persisted in Windmill. If a key now has no more values and is 
not KNOWN_EXIST,
+        // remove it from cache.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        } else {
+          if (keyState.existence != KeyExistence.KNOWN_EXIST) 
keyStateMap.remove(structuralKey);
+        }
+      }
+
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new 
KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (keyState.valuesCached || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesCached = false;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      localAdditions.removeAll(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(
+                    Iterables.filter(
+                        keyStateMap.values(),
+                        keyState -> keyState.existence == 
KeyExistence.KNOWN_EXIST),
+                    keyState -> keyState.originalKey));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
+            Iterable<Entry<Object, K>> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        K originalKey = 
keyCoder.decode(entry.getKey().newInput());
+                        return new AbstractMap.SimpleEntry<>(
+                            keyCoder.structuralValue(originalKey), 
originalKey);
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys,
+                    entry -> {
+                      KeyState keyState = 
keyStateMap.getOrDefault(entry.getKey(), null);
+                      // this is a key that exists in windmill but is not 
cached.
+                      if (keyState == null) return true;
+                      // if the key is marked as deleted in cache, ignore it.
+                      return keyState.existence != 
KeyExistence.KNOWN_NONEXISTENT;
+                    });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  entry -> {
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(
+                            entry.getKey(), stk -> new 
KeyState(entry.getValue()));
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != 
KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(keyStateMap.values(), keyState -> 
keyState.originalKey));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(
+                          Iterables.filter(
+                              keyStateMap.values(),
+                              keyState -> keyState.existence == 
KeyExistence.KNOWN_EXIST),
+                          keyState -> keyState.originalKey),
+                      // This is the part of the keys returned from Windmill 
that are not cached.
+                      Iterables.transform(
+                          Iterables.filter(keys, e -> 
!keyStateMap.containsKey(e.getKey())),
+                          entry -> entry.getValue())));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : 
localAdditions.asMap().entrySet()) {

Review Comment:
   Done. I should have thought of this..



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all 
values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they 
are added to
+      // localAdditions but not KeyState#values. New values will be added to 
KeyState#values only
+      // after they are persisted into windmill and removed from 
localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to 
provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so 
that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == 
KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> 
new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, 
localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                // empty in both cache and windmill, remove key from cache.
+                keyStateMap.remove(structuralKey);
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, 
localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest 
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = 
WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = 
builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as 
those new values now are
+        // also persisted in Windmill. If a key now has no more values and is 
not KNOWN_EXIST,
+        // remove it from cache.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        } else {
+          if (keyState.existence != KeyExistence.KNOWN_EXIST) 
keyStateMap.remove(structuralKey);
+        }
+      }
+
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new 
KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (keyState.valuesCached || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesCached = false;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      localAdditions.removeAll(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(
+                    Iterables.filter(
+                        keyStateMap.values(),
+                        keyState -> keyState.existence == 
KeyExistence.KNOWN_EXIST),
+                    keyState -> keyState.originalKey));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
+            Iterable<Entry<Object, K>> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        K originalKey = 
keyCoder.decode(entry.getKey().newInput());
+                        return new AbstractMap.SimpleEntry<>(
+                            keyCoder.structuralValue(originalKey), 
originalKey);
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys,
+                    entry -> {
+                      KeyState keyState = 
keyStateMap.getOrDefault(entry.getKey(), null);
+                      // this is a key that exists in windmill but is not 
cached.
+                      if (keyState == null) return true;
+                      // if the key is marked as deleted in cache, ignore it.
+                      return keyState.existence != 
KeyExistence.KNOWN_NONEXISTENT;
+                    });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  entry -> {
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(
+                            entry.getKey(), stk -> new 
KeyState(entry.getValue()));
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != 
KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(keyStateMap.values(), keyState -> 
keyState.originalKey));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(
+                          Iterables.filter(
+                              keyStateMap.values(),
+                              keyState -> keyState.existence == 
KeyExistence.KNOWN_EXIST),
+                          keyState -> keyState.originalKey),
+                      // This is the part of the keys returned from Windmill 
that are not cached.
+                      Iterables.transform(
+                          Iterables.filter(keys, e -> 
!keyStateMap.containsKey(e.getKey())),
+                          entry -> entry.getValue())));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : 
localAdditions.asMap().entrySet()) {
+        K key = keyStateMap.get(entry.getKey()).originalKey;
+        result.extendWith(key, entry.getValue());
+      }
+      for (Entry<Object, KeyState> entry : keyStateMap.entrySet()) {
+        if (entry.getValue().valuesCached) {
+          result.extendWith(entry.getValue().originalKey, 
entry.getValue().values);
+        }
+      }
+      return result;
+    }
+
+    private static class MultimapIterables<K, V> implements Iterable<Entry<K, 
V>> {
+      Map<K, ConcatIterables<V>> map;
+
+      public MultimapIterables() {
+        this.map = new HashMap<>();
+      }
+
+      public void extendWith(K key, Iterable<V> iterable) {
+        map.compute(
+            key,
+            (k, v) -> {
+              if (v == null) v = new ConcatIterables<>();
+              v.extendWith(iterable);
+              return v;
+            });
+      }
+
+      @Override
+      public Iterator<Entry<K, V>> iterator() {
+        return Iterators.concat(
+            Iterables.transform(
+                    map.entrySet(),
+                    entry ->
+                        Iterables.transform(
+                                entry.getValue(),
+                                v -> new 
AbstractMap.SimpleEntry<>(entry.getKey(), v))
+                            .iterator())
+                .iterator());
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(mergedCachedEntries());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
+            Map<Object, ConcatIterables<V>> entryMap = Maps.newHashMap();
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K key = keyCoder.decode(entry.getKey().newInput());
+                    final Object structuralKey = keyCoder.structuralValue(key);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new 
KeyState(key));
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) 
return;
+                    entryMap.compute(
+                        structuralKey,
+                        (k, v) -> {
+                          if (v == null) v = new ConcatIterables<>();
+                          v.extendWith(entry.getValue());
+                          keyState.existence = KeyExistence.KNOWN_EXIST;
+                          return v;
+                        });
+                  } catch (IOException e) {
+                    throw new RuntimeException(e);
+                  }
+                });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              entryMap.forEach(
+                  (structuralKey, values) -> {
+                    KeyState keyState = keyStateMap.get(structuralKey);
+                    if (!keyState.valuesCached) {
+                      keyState.values.extendWith(values);
+                      keyState.valuesCached = true;
+                    }
+                  });
+              allKeysKnown = true;
+              complete = true;
+              keyStateMap
+                  .entrySet()
+                  .removeIf(
+                      entry ->
+                          entry.getValue().existence == 
KeyExistence.KNOWN_NONEXISTENT
+                              && !localRemovals.contains(entry.getKey()));
+              return Iterables.unmodifiableIterable(mergedCachedEntries());
+            } else {
+              MultimapIterables<K, V> local = mergedCachedEntries();
+              entryMap.forEach(
+                  (structuralKey, values) -> {
+                    KeyState keyState = keyStateMap.get(structuralKey);
+                    if (!keyState.valuesCached) {
+                      local.extendWith(keyState.originalKey, values);
+                    }
+                  });
+              return Iterables.unmodifiableIterable(local);
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<Entry<K, V>>> readLater() {
+          WindmillMultimap.this.getFuture(false);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Boolean> containsKey(K key) {
+      return new ReadableState<Boolean>() {
+        ReadableState<Iterable<V>> values = null;
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Boolean read() {
+          KeyState keyState = keyStateMap.getOrDefault(structuralKey, null);
+          if (keyState != null && keyState.existence != 
KeyExistence.UNKNOWN_EXISTENCE) {
+            return keyState.existence == KeyExistence.KNOWN_EXIST;
+          }
+          if (values == null) {
+            values = WindmillMultimap.this.get(key);
+          }
+          return !Iterables.isEmpty(values.read());
+        }
+
+        @Override
+        public ReadableState<Boolean> readLater() {
+          if (values == null) {
+            values = WindmillMultimap.this.get(key);
+          }
+          values.readLater();
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Boolean> isEmpty() {
+      return new ReadableState<Boolean>() {
+        ReadableState<Iterable<K>> keys = null;
+
+        @Override
+        public Boolean read() {
+          for (KeyState keyState : keyStateMap.values()) {
+            if (keyState.existence == KeyExistence.KNOWN_EXIST) return false;
+          }
+          if (keys == null) {
+            keys = WindmillMultimap.this.keys();
+          }
+          return Iterables.isEmpty(keys.read());
+        }
+
+        @Override
+        public ReadableState<Boolean> readLater() {
+          if (keys == null) {
+            keys = WindmillMultimap.this.keys();
+          }
+          keys.readLater();

Review Comment:
   I think not. `readLater` doesn't evaluate immediate, it evaluates when 
`read()` is called, so we can't return based on the status as of now.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all 
values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they 
are added to
+      // localAdditions but not KeyState#values. New values will be added to 
KeyState#values only
+      // after they are persisted into windmill and removed from 
localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to 
provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so 
that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == 
KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> 
new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, 
localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                // empty in both cache and windmill, remove key from cache.
+                keyStateMap.remove(structuralKey);
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, 
localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest 
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = 
WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = 
builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as 
those new values now are
+        // also persisted in Windmill. If a key now has no more values and is 
not KNOWN_EXIST,
+        // remove it from cache.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        } else {
+          if (keyState.existence != KeyExistence.KNOWN_EXIST) 
keyStateMap.remove(structuralKey);
+        }
+      }
+
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new 
KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (keyState.valuesCached || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesCached = false;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      localAdditions.removeAll(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(
+                    Iterables.filter(
+                        keyStateMap.values(),
+                        keyState -> keyState.existence == 
KeyExistence.KNOWN_EXIST),
+                    keyState -> keyState.originalKey));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
+            Iterable<Entry<Object, K>> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        K originalKey = 
keyCoder.decode(entry.getKey().newInput());
+                        return new AbstractMap.SimpleEntry<>(
+                            keyCoder.structuralValue(originalKey), 
originalKey);
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys,
+                    entry -> {
+                      KeyState keyState = 
keyStateMap.getOrDefault(entry.getKey(), null);
+                      // this is a key that exists in windmill but is not 
cached.
+                      if (keyState == null) return true;
+                      // if the key is marked as deleted in cache, ignore it.
+                      return keyState.existence != 
KeyExistence.KNOWN_NONEXISTENT;
+                    });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  entry -> {
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(
+                            entry.getKey(), stk -> new 
KeyState(entry.getValue()));
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != 
KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(keyStateMap.values(), keyState -> 
keyState.originalKey));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(
+                          Iterables.filter(
+                              keyStateMap.values(),
+                              keyState -> keyState.existence == 
KeyExistence.KNOWN_EXIST),
+                          keyState -> keyState.originalKey),
+                      // This is the part of the keys returned from Windmill 
that are not cached.
+                      Iterables.transform(
+                          Iterables.filter(keys, e -> 
!keyStateMap.containsKey(e.getKey())),
+                          entry -> entry.getValue())));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : 
localAdditions.asMap().entrySet()) {
+        K key = keyStateMap.get(entry.getKey()).originalKey;
+        result.extendWith(key, entry.getValue());
+      }
+      for (Entry<Object, KeyState> entry : keyStateMap.entrySet()) {
+        if (entry.getValue().valuesCached) {
+          result.extendWith(entry.getValue().originalKey, 
entry.getValue().values);
+        }
+      }
+      return result;
+    }
+
+    private static class MultimapIterables<K, V> implements Iterable<Entry<K, 
V>> {
+      Map<K, ConcatIterables<V>> map;
+
+      public MultimapIterables() {
+        this.map = new HashMap<>();
+      }
+
+      public void extendWith(K key, Iterable<V> iterable) {
+        map.compute(
+            key,
+            (k, v) -> {
+              if (v == null) v = new ConcatIterables<>();
+              v.extendWith(iterable);
+              return v;
+            });
+      }
+
+      @Override
+      public Iterator<Entry<K, V>> iterator() {
+        return Iterators.concat(
+            Iterables.transform(
+                    map.entrySet(),
+                    entry ->
+                        Iterables.transform(
+                                entry.getValue(),
+                                v -> new 
AbstractMap.SimpleEntry<>(entry.getKey(), v))
+                            .iterator())
+                .iterator());
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(mergedCachedEntries());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
+            Map<Object, ConcatIterables<V>> entryMap = Maps.newHashMap();
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K key = keyCoder.decode(entry.getKey().newInput());
+                    final Object structuralKey = keyCoder.structuralValue(key);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new 
KeyState(key));
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) 
return;
+                    entryMap.compute(
+                        structuralKey,
+                        (k, v) -> {
+                          if (v == null) v = new ConcatIterables<>();
+                          v.extendWith(entry.getValue());
+                          keyState.existence = KeyExistence.KNOWN_EXIST;
+                          return v;
+                        });
+                  } catch (IOException e) {
+                    throw new RuntimeException(e);
+                  }
+                });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              entryMap.forEach(
+                  (structuralKey, values) -> {
+                    KeyState keyState = keyStateMap.get(structuralKey);
+                    if (!keyState.valuesCached) {
+                      keyState.values.extendWith(values);
+                      keyState.valuesCached = true;
+                    }
+                  });
+              allKeysKnown = true;
+              complete = true;
+              keyStateMap
+                  .entrySet()
+                  .removeIf(
+                      entry ->
+                          entry.getValue().existence == 
KeyExistence.KNOWN_NONEXISTENT
+                              && !localRemovals.contains(entry.getKey()));
+              return Iterables.unmodifiableIterable(mergedCachedEntries());
+            } else {
+              MultimapIterables<K, V> local = mergedCachedEntries();
+              entryMap.forEach(
+                  (structuralKey, values) -> {
+                    KeyState keyState = keyStateMap.get(structuralKey);
+                    if (!keyState.valuesCached) {
+                      local.extendWith(keyState.originalKey, values);
+                    }
+                  });
+              return Iterables.unmodifiableIterable(local);
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<Entry<K, V>>> readLater() {
+          WindmillMultimap.this.getFuture(false);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Boolean> containsKey(K key) {
+      return new ReadableState<Boolean>() {
+        ReadableState<Iterable<V>> values = null;
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Boolean read() {
+          KeyState keyState = keyStateMap.getOrDefault(structuralKey, null);
+          if (keyState != null && keyState.existence != 
KeyExistence.UNKNOWN_EXISTENCE) {
+            return keyState.existence == KeyExistence.KNOWN_EXIST;
+          }
+          if (values == null) {
+            values = WindmillMultimap.this.get(key);
+          }
+          return !Iterables.isEmpty(values.read());
+        }
+
+        @Override
+        public ReadableState<Boolean> readLater() {
+          if (values == null) {
+            values = WindmillMultimap.this.get(key);
+          }
+          values.readLater();
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Boolean> isEmpty() {
+      return new ReadableState<Boolean>() {
+        ReadableState<Iterable<K>> keys = null;
+
+        @Override
+        public Boolean read() {
+          for (KeyState keyState : keyStateMap.values()) {
+            if (keyState.existence == KeyExistence.KNOWN_EXIST) return false;
+          }
+          if (keys == null) {
+            keys = WindmillMultimap.this.keys();

Review Comment:
   True, with the current windmill api this is indeed more expensive. The 
situation is the same as Bag and MapState. Added a comment.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to