scwhittle commented on code in PR #23492:
URL: https://github.com/apache/beam/pull/23492#discussion_r1098594021


##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1947,10 +1952,11 @@ public ReadableState<Iterable<K>> readLater() {
     private MultimapIterables<K, V> mergedCachedEntries() {
       MultimapIterables<K, V> result = new MultimapIterables<>();
       for (Entry<Object, KeyState> entry : keyStateMap.entrySet()) {
-        if (!entry.getValue().localAdditions.isEmpty()) {
+        KeyState keyState = entry.getValue();
+        if (!keyState.localAdditions.isEmpty()) {
           result.extendWith(entry.getValue().originalKey, 
entry.getValue().localAdditions);

Review Comment:
   nit: use keyState for other entry.getValue() in this loop



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1999,56 +2005,79 @@ public Iterable<Entry<K, V>> read() {
           Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(false);
           try (Closeable scope = scopedReadState()) {
             Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
-            Map<Object, ConcatIterables<V>> entryMap = Maps.newHashMap();
-            entries.forEach(
-                entry -> {
-                  try {
-                    final K key = keyCoder.decode(entry.getKey().newInput());
-                    final Object structuralKey = keyCoder.structuralValue(key);
-                    KeyState keyState =
-                        keyStateMap.computeIfAbsent(structuralKey, k -> new 
KeyState(key));
-                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) 
return;
-                    entryMap.compute(
-                        structuralKey,
-                        (k, v) -> {
-                          if (v == null) v = new ConcatIterables<>();
-                          v.extendWith(entry.getValue());
-                          keyState.existence = KeyExistence.KNOWN_EXIST;
-                          return v;
-                        });
-                  } catch (IOException e) {
-                    throw new RuntimeException(e);
-                  }
-                });
+            // If a key returned by windmill is known to be no longer exist or 
is completely cached
+            // locally, we can safely ignore the content of this key in 
windmill. Thus, we filter
+            // entries to filteredEntries which contains only keys that are 
known to exist and not
+            // fully cached.
+            Iterable<Entry<Object, Iterable<V>>> filteredEntries =
+                Iterables.filter(
+                    Iterables.transform(
+                        entries,
+                        entry -> {
+                          try {
+                            final K key = 
keyCoder.decode(entry.getKey().newInput());
+                            final Object structuralKey = 
keyCoder.structuralValue(key);
+                            KeyState keyState =
+                                keyStateMap.computeIfAbsent(structuralKey, k 
-> new KeyState(key));
+                            if (keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE) {
+                              keyState.existence = KeyExistence.KNOWN_EXIST;
+                            }
+                            return new 
AbstractMap.SimpleEntry<>(structuralKey, entry.getValue());
+                          } catch (IOException e) {
+                            throw new RuntimeException(e);
+                          }
+                        }),
+                    entry -> {
+                      KeyState keyState = keyStateMap.get(entry.getKey());

Review Comment:
   can you combine the filter+transform to avoid looking up the key twice?
   one idea would be to return null from transform if you want to filter, and 
then filter to remove nulls. But perhaps there is a better way to do with the 
stream apis, I'm not super familiar with them.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1999,56 +2005,79 @@ public Iterable<Entry<K, V>> read() {
           Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(false);
           try (Closeable scope = scopedReadState()) {
             Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
-            Map<Object, ConcatIterables<V>> entryMap = Maps.newHashMap();
-            entries.forEach(
-                entry -> {
-                  try {
-                    final K key = keyCoder.decode(entry.getKey().newInput());
-                    final Object structuralKey = keyCoder.structuralValue(key);
-                    KeyState keyState =
-                        keyStateMap.computeIfAbsent(structuralKey, k -> new 
KeyState(key));
-                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) 
return;
-                    entryMap.compute(
-                        structuralKey,
-                        (k, v) -> {
-                          if (v == null) v = new ConcatIterables<>();
-                          v.extendWith(entry.getValue());
-                          keyState.existence = KeyExistence.KNOWN_EXIST;
-                          return v;
-                        });
-                  } catch (IOException e) {
-                    throw new RuntimeException(e);
-                  }
-                });
+            // If a key returned by windmill is known to be no longer exist or 
is completely cached
+            // locally, we can safely ignore the content of this key in 
windmill. Thus, we filter
+            // entries to filteredEntries which contains only keys that are 
known to exist and not
+            // fully cached.
+            Iterable<Entry<Object, Iterable<V>>> filteredEntries =
+                Iterables.filter(
+                    Iterables.transform(
+                        entries,
+                        entry -> {
+                          try {
+                            final K key = 
keyCoder.decode(entry.getKey().newInput());
+                            final Object structuralKey = 
keyCoder.structuralValue(key);
+                            KeyState keyState =
+                                keyStateMap.computeIfAbsent(structuralKey, k 
-> new KeyState(key));
+                            if (keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE) {
+                              keyState.existence = KeyExistence.KNOWN_EXIST;
+                            }
+                            return new 
AbstractMap.SimpleEntry<>(structuralKey, entry.getValue());
+                          } catch (IOException e) {
+                            throw new RuntimeException(e);
+                          }
+                        }),
+                    entry -> {
+                      KeyState keyState = keyStateMap.get(entry.getKey());
+                      return keyState.existence != 
KeyExistence.KNOWN_NONEXISTENT
+                          && !(keyState.existence == KeyExistence.KNOWN_EXIST
+                              && keyState.valuesCached);
+                    });
+
             if (entries instanceof Weighted) {
               // This is a known amount of data, cache them all.
-              entryMap.forEach(
-                  (structuralKey, values) -> {
+              filteredEntries.forEach(
+                  entry -> {
+                    final Object structuralKey = entry.getKey();
                     KeyState keyState = keyStateMap.get(structuralKey);
-                    if (!keyState.valuesCached) {
-                      keyState.values.extendWith(values);
-                      keyState.valuesCached = true;
-                    }
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                    keyState.values.extendWith(entry.getValue());
+                    // We can't set keyState.valuesCached to true here, 
because there may be more
+                    // paginated values that should not be filtered out in 
filteredEntries.
                   });
               allKeysKnown = true;
               complete = true;
+              // Unload keys that are not known exist from cache, set 
valuesCached of all cached
+              // entries to true.
               keyStateMap
                   .entrySet()
                   .removeIf(
-                      entry ->
-                          entry.getValue().existence == 
KeyExistence.KNOWN_NONEXISTENT
-                              && !entry.getValue().removedLocally);
+                      entry -> {
+                        KeyState keyState = entry.getValue();
+                        boolean shouldRemove =
+                            (keyState.existence == 
KeyExistence.KNOWN_NONEXISTENT
+                                    && !keyState.removedLocally)
+                                || keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE;
+                        keyState.valuesCached = !shouldRemove;
+                        return shouldRemove;
+                      });
               return Iterables.unmodifiableIterable(mergedCachedEntries());
             } else {
-              MultimapIterables<K, V> local = mergedCachedEntries();
-              entryMap.forEach(
-                  (structuralKey, values) -> {
-                    KeyState keyState = keyStateMap.get(structuralKey);
-                    if (!keyState.valuesCached) {
-                      local.extendWith(keyState.originalKey, values);
-                    }
-                  });
-              return Iterables.unmodifiableIterable(local);
+              Iterable<Entry<K, V>> fromWindmill =
+                  Iterables.concat(
+                      Iterables.transform(
+                          Iterables.transform(
+                              filteredEntries,
+                              entry ->
+                                  new AbstractMap.SimpleEntry<>(
+                                      
keyStateMap.get(entry.getKey()).originalKey,
+                                      entry.getValue())),
+                          entry ->
+                              Iterables.transform(
+                                  entry.getValue(),
+                                  v -> new 
AbstractMap.SimpleEntry<>(entry.getKey(), v))));
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(mergedCachedEntries(), fromWindmill));

Review Comment:
   are the keys expected to be returned in order by this iterable? You might 
need to do some kind of merging iterable



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1999,56 +2005,79 @@ public Iterable<Entry<K, V>> read() {
           Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(false);
           try (Closeable scope = scopedReadState()) {
             Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
-            Map<Object, ConcatIterables<V>> entryMap = Maps.newHashMap();
-            entries.forEach(
-                entry -> {
-                  try {
-                    final K key = keyCoder.decode(entry.getKey().newInput());
-                    final Object structuralKey = keyCoder.structuralValue(key);
-                    KeyState keyState =
-                        keyStateMap.computeIfAbsent(structuralKey, k -> new 
KeyState(key));
-                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) 
return;
-                    entryMap.compute(
-                        structuralKey,
-                        (k, v) -> {
-                          if (v == null) v = new ConcatIterables<>();
-                          v.extendWith(entry.getValue());
-                          keyState.existence = KeyExistence.KNOWN_EXIST;
-                          return v;
-                        });
-                  } catch (IOException e) {
-                    throw new RuntimeException(e);
-                  }
-                });
+            // If a key returned by windmill is known to be no longer exist or 
is completely cached
+            // locally, we can safely ignore the content of this key in 
windmill. Thus, we filter
+            // entries to filteredEntries which contains only keys that are 
known to exist and not
+            // fully cached.
+            Iterable<Entry<Object, Iterable<V>>> filteredEntries =
+                Iterables.filter(
+                    Iterables.transform(
+                        entries,
+                        entry -> {
+                          try {
+                            final K key = 
keyCoder.decode(entry.getKey().newInput());
+                            final Object structuralKey = 
keyCoder.structuralValue(key);
+                            KeyState keyState =
+                                keyStateMap.computeIfAbsent(structuralKey, k 
-> new KeyState(key));
+                            if (keyState.existence == 
KeyExistence.UNKNOWN_EXISTENCE) {
+                              keyState.existence = KeyExistence.KNOWN_EXIST;
+                            }
+                            return new 
AbstractMap.SimpleEntry<>(structuralKey, entry.getValue());
+                          } catch (IOException e) {
+                            throw new RuntimeException(e);
+                          }
+                        }),
+                    entry -> {
+                      KeyState keyState = keyStateMap.get(entry.getKey());
+                      return keyState.existence != 
KeyExistence.KNOWN_NONEXISTENT
+                          && !(keyState.existence == KeyExistence.KNOWN_EXIST
+                              && keyState.valuesCached);
+                    });
+
             if (entries instanceof Weighted) {
               // This is a known amount of data, cache them all.
-              entryMap.forEach(
-                  (structuralKey, values) -> {
+              filteredEntries.forEach(
+                  entry -> {
+                    final Object structuralKey = entry.getKey();
                     KeyState keyState = keyStateMap.get(structuralKey);
-                    if (!keyState.valuesCached) {
-                      keyState.values.extendWith(values);
-                      keyState.valuesCached = true;
-                    }
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                    keyState.values.extendWith(entry.getValue());
+                    // We can't set keyState.valuesCached to true here, 
because there may be more

Review Comment:
   I think the weighted case could be simpler by just iterating over entries, 
merging with keyStateMap and then returning keyStateMap
   
   I think the transform/filter above could be moved to just the lazy loading 
case.   I think you could end up with just a single lookup in keyStateMap for 
each entry.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to