This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 6ce104769c4 Change StateBackedIterable to implement
ElementByteSizeObservableIterable avoiding iteration to estimate observe bytes.
(#29517)
6ce104769c4 is described below
commit 6ce104769c45b56a01760f8e6574e2290cd7c4e8
Author: Sam Whittle <[email protected]>
AuthorDate: Thu Nov 23 12:13:05 2023 +0100
Change StateBackedIterable to implement ElementByteSizeObservableIterable
avoiding iteration to estimate observe bytes. (#29517)
* Change StateBackedIterable to implement ElementByteSizeObservableIterable
reducing byte estimation costs.
---
.../beam/fn/harness/state/StateBackedIterable.java | 87 +++++++++++++++++++++-
.../fn/harness/state/StateBackedIterableTest.java | 58 +++++++++++++++
2 files changed, 142 insertions(+), 3 deletions(-)
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java
index 9c95e9ad90e..22e0822b619 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateBackedIterable.java
@@ -43,12 +43,17 @@ import org.apache.beam.sdk.fn.stream.PrefetchableIterable;
import org.apache.beam.sdk.fn.stream.PrefetchableIterators;
import org.apache.beam.sdk.util.BufferedElementCountingOutputStream;
import org.apache.beam.sdk.util.VarInt;
+import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable;
+import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterator;
+import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
* A {@link BeamFnStateClient state} backed iterable which allows for fetching
elements over the
@@ -62,12 +67,17 @@ import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams
@SuppressWarnings({
"rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
})
-public class StateBackedIterable<T> implements Iterable<T>, Serializable {
+public class StateBackedIterable<T>
+ extends ElementByteSizeObservableIterable<T,
ElementByteSizeObservableIterator<T>>
+ implements Serializable {
+ private static final Logger LOG =
LoggerFactory.getLogger(StateBackedIterable.class);
@VisibleForTesting final StateRequest request;
@VisibleForTesting final List<T> prefix;
private final transient PrefetchableIterable<T> suffix;
+ private final org.apache.beam.sdk.coders.Coder<T> elemCoder;
+
public StateBackedIterable(
Cache<?, ?> cache,
BeamFnStateClient beamFnStateClient,
@@ -81,11 +91,82 @@ public class StateBackedIterable<T> implements Iterable<T>,
Serializable {
this.suffix =
StateFetchingIterators.readAllAndDecodeStartingFrom(
Caches.subCache(cache, stateKey), beamFnStateClient, request,
elemCoder);
+ this.elemCoder = elemCoder;
+ }
+
+ @SuppressWarnings("nullness")
+ private static class WrappedObservingIterator<T> extends
ElementByteSizeObservableIterator<T> {
+ private final Iterator<T> wrappedIterator;
+ private final org.apache.beam.sdk.coders.Coder<T> elementCoder;
+
+ // Logically final and non-null but initialized after construction by
factory method for
+ // initialization ordering.
+ private ElementByteSizeObserver observerProxy = null;
+
+ private boolean observerNeedsAdvance = false;
+ private boolean exceptionLogged = false;
+
+ static <T> WrappedObservingIterator<T> create(
+ Iterator<T> iterator, org.apache.beam.sdk.coders.Coder<T>
elementCoder) {
+ WrappedObservingIterator<T> result = new
WrappedObservingIterator<>(iterator, elementCoder);
+ result.observerProxy =
+ new ElementByteSizeObserver() {
+ @Override
+ protected void reportElementSize(long elementByteSize) {
+ result.notifyValueReturned(elementByteSize);
+ }
+ };
+ return result;
+ }
+
+ private WrappedObservingIterator(
+ Iterator<T> iterator, org.apache.beam.sdk.coders.Coder<T>
elementCoder) {
+ this.wrappedIterator = iterator;
+ this.elementCoder = elementCoder;
+ }
+
+ @Override
+ public boolean hasNext() {
+ if (observerNeedsAdvance) {
+ observerProxy.advance();
+ observerNeedsAdvance = false;
+ }
+ return wrappedIterator.hasNext();
+ }
+
+ @Override
+ public T next() {
+ T value = wrappedIterator.next();
+ try {
+ elementCoder.registerByteSizeObserver(value, observerProxy);
+ if (observerProxy.getIsLazy()) {
+ // The observer will only be notified of bytes as the result
+ // is used. We defer advancing the observer until hasNext in an
+ // attempt to capture those bytes.
+ observerNeedsAdvance = true;
+ } else {
+ observerNeedsAdvance = false;
+ observerProxy.advance();
+ }
+ } catch (Exception e) {
+ if (!exceptionLogged) {
+ LOG.warn("Lazily observed byte size will be under reported due to
exception", e);
+ exceptionLogged = true;
+ }
+ }
+ return value;
+ }
+
+ @Override
+ public void remove() {
+ super.remove();
+ }
}
@Override
- public Iterator<T> iterator() {
- return PrefetchableIterators.concat(prefix.iterator(), suffix.iterator());
+ protected ElementByteSizeObservableIterator<T> createIterator() {
+ return WrappedObservingIterator.create(
+ PrefetchableIterators.concat(prefix.iterator(), suffix.iterator()),
elemCoder);
}
protected Object writeReplace() throws ObjectStreamException {
diff --git
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java
index 4d53bcaef11..f758c367f73 100644
---
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java
+++
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateBackedIterableTest.java
@@ -19,6 +19,7 @@ package org.apache.beam.fn.harness.state;
import static java.util.Arrays.asList;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
@@ -36,11 +37,13 @@ import org.apache.beam.fn.harness.Caches;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.util.ByteStringOutputStream;
+import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams;
import org.junit.Test;
import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
@@ -213,6 +216,61 @@ public class StateBackedIterableTest {
}
}
}
+
+ private static class TestByteObserver extends ElementByteSizeObserver {
+ public long total = 0;
+
+ @Override
+ protected void reportElementSize(long elementByteSize) {
+ total += elementByteSize;
+ }
+ };
+
+ @Test
+ public void testByteObservingStateBackedIterable() throws Exception {
+ FakeBeamFnStateClient fakeBeamFnStateClient =
+ new FakeBeamFnStateClient(
+ StringUtf8Coder.of(),
+ ImmutableMap.of(
+ key("nonEmptySuffix"), asList("C", "D", "E", "F", "G", "H",
"I", "J", "K"),
+ key("emptySuffix"), asList()));
+
+ StateBackedIterable<String> iterable =
+ new StateBackedIterable<>(
+ Caches.noop(),
+ fakeBeamFnStateClient,
+ "instruction",
+ key(suffixKey),
+ StringUtf8Coder.of(),
+ prefix);
+ StateBackedIterable.Coder<String> coder =
+ new StateBackedIterable.Coder<>(
+ () -> Caches.noop(),
+ fakeBeamFnStateClient,
+ () -> "instructionId",
+ StringUtf8Coder.of());
+
+ assertTrue(coder.isRegisterByteSizeObserverCheap(iterable));
+ TestByteObserver observer = new TestByteObserver();
+ coder.registerByteSizeObserver(iterable, observer);
+ assertTrue(observer.getIsLazy());
+
+ long iterateBytes =
+ Streams.stream(iterable)
+ .mapToLong(
+ s -> {
+ try {
+ // 1 comes from hasNext = true flag (see
IterableLikeCoder)
+ return 1 +
StringUtf8Coder.of().getEncodedElementByteSize(s);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ })
+ .sum();
+ observer.advance();
+ // 5 comes from size and hasNext (see IterableLikeCoder)
+ assertEquals(iterateBytes + 5, observer.total);
+ }
}
@RunWith(JUnit4.class)