This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 84684aef490 Reduce weighing overhead for caching blocks (#36897)
84684aef490 is described below

commit 84684aef490e03163295050e3b6e51aa6c6a2ee6
Author: Sam Whittle <[email protected]>
AuthorDate: Fri Dec 5 12:46:43 2025 +0100

    Reduce weighing overhead for caching blocks (#36897)
---
 .../org/apache/beam/sdk/fn/data/WeightedList.java  | 11 +--
 .../fn/harness/state/StateFetchingIterators.java   | 83 ++++++++++++++--------
 2 files changed, 55 insertions(+), 39 deletions(-)

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java
index ad5e131cb2d..5eb317fc287 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java
@@ -20,6 +20,7 @@ package org.apache.beam.sdk.fn.data;
 import java.util.List;
 import java.util.concurrent.atomic.AtomicLong;
 import org.apache.beam.sdk.util.Weighted;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath;
 
 /** Facade for a {@link List<T>} that keeps track of weight, for cache limit 
reasons. */
 public class WeightedList<T> implements Weighted {
@@ -71,14 +72,6 @@ public class WeightedList<T> implements Weighted {
   }
 
   public void accumulateWeight(long weight) {
-    this.weight.accumulateAndGet(
-        weight,
-        (first, second) -> {
-          try {
-            return Math.addExact(first, second);
-          } catch (ArithmeticException e) {
-            return Long.MAX_VALUE;
-          }
-        });
+    this.weight.accumulateAndGet(weight, LongMath::saturatedAdd);
   }
 }
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
index 1e06c98f2e3..339ddad4061 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
@@ -49,6 +49,8 @@ import 
org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath;
 
 /**
  * Adapters which convert a logical series of chunks using continuation tokens 
over the Beam Fn
@@ -249,15 +251,11 @@ public class StateFetchingIterators {
 
       @Override
       public long getWeight() {
-        try {
-          long sum = 8 + blocks.size() * 8L;
-          for (Block<T> block : blocks) {
-            sum = Math.addExact(sum, block.getWeight());
-          }
-          return sum;
-        } catch (ArithmeticException e) {
-          return Long.MAX_VALUE;
+        long sum = 8 + blocks.size() * 8L;
+        for (Block<T> block : blocks) {
+          sum = LongMath.saturatedAdd(sum, block.getWeight());
         }
+        return sum;
       }
 
       BlocksPrefix(List<Block<T>> blocks) {
@@ -282,8 +280,7 @@ public class StateFetchingIterators {
 
     @AutoValue
     abstract static class Block<T> implements Weighted {
-      private static final Block<Void> EMPTY =
-          fromValues(WeightedList.of(Collections.emptyList(), 0), null);
+      private static final Block<Void> EMPTY = fromValues(ImmutableList.of(), 
0, null);
 
       @SuppressWarnings("unchecked") // Based upon as Collections.emptyList()
       public static <T> Block<T> emptyBlock() {
@@ -299,21 +296,37 @@ public class StateFetchingIterators {
       }
 
       public static <T> Block<T> fromValues(List<T> values, @Nullable 
ByteString nextToken) {
-        return fromValues(WeightedList.of(values, Caches.weigh(values)), 
nextToken);
+        if (values.isEmpty() && nextToken == null) {
+          return emptyBlock();
+        }
+        ImmutableList<T> immutableValues = ImmutableList.copyOf(values);
+        long listWeight = immutableValues.size() * Caches.REFERENCE_SIZE;
+        for (T value : immutableValues) {
+          listWeight = LongMath.saturatedAdd(listWeight, Caches.weigh(value));
+        }
+        return fromValues(immutableValues, listWeight, nextToken);
       }
 
       public static <T> Block<T> fromValues(
           WeightedList<T> values, @Nullable ByteString nextToken) {
-        long weight = values.getWeight() + 24;
+        if (values.isEmpty() && nextToken == null) {
+          return emptyBlock();
+        }
+        return fromValues(ImmutableList.copyOf(values.getBacking()), 
values.getWeight(), nextToken);
+      }
+
+      private static <T> Block<T> fromValues(
+          ImmutableList<T> values, long listWeight, @Nullable ByteString 
nextToken) {
+        long weight = LongMath.saturatedAdd(listWeight, 24);
         if (nextToken != null) {
           if (nextToken.isEmpty()) {
             nextToken = ByteString.EMPTY;
           } else {
-            weight += Caches.weigh(nextToken);
+            weight = LongMath.saturatedAdd(weight, Caches.weigh(nextToken));
           }
         }
         return new 
AutoValue_StateFetchingIterators_CachingStateIterable_Block<>(
-            values.getBacking(), nextToken, weight);
+            values, nextToken, weight);
       }
 
       abstract List<T> getValues();
@@ -372,10 +385,12 @@ public class StateFetchingIterators {
         totalSize += tBlock.getValues().size();
       }
 
-      WeightedList<T> allValues = WeightedList.of(new ArrayList<>(totalSize), 
0L);
+      ImmutableList.Builder<T> allValues = 
ImmutableList.builderWithExpectedSize(totalSize);
+      long weight = 0;
+      List<T> blockValuesToKeep = new ArrayList<>();
       for (Block<T> block : blocks) {
+        blockValuesToKeep.clear();
         boolean valueRemovedFromBlock = false;
-        List<T> blockValuesToKeep = new ArrayList<>();
         for (T value : block.getValues()) {
           if 
(!toRemoveStructuralValues.contains(valueCoder.structuralValue(value))) {
             blockValuesToKeep.add(value);
@@ -387,13 +402,19 @@ public class StateFetchingIterators {
         // If any value was removed from this block, need to estimate the 
weight again.
         // Otherwise, just reuse the block's weight.
         if (valueRemovedFromBlock) {
-          allValues.addAll(blockValuesToKeep, Caches.weigh(block.getValues()));
+          allValues.addAll(blockValuesToKeep);
+          for (T value : blockValuesToKeep) {
+            weight = LongMath.saturatedAdd(weight, Caches.weigh(value));
+          }
         } else {
-          allValues.addAll(block.getValues(), block.getWeight());
+          allValues.addAll(block.getValues());
+          weight = LongMath.saturatedAdd(weight, block.getWeight());
         }
       }
 
-      cache.put(IterableCacheKey.INSTANCE, new 
MutatedBlocks<>(Block.mutatedBlock(allValues)));
+      cache.put(
+          IterableCacheKey.INSTANCE,
+          new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, 
null)));
     }
 
     /**
@@ -484,21 +505,24 @@ public class StateFetchingIterators {
       for (Block<T> block : blocks) {
         totalSize += block.getValues().size();
       }
-      WeightedList<T> allValues = WeightedList.of(new ArrayList<>(totalSize), 
0L);
+      ImmutableList.Builder<T> allValues = 
ImmutableList.builderWithExpectedSize(totalSize);
+      long weight = 0;
       for (Block<T> block : blocks) {
-        allValues.addAll(block.getValues(), block.getWeight());
+        allValues.addAll(block.getValues());
+        weight = LongMath.saturatedAdd(weight, block.getWeight());
       }
       if (newWeight < 0) {
-        if (newValues.size() == 1) {
-          // Optimize weighing of the common value state as single 
single-element bag state.
-          newWeight = Caches.weigh(newValues.get(0));
-        } else {
-          newWeight = Caches.weigh(newValues);
+        newWeight = 0;
+        for (T value : newValues) {
+          newWeight = LongMath.saturatedAdd(newWeight, Caches.weigh(value));
         }
       }
-      allValues.addAll(newValues, newWeight);
+      allValues.addAll(newValues);
+      weight = LongMath.saturatedAdd(weight, newWeight);
 
-      cache.put(IterableCacheKey.INSTANCE, new 
MutatedBlocks<>(Block.mutatedBlock(allValues)));
+      cache.put(
+          IterableCacheKey.INSTANCE,
+          new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, 
null)));
     }
 
     class CachingStateIterator implements PrefetchableIterator<T> {
@@ -580,8 +604,7 @@ public class StateFetchingIterators {
             return false;
           }
           // Release the block while we are loading the next one.
-          currentBlock =
-              Block.fromValues(WeightedList.of(Collections.emptyList(), 0L), 
ByteString.EMPTY);
+          currentBlock = Block.emptyBlock();
 
           @Nullable Blocks<T> existing = cache.peek(IterableCacheKey.INSTANCE);
           boolean isFirstBlock = ByteString.EMPTY.equals(nextToken);

Reply via email to