[ 
https://issues.apache.org/jira/browse/BEAM-13015?focusedWorklogId=770023&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-770023
 ]

ASF GitHub Bot logged work on BEAM-13015:
-----------------------------------------

                Author: ASF GitHub Bot
            Created on: 13/May/22 04:51
            Start Date: 13/May/22 04:51
    Worklog Time Spent: 10m 
      Work Description: y1chi commented on code in PR #17327:
URL: https://github.com/apache/beam/pull/17327#discussion_r871975236


##########
sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java:
##########
@@ -17,434 +17,283 @@
  */
 package org.apache.beam.fn.harness;
 
-import java.util.HashMap;
 import java.util.Iterator;
-import java.util.Map;
+import java.util.LinkedHashMap;
 import java.util.Random;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
+import org.apache.beam.fn.harness.Cache.Shrinkable;
 import org.apache.beam.runners.core.GlobalCombineFnRunner;
 import org.apache.beam.runners.core.GlobalCombineFnRunners;
 import org.apache.beam.runners.core.NullSideInputReader;
-import org.apache.beam.runners.core.SideInputReader;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.SdkHarnessOptions;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.Weighted;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.KV;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
 import org.joda.time.Instant;
 
-/** Static utility methods that provide {@link GroupingTable} implementations. 
*/
+/** Static utility methods that provide a grouping table implementation. */
 @SuppressWarnings({
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
+@NotThreadSafe

Review Comment:
   Document why? Also seems to contradict the requirement of Shrinkable?



##########
sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java:
##########
@@ -17,434 +17,283 @@
  */
 package org.apache.beam.fn.harness;
 
-import java.util.HashMap;
 import java.util.Iterator;
-import java.util.Map;
+import java.util.LinkedHashMap;
 import java.util.Random;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
+import org.apache.beam.fn.harness.Cache.Shrinkable;
 import org.apache.beam.runners.core.GlobalCombineFnRunner;
 import org.apache.beam.runners.core.GlobalCombineFnRunners;
 import org.apache.beam.runners.core.NullSideInputReader;
-import org.apache.beam.runners.core.SideInputReader;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.SdkHarnessOptions;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.Weighted;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.KV;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
 import org.joda.time.Instant;
 
-/** Static utility methods that provide {@link GroupingTable} implementations. 
*/
+/** Static utility methods that provide a grouping table implementation. */
 @SuppressWarnings({
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
+@NotThreadSafe
 public class PrecombineGroupingTable<K, InputT, AccumT>
-    implements GroupingTable<K, InputT, AccumT> {
-  private static long getGroupingTableSizeBytes(PipelineOptions options) {
-    return options.as(SdkHarnessOptions.class).getGroupingTableMaxSizeMb() * 
1024L * 1024L;
-  }
+    implements Shrinkable<PrecombineGroupingTable<K, InputT, AccumT>>, 
Weighted {
+
+  private static final Instant IGNORED = BoundedWindow.TIMESTAMP_MIN_VALUE;
 
-  /** Returns a {@link GroupingTable} that combines inputs into a accumulator. 
*/
-  public static <K, InputT, AccumT> GroupingTable<WindowedValue<K>, InputT, 
AccumT> combining(
+  /**
+   * Returns a grouping table that combines inputs into an accumulator. The 
grouping table uses the
+   * cache to defer flushing output until the cache evicts the table.
+   */
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> 
combining(
       PipelineOptions options,
+      Cache<Object, Object> cache,
       CombineFn<InputT, AccumT, ?> combineFn,
-      Coder<K> keyCoder,
-      Coder<? super AccumT> accumulatorCoder) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), 
NullSideInputReader.empty(), options);
+      Coder<K> keyCoder) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-        new CoderSizeEstimator<>(accumulatorCoder));
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        Caches::weigh,
+        Caches::weigh);
   }
 
   /**
-   * Returns a {@link GroupingTable} that combines inputs into a accumulator 
with sampling {@link
-   * SizeEstimator SizeEstimators}.
+   * Returns a grouping table that combines inputs into an accumulator with 
sampling {@link
+   * SizeEstimator SizeEstimators}. The grouping table uses the cache to defer 
flushing output until
+   * the cache evicts the table.
    */
-  public static <K, InputT, AccumT>
-      GroupingTable<WindowedValue<K>, InputT, AccumT> combiningAndSampling(
-          PipelineOptions options,
-          CombineFn<InputT, AccumT, ?> combineFn,
-          Coder<K> keyCoder,
-          Coder<? super AccumT> accumulatorCoder,
-          double sizeEstimatorSampleRate) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), 
NullSideInputReader.empty(), options);
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> 
combiningAndSampling(
+      PipelineOptions options,
+      Cache<Object, Object> cache,
+      CombineFn<InputT, AccumT, ?> combineFn,
+      Coder<K> keyCoder,
+      double sizeEstimatorSampleRate) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new SamplingSizeEstimator<>(
-            new 
CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-            sizeEstimatorSampleRate,
-            1.0),
-        new SamplingSizeEstimator<>(
-            new CoderSizeEstimator<>(accumulatorCoder), 
sizeEstimatorSampleRate, 1.0));
-  }
-
-  /** Provides client-specific operations for grouping keys. */
-  public interface GroupingKeyCreator<K> {
-    Object createGroupingKey(K key) throws Exception;
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 
1.0),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 
1.0));
   }
 
-  /** Implements Precombine GroupingKeyCreator via Coder. */
-  public static class WindowingCoderGroupingKeyCreator<K>
-      implements GroupingKeyCreator<WindowedValue<K>> {
-
-    private static final Instant ignored = BoundedWindow.TIMESTAMP_MIN_VALUE;
-
-    private final Coder<K> coder;
-
-    WindowingCoderGroupingKeyCreator(Coder<K> coder) {
-      this.coder = coder;
+  @Nullable
+  @Override
+  public PrecombineGroupingTable<K, InputT, AccumT> shrink() {
+    long currentWeight = maxWeight.updateAndGet(operand -> operand >> 1);
+    // It is possible that we are shrunk multiple times until the requested 
max weight is too small.
+    // In this case we want to effectively stop shrinking since we can't 
effectively cache much
+    // at this time and the next insertion will likely evict all records.
+    if (currentWeight <= 100L) {
+      return null;
     }
+    return this;
+  }
 
-    @Override
-    public Object createGroupingKey(WindowedValue<K> key) {
-      // Ignore timestamp for grouping purposes.
-      // The Precombine output will inherit the timestamp of one of its inputs.
-      return WindowedValue.of(
-          coder.structuralValue(key.getValue()), ignored, key.getWindows(), 
key.getPane());
-    }
+  @Override
+  public long getWeight() {
+    return maxWeight.get();
   }
 
   /** Provides client-specific operations for size estimates. */
+  @FunctionalInterface
   public interface SizeEstimator<T> {
-    long estimateSize(T element) throws Exception;
+    long estimateSize(T element);
   }
 
-  /** Implements SizeEstimator via Coder. */
-  public static class CoderSizeEstimator<T> implements SizeEstimator<T> {
-    /** Basic implementation of {@link ElementByteSizeObserver} for use in 
size estimation. */
-    private static class Observer extends ElementByteSizeObserver {
-      private long observedSize = 0;
-
-      @Override
-      protected void reportElementSize(long elementSize) {
-        observedSize += elementSize;
-      }
-    }
-
-    final Coder<T> coder;
+  private final Coder<K> keyCoder;
+  private final GlobalCombineFnRunner<InputT, AccumT, ?> combineFn;
+  private final PipelineOptions options;
+  private final SizeEstimator<K> keySizer;
+  private final SizeEstimator<AccumT> accumulatorSizer;
+  private final Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>> cache;
+  private final LinkedHashMap<WindowedValue<Object>, GroupingTableEntry> 
lruMap;
+  private final AtomicLong maxWeight;
+  private long weight;
 
-    CoderSizeEstimator(Coder<T> coder) {
-      this.coder = coder;
-    }
+  private static final class Key implements Weighted {
+    private static final Key INSTANCE = new Key();
 
     @Override
-    public long estimateSize(T value) throws Exception {
-      // First try using byte size observer
-      CoderSizeEstimator.Observer observer = new CoderSizeEstimator.Observer();
-      coder.registerByteSizeObserver(value, observer);
-
-      if (!observer.getIsLazy()) {
-        observer.advance();
-        return observer.observedSize;
-      } else {
-        // Coder byte size observation is lazy (requires iteration for 
observation) so fall back to
-        // counting output stream
-        CountingOutputStream os = new 
CountingOutputStream(ByteStreams.nullOutputStream());
-        coder.encode(value, os);
-        return os.getCount();
-      }
+    public long getWeight() {
+      // Ignore the actual size of this singleton because it is trivial and 
because
+      // the weight reported here will be counted many times as it is present 
in
+      // many different state subcaches.
+      return 0;
     }
   }
 
-  /**
-   * Provides client-specific operations for working with elements that are 
key/value or key/values
-   * pairs.
-   */
-  public interface PairInfo {
-    Object getKeyFromInputPair(Object pair);
-
-    Object getValueFromInputPair(Object pair);
-
-    Object makeOutputPair(Object key, Object value);
-  }
-
-  /** Implements Precombine PairInfo via KVs. */
-  public static class WindowedPairInfo implements PairInfo {
-    private static WindowedPairInfo theInstance = new WindowedPairInfo();
-
-    public static WindowedPairInfo create() {
-      return theInstance;
-    }
-
-    private WindowedPairInfo() {}
-
-    @Override
-    public Object getKeyFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.withValue(windowedKv.getValue().getKey());
-    }
-
-    @Override
-    public Object getValueFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.getValue().getValue();
-    }
-
-    @Override
-    public Object makeOutputPair(Object key, Object values) {
-      WindowedValue<?> windowedKey = (WindowedValue<?>) key;
-      return windowedKey.withValue(KV.of(windowedKey.getValue(), values));
-    }
-  }
-
-  /** Provides client-specific operations for combining values. */
-  public interface Combiner<K, InputT, AccumT, OutputT> {
-    AccumT createAccumulator(K key);
-
-    AccumT add(K key, AccumT accumulator, InputT value);
-
-    AccumT merge(K key, Iterable<AccumT> accumulators);
-
-    AccumT compact(K key, AccumT accumulator);
-
-    OutputT extract(K key, AccumT accumulator);
+  PrecombineGroupingTable(
+      PipelineOptions options,
+      Cache<?, ?> cache,
+      Coder<K> keyCoder,
+      GlobalCombineFnRunner<InputT, AccumT, ?> combineFn,
+      SizeEstimator<K> keySizer,
+      SizeEstimator<AccumT> accumulatorSizer) {
+    this.options = options;
+    this.cache = (Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>>) 
cache;
+    this.keyCoder = keyCoder;
+    this.combineFn = combineFn;
+    this.keySizer = keySizer;
+    this.accumulatorSizer = accumulatorSizer;
+    this.lruMap = new LinkedHashMap<>(16, 0.75f, true);
+    this.maxWeight = new AtomicLong();
+    this.weight = 0L;
+    this.cache.put(Key.INSTANCE, this);
   }
 
-  /** Implements Precombine Combiner via Combine.KeyedCombineFn. */
-  public static class ValueCombiner<K, InputT, AccumT, OutputT>
-      implements Combiner<WindowedValue<K>, InputT, AccumT, OutputT> {
-    private final GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn;
-    private final SideInputReader sideInputReader;
-    private final PipelineOptions options;
-
-    private ValueCombiner(
-        GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn,
-        SideInputReader sideInputReader,
-        PipelineOptions options) {
-      this.combineFn = combineFn;
-      this.sideInputReader = sideInputReader;
-      this.options = options;
-    }
-
-    @Override
-    public AccumT createAccumulator(WindowedValue<K> windowedKey) {
-      return this.combineFn.createAccumulator(options, sideInputReader, 
windowedKey.getWindows());
+  private class GroupingTableEntry implements Weighted {
+    private final WindowedValue<Object> groupingKey;
+    private final K userKey;
+    private final long keySize;
+    private long accumulatorSize;
+    private AccumT accumulator;
+    private boolean dirty;
+
+    private GroupingTableEntry(
+        WindowedValue<Object> groupingKey, K userKey, InputT 
initialInputValue) {
+      this.groupingKey = groupingKey;
+      this.userKey = userKey;
+      if (groupingKey.getValue() == userKey) {
+        // This object is only storing references to the same objects that are 
being stored
+        // by the cache so the accounting of the size of the key is occurring 
already.
+        this.keySize = Caches.REFERENCE_SIZE * 2;
+      } else {
+        this.keySize = Caches.REFERENCE_SIZE + keySizer.estimateSize(userKey);
+      }
+      this.accumulator =
+          combineFn.createAccumulator(
+              options, NullSideInputReader.empty(), groupingKey.getWindows());
+      add(initialInputValue);
+      this.accumulatorSize = accumulatorSizer.estimateSize(accumulator);
     }
 
-    @Override
-    public AccumT add(WindowedValue<K> windowedKey, AccumT accumulator, InputT 
value) {
-      return this.combineFn.addInput(
-          accumulator, value, options, sideInputReader, 
windowedKey.getWindows());
+    public WindowedValue<Object> getGroupingKey() {
+      return groupingKey;
     }
 
-    @Override
-    public AccumT merge(WindowedValue<K> windowedKey, Iterable<AccumT> 
accumulators) {
-      return this.combineFn.mergeAccumulators(
-          accumulators, options, sideInputReader, windowedKey.getWindows());
+    public K getKey() {
+      return userKey;
     }
 
-    @Override
-    public AccumT compact(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.compact(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public AccumT getValue() {
+      return accumulator;
     }
 
     @Override
-    public OutputT extract(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.extractOutput(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public long getWeight() {
+      return keySize + accumulatorSize;
     }
-  }
-
-  // How many bytes a word in the JVM has.
-  private static final int BYTES_PER_JVM_WORD = getBytesPerJvmWord();
-  /**
-   * The number of bytes of overhead to store an entry in the grouping table 
(a {@code
-   * HashMap<StructuralByteArray, KeyAndValues>}), ignoring the actual number 
of bytes in the keys
-   * and values:
-   *
-   * <ul>
-   *   <li>an array element (1 word),
-   *   <li>a HashMap.Entry (4 words),
-   *   <li>a StructuralByteArray (1 words),
-   *   <li>a backing array (guessed at 1 word for the length),
-   *   <li>a KeyAndValues (2 words),
-   *   <li>an ArrayList (2 words),
-   *   <li>a backing array (1 word),
-   *   <li>per-object overhead (JVM-specific, guessed at 2 words * 6 objects).
-   * </ul>
-   */
-  private static final int PER_KEY_OVERHEAD = 24 * BYTES_PER_JVM_WORD;
-
-  /** A {@link GroupingTable} that uses the given combiner to combine values 
in place. */
-  // Keep the table relatively full to increase the chance of collisions.
-  private static final double TARGET_LOAD = 0.9;
-
-  private long maxSize;
-  private final GroupingKeyCreator<? super K> groupingKeyCreator;
-  private final PairInfo pairInfo;
-  private final Combiner<? super K, InputT, AccumT, ?> combiner;
-  private final SizeEstimator<? super K> keySizer;
-  private final SizeEstimator<? super AccumT> accumulatorSizer;
-
-  private long size = 0;
-  private Map<Object, GroupingTableEntry<K, InputT, AccumT>> table;
-
-  PrecombineGroupingTable(
-      long maxSize,
-      GroupingKeyCreator<? super K> groupingKeyCreator,
-      PairInfo pairInfo,
-      Combiner<? super K, InputT, AccumT, ?> combineFn,
-      SizeEstimator<? super K> keySizer,
-      SizeEstimator<? super AccumT> accumulatorSizer) {
-    this.maxSize = maxSize;
-    this.groupingKeyCreator = groupingKeyCreator;
-    this.pairInfo = pairInfo;
-    this.combiner = combineFn;
-    this.keySizer = keySizer;
-    this.accumulatorSizer = accumulatorSizer;
-    this.table = new HashMap<>();
-  }
-
-  interface GroupingTableEntry<K, InputT, AccumT> {
-    K getKey();
-
-    AccumT getValue();
-
-    void add(InputT value) throws Exception;
-
-    long getSize();
-
-    void compact() throws Exception;
-  }
-
-  private GroupingTableEntry<K, InputT, AccumT> createTableEntry(final K key) 
throws Exception {
-    return new GroupingTableEntry<K, InputT, AccumT>() {
-      final long keySize = keySizer.estimateSize(key);
-      AccumT accumulator = combiner.createAccumulator(key);
-      long accumulatorSize = 0; // never used before a value is added...
-
-      @Override
-      public K getKey() {
-        return key;
-      }
-
-      @Override
-      public AccumT getValue() {
-        return accumulator;
-      }
-
-      @Override
-      public long getSize() {
-        return keySize + accumulatorSize;
-      }
-
-      @Override
-      public void compact() throws Exception {
-        AccumT newAccumulator = combiner.compact(key, accumulator);
-        if (newAccumulator != accumulator) {
-          accumulator = newAccumulator;
-          accumulatorSize = accumulatorSizer.estimateSize(newAccumulator);
-        }
-      }
 
-      @Override
-      public void add(InputT value) throws Exception {
-        accumulator = combiner.add(key, accumulator, value);
+    public void compact() {
+      if (dirty) {
+        accumulator =
+            combineFn.compact(
+                accumulator, options, NullSideInputReader.empty(), 
groupingKey.getWindows());
         accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+        dirty = false;
       }
-    };
-  }
+    }
 
-  /** Adds a pair to this table, possibly flushing some entries to output if 
the table is full. */
-  @SuppressWarnings("unchecked")
-  @Override
-  public void put(Object pair, Receiver receiver) throws Exception {
-    put(
-        (K) pairInfo.getKeyFromInputPair(pair),
-        (InputT) pairInfo.getValueFromInputPair(pair),
-        receiver);
+    public void add(InputT value) {
+      dirty = true;
+      accumulator =
+          combineFn.addInput(
+              accumulator, value, options, NullSideInputReader.empty(), 
groupingKey.getWindows());
+      accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+    }
   }
 
   /**
    * Adds the key and value to this table, possibly flushing some entries to 
output if the table is
    * full.
    */
-  public void put(K key, InputT value, Receiver receiver) throws Exception {
-    Object groupingKey = groupingKeyCreator.createGroupingKey(key);
-    GroupingTableEntry<K, InputT, AccumT> entry = table.get(groupingKey);
-    if (entry == null) {
-      entry = createTableEntry(key);
-      table.put(groupingKey, entry);
-      size += PER_KEY_OVERHEAD;
-    } else {
-      size -= entry.getSize();
-    }
-    entry.add(value);
-    size += entry.getSize();
-
-    if (size >= maxSize) {
-      long targetSize = (long) (TARGET_LOAD * maxSize);
-      Iterator<GroupingTableEntry<K, InputT, AccumT>> entries = 
table.values().iterator();
-      while (size >= targetSize) {
-        if (!entries.hasNext()) {
-          // Should never happen, but sizes may be estimates...
-          size = 0;
-          break;
+  @VisibleForTesting
+  public void put(
+      WindowedValue<KV<K, InputT>> value, FnDataReceiver<WindowedValue<KV<K, 
AccumT>>> receiver)
+      throws Exception {
+    // Ignore timestamp for grouping purposes.
+    // The Pre-combine output will inherit the timestamp of one of its inputs.
+    WindowedValue<Object> groupingKey =
+        WindowedValue.of(
+            keyCoder.structuralValue(value.getValue().getKey()),
+            IGNORED,
+            value.getWindows(),
+            value.getPane());
+
+    GroupingTableEntry entry =
+        lruMap.compute(
+            groupingKey,
+            (key, tableEntry) -> {
+              if (tableEntry == null) {
+                tableEntry =
+                    new GroupingTableEntry(
+                        groupingKey, value.getValue().getKey(), 
value.getValue().getValue());
+              } else {
+                tableEntry.add(value.getValue().getValue());
+              }
+              return tableEntry;
+            });
+    weight += entry.getWeight();

Review Comment:
   is this accurate if entry is not new?



##########
sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PrecombineGroupingTableTest.java:
##########
@@ -46,79 +60,207 @@
 @RunWith(JUnit4.class)
 public class PrecombineGroupingTableTest {
 
-  private static class TestOutputReceiver implements Receiver {
-    final List<Object> outputElems = new ArrayList<>();
+  @Rule
+  public TestExecutorService executorService = 
TestExecutors.from(Executors.newCachedThreadPool());
+
+  private static class TestOutputReceiver<T> implements FnDataReceiver<T> {
+    final List<T> outputElems = new ArrayList<>();
 
     @Override
-    public void process(Object elem) {
+    public void accept(T elem) {
       outputElems.add(elem);
     }
   }
 
-  @Test
-  public void testCombiningGroupingTable() throws Exception {
-    Combiner<Object, Integer, Long, Long> summingCombineFn =
-        new Combiner<Object, Integer, Long, Long>() {
+  private static final CombineFn<Integer, Long, Long> COMBINE_FN =
+      new CombineFn<Integer, Long, Long>() {
 
-          @Override
-          public Long createAccumulator(Object key) {
-            return 0L;
-          }
+        @Override
+        public Long createAccumulator() {
+          return 0L;
+        }
 
-          @Override
-          public Long add(Object key, Long accumulator, Integer value) {
-            return accumulator + value;
-          }
+        @Override
+        public Long addInput(Long accumulator, Integer value) {
+          return accumulator + value;
+        }
 
-          @Override
-          public Long merge(Object key, Iterable<Long> accumulators) {
-            long sum = 0;
-            for (Long part : accumulators) {
-              sum += part;
-            }
-            return sum;
+        @Override
+        public Long mergeAccumulators(Iterable<Long> accumulators) {
+          long sum = 0;
+          for (Long part : accumulators) {
+            sum += part;
           }
+          return sum;
+        }
 
-          @Override
-          public Long compact(Object key, Long accumulator) {
-            return accumulator;
+        @Override
+        public Long compact(Long accumulator) {
+          if (accumulator % 2 == 0) {
+            return accumulator / 4;
           }
+          return accumulator;
+        }
 
-          @Override
-          public Long extract(Object key, Long accumulator) {
-            return accumulator;
-          }
-        };
+        @Override
+        public Long extractOutput(Long accumulator) {
+          return accumulator;
+        }
+      };
 
+  @Test
+  public void testCombiningGroupingTableEvictsAllOnLargeEntry() throws 
Exception {
     PrecombineGroupingTable<String, Integer, Long> table =
         new PrecombineGroupingTable<>(
-            100_000_000L,
-            new IdentityGroupingKeyCreator(),
-            new KvPairInfo(),
-            summingCombineFn,
+            PipelineOptionsFactory.create(),
+            Caches.forMaximumBytes(2500L),
+            StringUtf8Coder.of(),
+            GlobalCombineFnRunners.create(COMBINE_FN),
             new StringPowerSizeEstimator(),
             new IdentitySizeEstimator());
-    table.setMaxSize(1000);
 
-    TestOutputReceiver receiver = new TestOutputReceiver();
+    TestOutputReceiver<WindowedValue<KV<String, Long>>> receiver = new 
TestOutputReceiver<>();
 
-    table.put("A", 1, receiver);
-    table.put("B", 2, receiver);
-    table.put("B", 3, receiver);
-    table.put("C", 4, receiver);
+    table.put(valueInGlobalWindow(KV.of("A", 1)), receiver);
+    table.put(valueInGlobalWindow(KV.of("B", 3)), receiver);
+    table.put(valueInGlobalWindow(KV.of("B", 6)), receiver);
+    table.put(valueInGlobalWindow(KV.of("C", 7)), receiver);
     assertThat(receiver.outputElems, empty());
 
-    table.put("C", 5000, receiver);
-    assertThat(receiver.outputElems, hasItem((Object) KV.of("C", 5004L)));
+    // Add beyond the size which causes compaction which still leads to 
evicting all since the
+    // largest is most recent.
+    table.put(valueInGlobalWindow(KV.of("C", 9999)), receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("A", 1L)),
+            valueInGlobalWindow(KV.of("B", 9L)),
+            valueInGlobalWindow(KV.of("C", (9999L + 7) / 4))));
+
+    table.flush(receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("A", 1L)),
+            valueInGlobalWindow(KV.of("B", 3L + 6)),
+            valueInGlobalWindow(KV.of("C", (9999L + 7) / 4))));
+  }
+
+  @Test
+  public void testCombiningGroupingTableCompactionSaves() throws Exception {
+    PrecombineGroupingTable<String, Integer, Long> table =
+        new PrecombineGroupingTable<>(
+            PipelineOptionsFactory.create(),
+            Caches.forMaximumBytes(2500L),
+            StringUtf8Coder.of(),
+            GlobalCombineFnRunners.create(COMBINE_FN),
+            new StringPowerSizeEstimator(),
+            new IdentitySizeEstimator());
+
+    TestOutputReceiver<WindowedValue<KV<String, Long>>> receiver = new 
TestOutputReceiver<>();
+
+    // Insert three compactable values which shouldn't lead to eviction even 
though we are over
+    // the maximum size.
+    table.put(valueInGlobalWindow(KV.of("A", 1004)), receiver);
+    table.put(valueInGlobalWindow(KV.of("B", 1004)), receiver);
+    table.put(valueInGlobalWindow(KV.of("C", 1004)), receiver);
+    assertThat(receiver.outputElems, empty());
+
+    table.flush(receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("A", 1004L / 4)),
+            valueInGlobalWindow(KV.of("B", 1004L / 4)),
+            valueInGlobalWindow(KV.of("C", 1004L / 4))));
+  }
+
+  @Test
+  public void testCombiningGroupingTablePartialEviction() throws Exception {
+    PrecombineGroupingTable<String, Integer, Long> table =
+        new PrecombineGroupingTable<>(
+            PipelineOptionsFactory.create(),
+            Caches.forMaximumBytes(2500L),
+            StringUtf8Coder.of(),
+            GlobalCombineFnRunners.create(COMBINE_FN),
+            new StringPowerSizeEstimator(),
+            new IdentitySizeEstimator());
+
+    TestOutputReceiver<WindowedValue<KV<String, Long>>> receiver = new 
TestOutputReceiver<>();
 
-    table.put("DDDD", 6, receiver);
-    assertThat(receiver.outputElems, hasItem((Object) KV.of("DDDD", 6L)));
+    // Insert three values which even with compaction isn't enough so we evict 
D & E to get

Review Comment:
   s/D & E/A & B/



##########
sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java:
##########
@@ -17,434 +17,283 @@
  */
 package org.apache.beam.fn.harness;
 
-import java.util.HashMap;
 import java.util.Iterator;
-import java.util.Map;
+import java.util.LinkedHashMap;
 import java.util.Random;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
+import org.apache.beam.fn.harness.Cache.Shrinkable;
 import org.apache.beam.runners.core.GlobalCombineFnRunner;
 import org.apache.beam.runners.core.GlobalCombineFnRunners;
 import org.apache.beam.runners.core.NullSideInputReader;
-import org.apache.beam.runners.core.SideInputReader;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.SdkHarnessOptions;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.Weighted;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.KV;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
 import org.joda.time.Instant;
 
-/** Static utility methods that provide {@link GroupingTable} implementations. 
*/
+/** Static utility methods that provide a grouping table implementation. */
 @SuppressWarnings({
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
+@NotThreadSafe
 public class PrecombineGroupingTable<K, InputT, AccumT>
-    implements GroupingTable<K, InputT, AccumT> {
-  private static long getGroupingTableSizeBytes(PipelineOptions options) {
-    return options.as(SdkHarnessOptions.class).getGroupingTableMaxSizeMb() * 
1024L * 1024L;
-  }
+    implements Shrinkable<PrecombineGroupingTable<K, InputT, AccumT>>, 
Weighted {
+
+  private static final Instant IGNORED = BoundedWindow.TIMESTAMP_MIN_VALUE;
 
-  /** Returns a {@link GroupingTable} that combines inputs into a accumulator. 
*/
-  public static <K, InputT, AccumT> GroupingTable<WindowedValue<K>, InputT, 
AccumT> combining(
+  /**
+   * Returns a grouping table that combines inputs into an accumulator. The 
grouping table uses the
+   * cache to defer flushing output until the cache evicts the table.
+   */
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> 
combining(
       PipelineOptions options,
+      Cache<Object, Object> cache,
       CombineFn<InputT, AccumT, ?> combineFn,
-      Coder<K> keyCoder,
-      Coder<? super AccumT> accumulatorCoder) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), 
NullSideInputReader.empty(), options);
+      Coder<K> keyCoder) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-        new CoderSizeEstimator<>(accumulatorCoder));
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        Caches::weigh,
+        Caches::weigh);
   }
 
   /**
-   * Returns a {@link GroupingTable} that combines inputs into a accumulator 
with sampling {@link
-   * SizeEstimator SizeEstimators}.
+   * Returns a grouping table that combines inputs into an accumulator with 
sampling {@link
+   * SizeEstimator SizeEstimators}. The grouping table uses the cache to defer 
flushing output until
+   * the cache evicts the table.
    */
-  public static <K, InputT, AccumT>
-      GroupingTable<WindowedValue<K>, InputT, AccumT> combiningAndSampling(
-          PipelineOptions options,
-          CombineFn<InputT, AccumT, ?> combineFn,
-          Coder<K> keyCoder,
-          Coder<? super AccumT> accumulatorCoder,
-          double sizeEstimatorSampleRate) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), 
NullSideInputReader.empty(), options);
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> 
combiningAndSampling(
+      PipelineOptions options,
+      Cache<Object, Object> cache,
+      CombineFn<InputT, AccumT, ?> combineFn,
+      Coder<K> keyCoder,
+      double sizeEstimatorSampleRate) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new SamplingSizeEstimator<>(
-            new 
CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-            sizeEstimatorSampleRate,
-            1.0),
-        new SamplingSizeEstimator<>(
-            new CoderSizeEstimator<>(accumulatorCoder), 
sizeEstimatorSampleRate, 1.0));
-  }
-
-  /** Provides client-specific operations for grouping keys. */
-  public interface GroupingKeyCreator<K> {
-    Object createGroupingKey(K key) throws Exception;
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 
1.0),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 
1.0));
   }
 
-  /** Implements Precombine GroupingKeyCreator via Coder. */
-  public static class WindowingCoderGroupingKeyCreator<K>
-      implements GroupingKeyCreator<WindowedValue<K>> {
-
-    private static final Instant ignored = BoundedWindow.TIMESTAMP_MIN_VALUE;
-
-    private final Coder<K> coder;
-
-    WindowingCoderGroupingKeyCreator(Coder<K> coder) {
-      this.coder = coder;
+  @Nullable
+  @Override
+  public PrecombineGroupingTable<K, InputT, AccumT> shrink() {
+    long currentWeight = maxWeight.updateAndGet(operand -> operand >> 1);
+    // It is possible that we are shrunk multiple times until the requested 
max weight is too small.
+    // In this case we want to effectively stop shrinking since we can't 
effectively cache much
+    // at this time and the next insertion will likely evict all records.
+    if (currentWeight <= 100L) {
+      return null;
     }
+    return this;
+  }
 
-    @Override
-    public Object createGroupingKey(WindowedValue<K> key) {
-      // Ignore timestamp for grouping purposes.
-      // The Precombine output will inherit the timestamp of one of its inputs.
-      return WindowedValue.of(
-          coder.structuralValue(key.getValue()), ignored, key.getWindows(), 
key.getPane());
-    }
+  @Override
+  public long getWeight() {
+    return maxWeight.get();
   }
 
   /** Provides client-specific operations for size estimates. */
+  @FunctionalInterface
   public interface SizeEstimator<T> {
-    long estimateSize(T element) throws Exception;
+    long estimateSize(T element);
   }
 
-  /** Implements SizeEstimator via Coder. */
-  public static class CoderSizeEstimator<T> implements SizeEstimator<T> {
-    /** Basic implementation of {@link ElementByteSizeObserver} for use in 
size estimation. */
-    private static class Observer extends ElementByteSizeObserver {
-      private long observedSize = 0;
-
-      @Override
-      protected void reportElementSize(long elementSize) {
-        observedSize += elementSize;
-      }
-    }
-
-    final Coder<T> coder;
+  private final Coder<K> keyCoder;
+  private final GlobalCombineFnRunner<InputT, AccumT, ?> combineFn;
+  private final PipelineOptions options;
+  private final SizeEstimator<K> keySizer;
+  private final SizeEstimator<AccumT> accumulatorSizer;
+  private final Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>> cache;
+  private final LinkedHashMap<WindowedValue<Object>, GroupingTableEntry> 
lruMap;
+  private final AtomicLong maxWeight;
+  private long weight;
 
-    CoderSizeEstimator(Coder<T> coder) {
-      this.coder = coder;
-    }
+  private static final class Key implements Weighted {
+    private static final Key INSTANCE = new Key();
 
     @Override
-    public long estimateSize(T value) throws Exception {
-      // First try using byte size observer
-      CoderSizeEstimator.Observer observer = new CoderSizeEstimator.Observer();
-      coder.registerByteSizeObserver(value, observer);
-
-      if (!observer.getIsLazy()) {
-        observer.advance();
-        return observer.observedSize;
-      } else {
-        // Coder byte size observation is lazy (requires iteration for 
observation) so fall back to
-        // counting output stream
-        CountingOutputStream os = new 
CountingOutputStream(ByteStreams.nullOutputStream());
-        coder.encode(value, os);
-        return os.getCount();
-      }
+    public long getWeight() {
+      // Ignore the actual size of this singleton because it is trivial and 
because
+      // the weight reported here will be counted many times as it is present 
in
+      // many different state subcaches.
+      return 0;
     }
   }
 
-  /**
-   * Provides client-specific operations for working with elements that are 
key/value or key/values
-   * pairs.
-   */
-  public interface PairInfo {
-    Object getKeyFromInputPair(Object pair);
-
-    Object getValueFromInputPair(Object pair);
-
-    Object makeOutputPair(Object key, Object value);
-  }
-
-  /** Implements Precombine PairInfo via KVs. */
-  public static class WindowedPairInfo implements PairInfo {
-    private static WindowedPairInfo theInstance = new WindowedPairInfo();
-
-    public static WindowedPairInfo create() {
-      return theInstance;
-    }
-
-    private WindowedPairInfo() {}
-
-    @Override
-    public Object getKeyFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.withValue(windowedKv.getValue().getKey());
-    }
-
-    @Override
-    public Object getValueFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.getValue().getValue();
-    }
-
-    @Override
-    public Object makeOutputPair(Object key, Object values) {
-      WindowedValue<?> windowedKey = (WindowedValue<?>) key;
-      return windowedKey.withValue(KV.of(windowedKey.getValue(), values));
-    }
-  }
-
-  /** Provides client-specific operations for combining values. */
-  public interface Combiner<K, InputT, AccumT, OutputT> {
-    AccumT createAccumulator(K key);
-
-    AccumT add(K key, AccumT accumulator, InputT value);
-
-    AccumT merge(K key, Iterable<AccumT> accumulators);
-
-    AccumT compact(K key, AccumT accumulator);
-
-    OutputT extract(K key, AccumT accumulator);
+  PrecombineGroupingTable(
+      PipelineOptions options,
+      Cache<?, ?> cache,
+      Coder<K> keyCoder,
+      GlobalCombineFnRunner<InputT, AccumT, ?> combineFn,
+      SizeEstimator<K> keySizer,
+      SizeEstimator<AccumT> accumulatorSizer) {
+    this.options = options;
+    this.cache = (Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>>) 
cache;
+    this.keyCoder = keyCoder;
+    this.combineFn = combineFn;
+    this.keySizer = keySizer;
+    this.accumulatorSizer = accumulatorSizer;
+    this.lruMap = new LinkedHashMap<>(16, 0.75f, true);
+    this.maxWeight = new AtomicLong();
+    this.weight = 0L;
+    this.cache.put(Key.INSTANCE, this);
   }
 
-  /** Implements Precombine Combiner via Combine.KeyedCombineFn. */
-  public static class ValueCombiner<K, InputT, AccumT, OutputT>
-      implements Combiner<WindowedValue<K>, InputT, AccumT, OutputT> {
-    private final GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn;
-    private final SideInputReader sideInputReader;
-    private final PipelineOptions options;
-
-    private ValueCombiner(
-        GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn,
-        SideInputReader sideInputReader,
-        PipelineOptions options) {
-      this.combineFn = combineFn;
-      this.sideInputReader = sideInputReader;
-      this.options = options;
-    }
-
-    @Override
-    public AccumT createAccumulator(WindowedValue<K> windowedKey) {
-      return this.combineFn.createAccumulator(options, sideInputReader, 
windowedKey.getWindows());
+  private class GroupingTableEntry implements Weighted {
+    private final WindowedValue<Object> groupingKey;
+    private final K userKey;
+    private final long keySize;
+    private long accumulatorSize;
+    private AccumT accumulator;
+    private boolean dirty;
+
+    private GroupingTableEntry(
+        WindowedValue<Object> groupingKey, K userKey, InputT 
initialInputValue) {
+      this.groupingKey = groupingKey;
+      this.userKey = userKey;
+      if (groupingKey.getValue() == userKey) {
+        // This object is only storing references to the same objects that are 
being stored
+        // by the cache so the accounting of the size of the key is occurring 
already.
+        this.keySize = Caches.REFERENCE_SIZE * 2;
+      } else {
+        this.keySize = Caches.REFERENCE_SIZE + keySizer.estimateSize(userKey);
+      }
+      this.accumulator =
+          combineFn.createAccumulator(
+              options, NullSideInputReader.empty(), groupingKey.getWindows());
+      add(initialInputValue);
+      this.accumulatorSize = accumulatorSizer.estimateSize(accumulator);
     }
 
-    @Override
-    public AccumT add(WindowedValue<K> windowedKey, AccumT accumulator, InputT 
value) {
-      return this.combineFn.addInput(
-          accumulator, value, options, sideInputReader, 
windowedKey.getWindows());
+    public WindowedValue<Object> getGroupingKey() {
+      return groupingKey;
     }
 
-    @Override
-    public AccumT merge(WindowedValue<K> windowedKey, Iterable<AccumT> 
accumulators) {
-      return this.combineFn.mergeAccumulators(
-          accumulators, options, sideInputReader, windowedKey.getWindows());
+    public K getKey() {
+      return userKey;
     }
 
-    @Override
-    public AccumT compact(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.compact(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public AccumT getValue() {
+      return accumulator;
     }
 
     @Override
-    public OutputT extract(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.extractOutput(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public long getWeight() {
+      return keySize + accumulatorSize;
     }
-  }
-
-  // How many bytes a word in the JVM has.
-  private static final int BYTES_PER_JVM_WORD = getBytesPerJvmWord();
-  /**
-   * The number of bytes of overhead to store an entry in the grouping table 
(a {@code
-   * HashMap<StructuralByteArray, KeyAndValues>}), ignoring the actual number 
of bytes in the keys
-   * and values:
-   *
-   * <ul>
-   *   <li>an array element (1 word),
-   *   <li>a HashMap.Entry (4 words),
-   *   <li>a StructuralByteArray (1 words),
-   *   <li>a backing array (guessed at 1 word for the length),
-   *   <li>a KeyAndValues (2 words),
-   *   <li>an ArrayList (2 words),
-   *   <li>a backing array (1 word),
-   *   <li>per-object overhead (JVM-specific, guessed at 2 words * 6 objects).
-   * </ul>
-   */
-  private static final int PER_KEY_OVERHEAD = 24 * BYTES_PER_JVM_WORD;
-
-  /** A {@link GroupingTable} that uses the given combiner to combine values 
in place. */
-  // Keep the table relatively full to increase the chance of collisions.
-  private static final double TARGET_LOAD = 0.9;
-
-  private long maxSize;
-  private final GroupingKeyCreator<? super K> groupingKeyCreator;
-  private final PairInfo pairInfo;
-  private final Combiner<? super K, InputT, AccumT, ?> combiner;
-  private final SizeEstimator<? super K> keySizer;
-  private final SizeEstimator<? super AccumT> accumulatorSizer;
-
-  private long size = 0;
-  private Map<Object, GroupingTableEntry<K, InputT, AccumT>> table;
-
-  PrecombineGroupingTable(
-      long maxSize,
-      GroupingKeyCreator<? super K> groupingKeyCreator,
-      PairInfo pairInfo,
-      Combiner<? super K, InputT, AccumT, ?> combineFn,
-      SizeEstimator<? super K> keySizer,
-      SizeEstimator<? super AccumT> accumulatorSizer) {
-    this.maxSize = maxSize;
-    this.groupingKeyCreator = groupingKeyCreator;
-    this.pairInfo = pairInfo;
-    this.combiner = combineFn;
-    this.keySizer = keySizer;
-    this.accumulatorSizer = accumulatorSizer;
-    this.table = new HashMap<>();
-  }
-
-  interface GroupingTableEntry<K, InputT, AccumT> {
-    K getKey();
-
-    AccumT getValue();
-
-    void add(InputT value) throws Exception;
-
-    long getSize();
-
-    void compact() throws Exception;
-  }
-
-  private GroupingTableEntry<K, InputT, AccumT> createTableEntry(final K key) 
throws Exception {
-    return new GroupingTableEntry<K, InputT, AccumT>() {
-      final long keySize = keySizer.estimateSize(key);
-      AccumT accumulator = combiner.createAccumulator(key);
-      long accumulatorSize = 0; // never used before a value is added...
-
-      @Override
-      public K getKey() {
-        return key;
-      }
-
-      @Override
-      public AccumT getValue() {
-        return accumulator;
-      }
-
-      @Override
-      public long getSize() {
-        return keySize + accumulatorSize;
-      }
-
-      @Override
-      public void compact() throws Exception {
-        AccumT newAccumulator = combiner.compact(key, accumulator);
-        if (newAccumulator != accumulator) {
-          accumulator = newAccumulator;
-          accumulatorSize = accumulatorSizer.estimateSize(newAccumulator);
-        }
-      }
 
-      @Override
-      public void add(InputT value) throws Exception {
-        accumulator = combiner.add(key, accumulator, value);
+    public void compact() {
+      if (dirty) {
+        accumulator =
+            combineFn.compact(
+                accumulator, options, NullSideInputReader.empty(), 
groupingKey.getWindows());
         accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+        dirty = false;
       }
-    };
-  }
+    }
 
-  /** Adds a pair to this table, possibly flushing some entries to output if 
the table is full. */
-  @SuppressWarnings("unchecked")
-  @Override
-  public void put(Object pair, Receiver receiver) throws Exception {
-    put(
-        (K) pairInfo.getKeyFromInputPair(pair),
-        (InputT) pairInfo.getValueFromInputPair(pair),
-        receiver);
+    public void add(InputT value) {
+      dirty = true;
+      accumulator =
+          combineFn.addInput(
+              accumulator, value, options, NullSideInputReader.empty(), 
groupingKey.getWindows());
+      accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+    }
   }
 
   /**
    * Adds the key and value to this table, possibly flushing some entries to 
output if the table is
    * full.
    */
-  public void put(K key, InputT value, Receiver receiver) throws Exception {
-    Object groupingKey = groupingKeyCreator.createGroupingKey(key);
-    GroupingTableEntry<K, InputT, AccumT> entry = table.get(groupingKey);
-    if (entry == null) {
-      entry = createTableEntry(key);
-      table.put(groupingKey, entry);
-      size += PER_KEY_OVERHEAD;
-    } else {
-      size -= entry.getSize();
-    }
-    entry.add(value);
-    size += entry.getSize();
-
-    if (size >= maxSize) {
-      long targetSize = (long) (TARGET_LOAD * maxSize);
-      Iterator<GroupingTableEntry<K, InputT, AccumT>> entries = 
table.values().iterator();
-      while (size >= targetSize) {
-        if (!entries.hasNext()) {
-          // Should never happen, but sizes may be estimates...
-          size = 0;
-          break;
+  @VisibleForTesting
+  public void put(
+      WindowedValue<KV<K, InputT>> value, FnDataReceiver<WindowedValue<KV<K, 
AccumT>>> receiver)
+      throws Exception {
+    // Ignore timestamp for grouping purposes.
+    // The Pre-combine output will inherit the timestamp of one of its inputs.
+    WindowedValue<Object> groupingKey =
+        WindowedValue.of(
+            keyCoder.structuralValue(value.getValue().getKey()),
+            IGNORED,
+            value.getWindows(),
+            value.getPane());
+
+    GroupingTableEntry entry =
+        lruMap.compute(
+            groupingKey,
+            (key, tableEntry) -> {
+              if (tableEntry == null) {
+                tableEntry =
+                    new GroupingTableEntry(
+                        groupingKey, value.getValue().getKey(), 
value.getValue().getValue());
+              } else {
+                tableEntry.add(value.getValue().getValue());
+              }
+              return tableEntry;
+            });
+    weight += entry.getWeight();
+    // Increase the maximum only if we require it
+    maxWeight.accumulateAndGet(weight, (current, update) -> current < update ? 
update : current);
+
+    // Update the cache to ensure that LRU is handled appropriately and for 
the cache to have an
+    // opportunity to shrink the maxWeight if necessary.
+    cache.put(Key.INSTANCE, this);
+
+    // Get the updated weight now that the cache may have been shrunk and 
respect it
+    long currentMax = maxWeight.get();
+    if (weight > currentMax) {

Review Comment:
   If this is triggered by shrink() why not do it in shrink but instead rely on 
new input?





Issue Time Tracking
-------------------

    Worklog Id:     (was: 770023)
    Time Spent: 79h 50m  (was: 79h 40m)

> Optimize Java SDK harness
> -------------------------
>
>                 Key: BEAM-13015
>                 URL: https://issues.apache.org/jira/browse/BEAM-13015
>             Project: Beam
>          Issue Type: Improvement
>          Components: sdk-java-harness
>            Reporter: Luke Cwik
>            Assignee: Luke Cwik
>            Priority: P2
>          Time Spent: 79h 50m
>  Remaining Estimate: 0h
>
> Use profiling tools to remove bundle processing overhead in the SDK harness.



--
This message was sent by Atlassian Jira
(v8.20.7#820007)

Reply via email to