This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 84684aef490 Reduce weighing overhead for caching blocks (#36897)
84684aef490 is described below
commit 84684aef490e03163295050e3b6e51aa6c6a2ee6
Author: Sam Whittle <[email protected]>
AuthorDate: Fri Dec 5 12:46:43 2025 +0100
Reduce weighing overhead for caching blocks (#36897)
---
.../org/apache/beam/sdk/fn/data/WeightedList.java | 11 +--
.../fn/harness/state/StateFetchingIterators.java | 83 ++++++++++++++--------
2 files changed, 55 insertions(+), 39 deletions(-)
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java
index ad5e131cb2d..5eb317fc287 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java
@@ -20,6 +20,7 @@ package org.apache.beam.sdk.fn.data;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.beam.sdk.util.Weighted;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath;
/** Facade for a {@link List<T>} that keeps track of weight, for cache limit
reasons. */
public class WeightedList<T> implements Weighted {
@@ -71,14 +72,6 @@ public class WeightedList<T> implements Weighted {
}
public void accumulateWeight(long weight) {
- this.weight.accumulateAndGet(
- weight,
- (first, second) -> {
- try {
- return Math.addExact(first, second);
- } catch (ArithmeticException e) {
- return Long.MAX_VALUE;
- }
- });
+ this.weight.accumulateAndGet(weight, LongMath::saturatedAdd);
}
}
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
index 1e06c98f2e3..339ddad4061 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
@@ -49,6 +49,8 @@ import
org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath;
/**
* Adapters which convert a logical series of chunks using continuation tokens
over the Beam Fn
@@ -249,15 +251,11 @@ public class StateFetchingIterators {
@Override
public long getWeight() {
- try {
- long sum = 8 + blocks.size() * 8L;
- for (Block<T> block : blocks) {
- sum = Math.addExact(sum, block.getWeight());
- }
- return sum;
- } catch (ArithmeticException e) {
- return Long.MAX_VALUE;
+ long sum = 8 + blocks.size() * 8L;
+ for (Block<T> block : blocks) {
+ sum = LongMath.saturatedAdd(sum, block.getWeight());
}
+ return sum;
}
BlocksPrefix(List<Block<T>> blocks) {
@@ -282,8 +280,7 @@ public class StateFetchingIterators {
@AutoValue
abstract static class Block<T> implements Weighted {
- private static final Block<Void> EMPTY =
- fromValues(WeightedList.of(Collections.emptyList(), 0), null);
+ private static final Block<Void> EMPTY = fromValues(ImmutableList.of(),
0, null);
@SuppressWarnings("unchecked") // Based upon as Collections.emptyList()
public static <T> Block<T> emptyBlock() {
@@ -299,21 +296,37 @@ public class StateFetchingIterators {
}
public static <T> Block<T> fromValues(List<T> values, @Nullable
ByteString nextToken) {
- return fromValues(WeightedList.of(values, Caches.weigh(values)),
nextToken);
+ if (values.isEmpty() && nextToken == null) {
+ return emptyBlock();
+ }
+ ImmutableList<T> immutableValues = ImmutableList.copyOf(values);
+ long listWeight = immutableValues.size() * Caches.REFERENCE_SIZE;
+ for (T value : immutableValues) {
+ listWeight = LongMath.saturatedAdd(listWeight, Caches.weigh(value));
+ }
+ return fromValues(immutableValues, listWeight, nextToken);
}
public static <T> Block<T> fromValues(
WeightedList<T> values, @Nullable ByteString nextToken) {
- long weight = values.getWeight() + 24;
+ if (values.isEmpty() && nextToken == null) {
+ return emptyBlock();
+ }
+ return fromValues(ImmutableList.copyOf(values.getBacking()),
values.getWeight(), nextToken);
+ }
+
+ private static <T> Block<T> fromValues(
+ ImmutableList<T> values, long listWeight, @Nullable ByteString
nextToken) {
+ long weight = LongMath.saturatedAdd(listWeight, 24);
if (nextToken != null) {
if (nextToken.isEmpty()) {
nextToken = ByteString.EMPTY;
} else {
- weight += Caches.weigh(nextToken);
+ weight = LongMath.saturatedAdd(weight, Caches.weigh(nextToken));
}
}
return new
AutoValue_StateFetchingIterators_CachingStateIterable_Block<>(
- values.getBacking(), nextToken, weight);
+ values, nextToken, weight);
}
abstract List<T> getValues();
@@ -372,10 +385,12 @@ public class StateFetchingIterators {
totalSize += tBlock.getValues().size();
}
- WeightedList<T> allValues = WeightedList.of(new ArrayList<>(totalSize),
0L);
+ ImmutableList.Builder<T> allValues =
ImmutableList.builderWithExpectedSize(totalSize);
+ long weight = 0;
+ List<T> blockValuesToKeep = new ArrayList<>();
for (Block<T> block : blocks) {
+ blockValuesToKeep.clear();
boolean valueRemovedFromBlock = false;
- List<T> blockValuesToKeep = new ArrayList<>();
for (T value : block.getValues()) {
if
(!toRemoveStructuralValues.contains(valueCoder.structuralValue(value))) {
blockValuesToKeep.add(value);
@@ -387,13 +402,19 @@ public class StateFetchingIterators {
// If any value was removed from this block, need to estimate the
weight again.
// Otherwise, just reuse the block's weight.
if (valueRemovedFromBlock) {
- allValues.addAll(blockValuesToKeep, Caches.weigh(block.getValues()));
+ allValues.addAll(blockValuesToKeep);
+ for (T value : blockValuesToKeep) {
+ weight = LongMath.saturatedAdd(weight, Caches.weigh(value));
+ }
} else {
- allValues.addAll(block.getValues(), block.getWeight());
+ allValues.addAll(block.getValues());
+ weight = LongMath.saturatedAdd(weight, block.getWeight());
}
}
- cache.put(IterableCacheKey.INSTANCE, new
MutatedBlocks<>(Block.mutatedBlock(allValues)));
+ cache.put(
+ IterableCacheKey.INSTANCE,
+ new MutatedBlocks<>(Block.fromValues(allValues.build(), weight,
null)));
}
/**
@@ -484,21 +505,24 @@ public class StateFetchingIterators {
for (Block<T> block : blocks) {
totalSize += block.getValues().size();
}
- WeightedList<T> allValues = WeightedList.of(new ArrayList<>(totalSize),
0L);
+ ImmutableList.Builder<T> allValues =
ImmutableList.builderWithExpectedSize(totalSize);
+ long weight = 0;
for (Block<T> block : blocks) {
- allValues.addAll(block.getValues(), block.getWeight());
+ allValues.addAll(block.getValues());
+ weight = LongMath.saturatedAdd(weight, block.getWeight());
}
if (newWeight < 0) {
- if (newValues.size() == 1) {
- // Optimize weighing of the common value state as single
single-element bag state.
- newWeight = Caches.weigh(newValues.get(0));
- } else {
- newWeight = Caches.weigh(newValues);
+ newWeight = 0;
+ for (T value : newValues) {
+ newWeight = LongMath.saturatedAdd(newWeight, Caches.weigh(value));
}
}
- allValues.addAll(newValues, newWeight);
+ allValues.addAll(newValues);
+ weight = LongMath.saturatedAdd(weight, newWeight);
- cache.put(IterableCacheKey.INSTANCE, new
MutatedBlocks<>(Block.mutatedBlock(allValues)));
+ cache.put(
+ IterableCacheKey.INSTANCE,
+ new MutatedBlocks<>(Block.fromValues(allValues.build(), weight,
null)));
}
class CachingStateIterator implements PrefetchableIterator<T> {
@@ -580,8 +604,7 @@ public class StateFetchingIterators {
return false;
}
// Release the block while we are loading the next one.
- currentBlock =
- Block.fromValues(WeightedList.of(Collections.emptyList(), 0L),
ByteString.EMPTY);
+ currentBlock = Block.emptyBlock();
@Nullable Blocks<T> existing = cache.peek(IterableCacheKey.INSTANCE);
boolean isFirstBlock = ByteString.EMPTY.equals(nextToken);