This is an automated email from the ASF dual-hosted git repository.

cvandermerwe pushed a commit to branch revert-36897-optimize_weigh
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 63177cb5db6623bbfe9ba7ede1041f58dd4ddb93
Author: claudevdm <[email protected]>
AuthorDate: Fri Dec 5 15:27:11 2025 -0500

    Revert "Reduce weighing overhead for caching blocks (#36897)"
    
    This reverts commit 84684aef490e03163295050e3b6e51aa6c6a2ee6.
---
 .../org/apache/beam/sdk/fn/data/WeightedList.java  | 11 ++-
 .../fn/harness/state/StateFetchingIterators.java   | 83 ++++++++--------------
 2 files changed, 39 insertions(+), 55 deletions(-)

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java
index 5eb317fc287..ad5e131cb2d 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java
@@ -20,7 +20,6 @@ package org.apache.beam.sdk.fn.data;
 import java.util.List;
 import java.util.concurrent.atomic.AtomicLong;
 import org.apache.beam.sdk.util.Weighted;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath;
 
 /** Facade for a {@link List<T>} that keeps track of weight, for cache limit 
reasons. */
 public class WeightedList<T> implements Weighted {
@@ -72,6 +71,14 @@ public class WeightedList<T> implements Weighted {
   }
 
   public void accumulateWeight(long weight) {
-    this.weight.accumulateAndGet(weight, LongMath::saturatedAdd);
+    this.weight.accumulateAndGet(
+        weight,
+        (first, second) -> {
+          try {
+            return Math.addExact(first, second);
+          } catch (ArithmeticException e) {
+            return Long.MAX_VALUE;
+          }
+        });
   }
 }
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
index 339ddad4061..1e06c98f2e3 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
@@ -49,8 +49,6 @@ import 
org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath;
 
 /**
  * Adapters which convert a logical series of chunks using continuation tokens 
over the Beam Fn
@@ -251,11 +249,15 @@ public class StateFetchingIterators {
 
       @Override
       public long getWeight() {
-        long sum = 8 + blocks.size() * 8L;
-        for (Block<T> block : blocks) {
-          sum = LongMath.saturatedAdd(sum, block.getWeight());
+        try {
+          long sum = 8 + blocks.size() * 8L;
+          for (Block<T> block : blocks) {
+            sum = Math.addExact(sum, block.getWeight());
+          }
+          return sum;
+        } catch (ArithmeticException e) {
+          return Long.MAX_VALUE;
         }
-        return sum;
       }
 
       BlocksPrefix(List<Block<T>> blocks) {
@@ -280,7 +282,8 @@ public class StateFetchingIterators {
 
     @AutoValue
     abstract static class Block<T> implements Weighted {
-      private static final Block<Void> EMPTY = fromValues(ImmutableList.of(), 
0, null);
+      private static final Block<Void> EMPTY =
+          fromValues(WeightedList.of(Collections.emptyList(), 0), null);
 
       @SuppressWarnings("unchecked") // Based upon as Collections.emptyList()
       public static <T> Block<T> emptyBlock() {
@@ -296,37 +299,21 @@ public class StateFetchingIterators {
       }
 
       public static <T> Block<T> fromValues(List<T> values, @Nullable 
ByteString nextToken) {
-        if (values.isEmpty() && nextToken == null) {
-          return emptyBlock();
-        }
-        ImmutableList<T> immutableValues = ImmutableList.copyOf(values);
-        long listWeight = immutableValues.size() * Caches.REFERENCE_SIZE;
-        for (T value : immutableValues) {
-          listWeight = LongMath.saturatedAdd(listWeight, Caches.weigh(value));
-        }
-        return fromValues(immutableValues, listWeight, nextToken);
+        return fromValues(WeightedList.of(values, Caches.weigh(values)), 
nextToken);
       }
 
       public static <T> Block<T> fromValues(
           WeightedList<T> values, @Nullable ByteString nextToken) {
-        if (values.isEmpty() && nextToken == null) {
-          return emptyBlock();
-        }
-        return fromValues(ImmutableList.copyOf(values.getBacking()), 
values.getWeight(), nextToken);
-      }
-
-      private static <T> Block<T> fromValues(
-          ImmutableList<T> values, long listWeight, @Nullable ByteString 
nextToken) {
-        long weight = LongMath.saturatedAdd(listWeight, 24);
+        long weight = values.getWeight() + 24;
         if (nextToken != null) {
           if (nextToken.isEmpty()) {
             nextToken = ByteString.EMPTY;
           } else {
-            weight = LongMath.saturatedAdd(weight, Caches.weigh(nextToken));
+            weight += Caches.weigh(nextToken);
           }
         }
         return new 
AutoValue_StateFetchingIterators_CachingStateIterable_Block<>(
-            values, nextToken, weight);
+            values.getBacking(), nextToken, weight);
       }
 
       abstract List<T> getValues();
@@ -385,12 +372,10 @@ public class StateFetchingIterators {
         totalSize += tBlock.getValues().size();
       }
 
-      ImmutableList.Builder<T> allValues = 
ImmutableList.builderWithExpectedSize(totalSize);
-      long weight = 0;
-      List<T> blockValuesToKeep = new ArrayList<>();
+      WeightedList<T> allValues = WeightedList.of(new ArrayList<>(totalSize), 
0L);
       for (Block<T> block : blocks) {
-        blockValuesToKeep.clear();
         boolean valueRemovedFromBlock = false;
+        List<T> blockValuesToKeep = new ArrayList<>();
         for (T value : block.getValues()) {
           if 
(!toRemoveStructuralValues.contains(valueCoder.structuralValue(value))) {
             blockValuesToKeep.add(value);
@@ -402,19 +387,13 @@ public class StateFetchingIterators {
         // If any value was removed from this block, need to estimate the 
weight again.
         // Otherwise, just reuse the block's weight.
         if (valueRemovedFromBlock) {
-          allValues.addAll(blockValuesToKeep);
-          for (T value : blockValuesToKeep) {
-            weight = LongMath.saturatedAdd(weight, Caches.weigh(value));
-          }
+          allValues.addAll(blockValuesToKeep, Caches.weigh(block.getValues()));
         } else {
-          allValues.addAll(block.getValues());
-          weight = LongMath.saturatedAdd(weight, block.getWeight());
+          allValues.addAll(block.getValues(), block.getWeight());
         }
       }
 
-      cache.put(
-          IterableCacheKey.INSTANCE,
-          new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, 
null)));
+      cache.put(IterableCacheKey.INSTANCE, new 
MutatedBlocks<>(Block.mutatedBlock(allValues)));
     }
 
     /**
@@ -505,24 +484,21 @@ public class StateFetchingIterators {
       for (Block<T> block : blocks) {
         totalSize += block.getValues().size();
       }
-      ImmutableList.Builder<T> allValues = 
ImmutableList.builderWithExpectedSize(totalSize);
-      long weight = 0;
+      WeightedList<T> allValues = WeightedList.of(new ArrayList<>(totalSize), 
0L);
       for (Block<T> block : blocks) {
-        allValues.addAll(block.getValues());
-        weight = LongMath.saturatedAdd(weight, block.getWeight());
+        allValues.addAll(block.getValues(), block.getWeight());
       }
       if (newWeight < 0) {
-        newWeight = 0;
-        for (T value : newValues) {
-          newWeight = LongMath.saturatedAdd(newWeight, Caches.weigh(value));
+        if (newValues.size() == 1) {
+          // Optimize weighing of the common value state as single 
single-element bag state.
+          newWeight = Caches.weigh(newValues.get(0));
+        } else {
+          newWeight = Caches.weigh(newValues);
         }
       }
-      allValues.addAll(newValues);
-      weight = LongMath.saturatedAdd(weight, newWeight);
+      allValues.addAll(newValues, newWeight);
 
-      cache.put(
-          IterableCacheKey.INSTANCE,
-          new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, 
null)));
+      cache.put(IterableCacheKey.INSTANCE, new 
MutatedBlocks<>(Block.mutatedBlock(allValues)));
     }
 
     class CachingStateIterator implements PrefetchableIterator<T> {
@@ -604,7 +580,8 @@ public class StateFetchingIterators {
             return false;
           }
           // Release the block while we are loading the next one.
-          currentBlock = Block.emptyBlock();
+          currentBlock =
+              Block.fromValues(WeightedList.of(Collections.emptyList(), 0L), 
ByteString.EMPTY);
 
           @Nullable Blocks<T> existing = cache.peek(IterableCacheKey.INSTANCE);
           boolean isFirstBlock = ByteString.EMPTY.equals(nextToken);

Reply via email to