emilymye commented on a change in pull request #16354:
URL: https://github.com/apache/beam/pull/16354#discussion_r776817563
##########
File path:
sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
##########
@@ -361,62 +377,332 @@ private CoGbkResult(CoGbkResultSchema schema,
List<Iterable<?>> valueMap) {
}
/**
- * Lazily filters and recasts an {@code Iterator<RawUnionValue>} into an
{@code Iterator<V>},
- * where V is the type of the raw union value's contents.
+ * A re-iterable that notifies an observer at every advance, and upon
finishing, but only once
+ * across all copies.
+ *
+ * @param <T> The value type of the underlying iterable.
*/
- private static class UnionValueIterator<V> implements Iterator<V> {
+ private static class ObservingReiterator<T> implements Reiterator<T> {
+
+ public interface Observer<T> {
+ /**
+ * Called exactly once, across all copies before advancing this iterator.
+ *
+ * <p>The iterator rather than the element is given so that the callee
can perform a copy if
+ * desired. This class offers a peek method to get at the current
element without disturbing
+ * the state of this iterator.
+ */
+ void observeAt(ObservingReiterator<T> reiterator);
+
+ /** Called exactly once, across all copies, once this iterator is
exhausted. */
+ void done();
+ }
- private final int tag;
- private final PeekingIterator<RawUnionValue> unions;
- private final Boolean[] containsTag;
+ private PeekingReiterator<IndexingReiterator.Indexed<T>> underlying;
+ private Observer<T> observer;
- private UnionValueIterator(int tag, Iterator<RawUnionValue> unions,
Boolean[] containsTag) {
- this.tag = tag;
- this.unions = Iterators.peekingIterator(unions);
- this.containsTag = containsTag;
+ // Used to keep track of what has been observed so far.
+ // These are arrays to facilitate sharing values among all copies of the
same root Reiterator.
+ private final int[] lastObserved;
+ private final boolean[] doneHasRun;
+ private final PeekingReiterator[] mostAdvanced;
+
+ public ObservingReiterator(Reiterator<T> underlying, Observer<T> observer)
{
+ this(new PeekingReiterator<>(new IndexingReiterator<>(underlying)),
observer);
+ }
+
+ public ObservingReiterator(
+ PeekingReiterator<IndexingReiterator.Indexed<T>> underlying,
Observer<T> observer) {
+ this(
+ underlying,
+ observer,
+ new int[] {-1},
+ new boolean[] {false},
+ new PeekingReiterator[] {underlying});
+ }
+
+ private ObservingReiterator(
+ PeekingReiterator<IndexingReiterator.Indexed<T>> underlying,
+ Observer<T> observer,
+ int[] lastObserved,
+ boolean[] doneHasRun,
+ PeekingReiterator[] mostAdvanced) {
+ this.underlying = underlying;
+ this.observer = observer;
+ this.lastObserved = lastObserved;
+ this.doneHasRun = doneHasRun;
+ this.mostAdvanced = mostAdvanced;
+ }
+
+ @Override
+ public Reiterator<T> copy() {
+ return new ObservingReiterator<T>(
+ underlying.copy(), observer, lastObserved, doneHasRun, mostAdvanced);
}
@Override
public boolean hasNext() {
- if (Boolean.FALSE.equals(containsTag[tag])) {
- return false;
+ boolean hasNext = underlying.hasNext();
+ if (!hasNext && !doneHasRun[0]) {
+ mostAdvanced[0] = underlying;
+ observer.done();
+ doneHasRun[0] = true;
}
- advance();
- if (unions.hasNext()) {
- return true;
- } else {
- // Now that we've iterated over all the values, we can resolve all the
"unknown" null
- // values to false.
- for (int i = 0; i < containsTag.length; i++) {
- if (containsTag[i] == null) {
- containsTag[i] = false;
- }
- }
- return false;
+ return hasNext;
+ }
+
+ @Override
+ public T next() {
+ peek(); // trigger observation *before* advancing
+ return underlying.next().value;
+ }
+
+ public T peek() {
+ IndexingReiterator.Indexed<T> next = underlying.peek();
+ if (next.index > lastObserved[0]) {
+ assert next.index == lastObserved[0] + 1;
+ mostAdvanced[0] = underlying;
+ lastObserved[0] = next.index;
+ observer.observeAt(this);
}
+ return next.value;
+ }
+
+ public void fastForward() {
+ if (underlying != mostAdvanced[0]) {
+ underlying = mostAdvanced[0].copy();
+ }
+ }
+ }
+
+ /**
+ * Assigns a monotonically increasing index to each item in the underling
Reiterator.
+ *
+ * @param <T> The value type of the underlying iterable.
+ */
+ private static class IndexingReiterator<T> implements
Reiterator<IndexingReiterator.Indexed<T>> {
+
+ private Reiterator<T> underlying;
+ private int index;
+
+ public IndexingReiterator(Reiterator<T> underlying) {
+ this(underlying, 0);
+ }
+
+ public IndexingReiterator(Reiterator<T> underlying, int start) {
+ this.underlying = underlying;
+ this.index = start;
}
@Override
- @SuppressWarnings("unchecked")
- public V next() {
- advance();
- return (V) unions.next().getValue();
+ public IndexingReiterator<T> copy() {
+ return new IndexingReiterator(underlying.copy(), index);
}
- private void advance() {
- while (unions.hasNext()) {
- int curTag = unions.peek().getUnionTag();
- containsTag[curTag] = true;
- if (curTag == tag) {
- break;
- }
- unions.next();
+ @Override
+ public boolean hasNext() {
+ return underlying.hasNext();
+ }
+
+ @Override
+ public Indexed<T> next() {
+ return new Indexed<T>(index++, underlying.next());
+ }
+
+ public static class Indexed<T> {
+ public final int index;
+ public final T value;
+
+ public Indexed(int index, T value) {
+ this.index = index;
+ this.value = value;
+ }
+ }
+ }
+
+ /**
+ * Adapts an Reiterator, giving it a peek() method that can be used to
observe the next element
+ * without consuming it.
+ *
+ * @param <T> The value type of the underlying iterable.
+ */
+ private static class PeekingReiterator<T> implements Reiterator<T> {
+ private Reiterator<T> underlying;
+ private T next;
+ private boolean nextIsValid;
+
+ public PeekingReiterator(Reiterator<T> underlying) {
+ this(underlying, null, false);
+ }
+
+ private PeekingReiterator(Reiterator<T> underlying, T next, boolean
nextIsValid) {
+ this.underlying = underlying;
+ this.next = next;
+ this.nextIsValid = nextIsValid;
+ }
+
+ @Override
+ public PeekingReiterator<T> copy() {
+ return new PeekingReiterator(underlying.copy(), next, nextIsValid);
+ }
+
+ @Override
+ public boolean hasNext() {
+ return nextIsValid || underlying.hasNext();
+ }
+
+ @Override
+ public T next() {
+ if (nextIsValid) {
+ nextIsValid = false;
+ return next;
+ } else {
+ return underlying.next();
+ }
+ }
+
+ public T peek() {
+ if (!nextIsValid) {
+ next = underlying.next();
+ nextIsValid = true;
+ }
+ return next;
+ }
+ }
+
+ /**
+ * An Iterable corresponding to a single tag.
+ *
+ * <p>The values in this iterable are populated lazily via the offer method
as tip advances for
+ * any tag.
+ *
+ * @param <T> The value type of the corresponging tag.
+ */
+ private static class TagIterable<T> implements Iterable<T> {
+ int tag;
+ int cacheSize;
+ Supplier<Boolean> forceCache;
+
+ ObservingReiterator<RawUnionValue> tip;
+
+ List<T> head;
+ Reiterator<RawUnionValue> tail;
+ boolean finished;
+
+ public TagIterable(
+ List<T> head, int tag, int cacheSize,
ObservingReiterator<RawUnionValue> tip) {
+ this.tag = tag;
+ this.cacheSize = cacheSize;
+ this.head = head;
+ this.tip = tip;
+ }
+
+ void offer(ObservingReiterator<RawUnionValue> tail) {
+ assert !finished;
+ assert tail.peek().getUnionTag() == tag;
+ if (head.size() < cacheSize) {
+ head.add((T) tail.peek().getValue());
+ } else if (this.tail == null) {
+ this.tail = tail.copy();
+ }
+ }
+
+ void finish() {
+ finished = true;
+ }
+
+ void seek(int tag) {
+ while (tip.hasNext() && tip.peek().getUnionTag() != tag) {
+ tip.next();
}
}
@Override
- public void remove() {
- throw new UnsupportedOperationException();
+ public Iterator<T> iterator() {
+ return new Iterator<T>() {
+
+ boolean isDone;
+ boolean advanced;
+ T next;
+
+ /** Keeps track of the index, in head, that this iterator points to. */
+ int index = -1;
+ /** If the index is beyond what was cached in head, this is this
iterators view of tail. */
+ Iterator<T> tailIter;
+
+ @Override
+ public boolean hasNext() {
+ if (!advanced) {
+ advance();
+ }
+ return !isDone;
+ }
+
+ @Override
+ public T next() {
+ if (!advanced) {
+ advance();
+ }
+ if (isDone) {
+ throw new NoSuchElementException();
+ }
+ advanced = false;
+ return next;
+ }
+
+ private void advance() {
+ assert !advanced;
+ assert !isDone;
+ advanced = true;
+
+ index++;
+ if (maybeAdvance()) {
+ return;
+ }
+
+ // We were unable to advance; advance tip to populate either head or
tail.
+ tip.fastForward();
+ if (tip.hasNext()) {
+ tip.next();
+ seek(tag);
+ }
+
+ // A this point, either head or tail should be sufficient to advance.
+ assert maybeAdvance();
+ }
+
+ private boolean maybeAdvance() {
+ if (index < head.size()) {
+ // First consume head.
+ assert tailIter == null;
+ next = head.get(index);
+ return true;
+ } else if (tail != null) {
+ // Next consume tail, if any.
+ if (tailIter == null) {
+ tailIter =
+ Iterators.transform(
+ Iterators.filter(
+ tail.copy(), taggedUnion ->
taggedUnion.getUnionTag() == tag),
+ taggedUnion -> (T) taggedUnion.getValue());
+ }
+ if (tailIter.hasNext()) {
+ next = tailIter.next();
+ } else {
+ isDone = true;
+ }
+ return true;
+ } else if (finished) {
+ // If there are no more elements in head, and tail was not
populated, and we are
+ // finshed, this is the end of the iteration.
+ isDone = true;
+ return true;
+ } else {
+ // We need more lements in either head or tail.
Review comment:
```suggestion
// We need more elements in either head or tail.
```
##########
File path:
sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/join/CoGbkResultTest.java
##########
@@ -57,6 +74,175 @@ public void runLazyResult(int cacheSize) {
assertThat(result.getAll(new TupleTag<>("tag0")), contains(0, 2, 4));
}
+ @Test
+ public void testLazyResults() {
+ TestUnionValues values = new TestUnionValues(0, 0, 1, 1, 0, 1, 1);
+ CoGbkResult result = new CoGbkResult(createSchema(5), values, 0, 2);
+ // Nothing is read until we try to iterate.
+ assertThat(values.maxPos(), equalTo(0));
+ Iterable<?> tag0iterable = result.getAll("tag0");
+ assertThat(values.maxPos(), equalTo(0));
+ tag0iterable.iterator();
+ assertThat(values.maxPos(), equalTo(0));
+
+ // Iterating reads (nearly) the minimal number of values.
+ Iterator<?> tag0 = tag0iterable.iterator();
+ tag0.next();
+ assertThat(values.maxPos(), lessThanOrEqualTo(2));
+ tag0.next();
+ assertThat(values.maxPos(), equalTo(2));
+ // Note that we're skipping over tag 1.
+ tag0.next();
+ assertThat(values.maxPos(), equalTo(5));
+
+ // Iterating again does not cause more reads.
+ Iterator<?> tag0iterAgain = tag0iterable.iterator();
+ tag0iterAgain.next();
+ tag0iterAgain.next();
+ tag0iterAgain.next();
+ assertThat(values.maxPos(), equalTo(5));
+
+ // Iterating over other tags does not cause more reads for values we have
seen.
+ Iterator<?> tag1 = result.getAll("tag1").iterator();
+ tag1.next();
+ tag1.next();
+ assertThat(values.maxPos(), equalTo(5));
+ // However, finding the next tag1 value does require more reads.
+ tag1.next();
+ assertThat(values.maxPos(), equalTo(6));
+ }
+
+ @Test
+ @SuppressWarnings("BoxedPrimitiveEquality")
+ public void testCachedResults() {
+ // Ensure we don't fail below due to odd VM settings.
Review comment:
out of curiousity, what odd VM settings?
##########
File path:
sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
##########
@@ -361,62 +377,332 @@ private CoGbkResult(CoGbkResultSchema schema,
List<Iterable<?>> valueMap) {
}
/**
- * Lazily filters and recasts an {@code Iterator<RawUnionValue>} into an
{@code Iterator<V>},
- * where V is the type of the raw union value's contents.
+ * A re-iterable that notifies an observer at every advance, and upon
finishing, but only once
+ * across all copies.
+ *
+ * @param <T> The value type of the underlying iterable.
*/
- private static class UnionValueIterator<V> implements Iterator<V> {
+ private static class ObservingReiterator<T> implements Reiterator<T> {
+
+ public interface Observer<T> {
+ /**
+ * Called exactly once, across all copies before advancing this iterator.
+ *
+ * <p>The iterator rather than the element is given so that the callee
can perform a copy if
+ * desired. This class offers a peek method to get at the current
element without disturbing
+ * the state of this iterator.
+ */
+ void observeAt(ObservingReiterator<T> reiterator);
+
+ /** Called exactly once, across all copies, once this iterator is
exhausted. */
+ void done();
+ }
- private final int tag;
- private final PeekingIterator<RawUnionValue> unions;
- private final Boolean[] containsTag;
+ private PeekingReiterator<IndexingReiterator.Indexed<T>> underlying;
+ private Observer<T> observer;
- private UnionValueIterator(int tag, Iterator<RawUnionValue> unions,
Boolean[] containsTag) {
- this.tag = tag;
- this.unions = Iterators.peekingIterator(unions);
- this.containsTag = containsTag;
+ // Used to keep track of what has been observed so far.
+ // These are arrays to facilitate sharing values among all copies of the
same root Reiterator.
+ private final int[] lastObserved;
+ private final boolean[] doneHasRun;
+ private final PeekingReiterator[] mostAdvanced;
+
+ public ObservingReiterator(Reiterator<T> underlying, Observer<T> observer)
{
+ this(new PeekingReiterator<>(new IndexingReiterator<>(underlying)),
observer);
+ }
+
+ public ObservingReiterator(
+ PeekingReiterator<IndexingReiterator.Indexed<T>> underlying,
Observer<T> observer) {
+ this(
+ underlying,
+ observer,
+ new int[] {-1},
+ new boolean[] {false},
+ new PeekingReiterator[] {underlying});
+ }
+
+ private ObservingReiterator(
+ PeekingReiterator<IndexingReiterator.Indexed<T>> underlying,
+ Observer<T> observer,
+ int[] lastObserved,
+ boolean[] doneHasRun,
+ PeekingReiterator[] mostAdvanced) {
+ this.underlying = underlying;
+ this.observer = observer;
+ this.lastObserved = lastObserved;
+ this.doneHasRun = doneHasRun;
+ this.mostAdvanced = mostAdvanced;
+ }
+
+ @Override
+ public Reiterator<T> copy() {
+ return new ObservingReiterator<T>(
+ underlying.copy(), observer, lastObserved, doneHasRun, mostAdvanced);
}
@Override
public boolean hasNext() {
- if (Boolean.FALSE.equals(containsTag[tag])) {
- return false;
+ boolean hasNext = underlying.hasNext();
+ if (!hasNext && !doneHasRun[0]) {
+ mostAdvanced[0] = underlying;
+ observer.done();
+ doneHasRun[0] = true;
}
- advance();
- if (unions.hasNext()) {
- return true;
- } else {
- // Now that we've iterated over all the values, we can resolve all the
"unknown" null
- // values to false.
- for (int i = 0; i < containsTag.length; i++) {
- if (containsTag[i] == null) {
- containsTag[i] = false;
- }
- }
- return false;
+ return hasNext;
+ }
+
+ @Override
+ public T next() {
+ peek(); // trigger observation *before* advancing
+ return underlying.next().value;
+ }
+
+ public T peek() {
+ IndexingReiterator.Indexed<T> next = underlying.peek();
+ if (next.index > lastObserved[0]) {
+ assert next.index == lastObserved[0] + 1;
+ mostAdvanced[0] = underlying;
+ lastObserved[0] = next.index;
+ observer.observeAt(this);
}
+ return next.value;
+ }
+
+ public void fastForward() {
+ if (underlying != mostAdvanced[0]) {
+ underlying = mostAdvanced[0].copy();
+ }
+ }
+ }
+
+ /**
+ * Assigns a monotonically increasing index to each item in the underling
Reiterator.
+ *
+ * @param <T> The value type of the underlying iterable.
+ */
+ private static class IndexingReiterator<T> implements
Reiterator<IndexingReiterator.Indexed<T>> {
+
+ private Reiterator<T> underlying;
+ private int index;
+
+ public IndexingReiterator(Reiterator<T> underlying) {
+ this(underlying, 0);
+ }
+
+ public IndexingReiterator(Reiterator<T> underlying, int start) {
+ this.underlying = underlying;
+ this.index = start;
}
@Override
- @SuppressWarnings("unchecked")
- public V next() {
- advance();
- return (V) unions.next().getValue();
+ public IndexingReiterator<T> copy() {
+ return new IndexingReiterator(underlying.copy(), index);
}
- private void advance() {
- while (unions.hasNext()) {
- int curTag = unions.peek().getUnionTag();
- containsTag[curTag] = true;
- if (curTag == tag) {
- break;
- }
- unions.next();
+ @Override
+ public boolean hasNext() {
+ return underlying.hasNext();
+ }
+
+ @Override
+ public Indexed<T> next() {
+ return new Indexed<T>(index++, underlying.next());
+ }
+
+ public static class Indexed<T> {
+ public final int index;
+ public final T value;
+
+ public Indexed(int index, T value) {
+ this.index = index;
+ this.value = value;
+ }
+ }
+ }
+
+ /**
+ * Adapts an Reiterator, giving it a peek() method that can be used to
observe the next element
+ * without consuming it.
+ *
+ * @param <T> The value type of the underlying iterable.
+ */
+ private static class PeekingReiterator<T> implements Reiterator<T> {
+ private Reiterator<T> underlying;
+ private T next;
+ private boolean nextIsValid;
+
+ public PeekingReiterator(Reiterator<T> underlying) {
+ this(underlying, null, false);
+ }
+
+ private PeekingReiterator(Reiterator<T> underlying, T next, boolean
nextIsValid) {
+ this.underlying = underlying;
+ this.next = next;
+ this.nextIsValid = nextIsValid;
+ }
+
+ @Override
+ public PeekingReiterator<T> copy() {
+ return new PeekingReiterator(underlying.copy(), next, nextIsValid);
+ }
+
+ @Override
+ public boolean hasNext() {
+ return nextIsValid || underlying.hasNext();
+ }
+
+ @Override
+ public T next() {
+ if (nextIsValid) {
+ nextIsValid = false;
+ return next;
+ } else {
+ return underlying.next();
+ }
+ }
+
+ public T peek() {
+ if (!nextIsValid) {
+ next = underlying.next();
+ nextIsValid = true;
+ }
+ return next;
+ }
+ }
+
+ /**
+ * An Iterable corresponding to a single tag.
+ *
+ * <p>The values in this iterable are populated lazily via the offer method
as tip advances for
+ * any tag.
+ *
+ * @param <T> The value type of the corresponging tag.
+ */
+ private static class TagIterable<T> implements Iterable<T> {
+ int tag;
+ int cacheSize;
+ Supplier<Boolean> forceCache;
+
+ ObservingReiterator<RawUnionValue> tip;
+
+ List<T> head;
+ Reiterator<RawUnionValue> tail;
+ boolean finished;
+
+ public TagIterable(
+ List<T> head, int tag, int cacheSize,
ObservingReiterator<RawUnionValue> tip) {
+ this.tag = tag;
+ this.cacheSize = cacheSize;
+ this.head = head;
+ this.tip = tip;
+ }
+
+ void offer(ObservingReiterator<RawUnionValue> tail) {
+ assert !finished;
+ assert tail.peek().getUnionTag() == tag;
+ if (head.size() < cacheSize) {
+ head.add((T) tail.peek().getValue());
+ } else if (this.tail == null) {
+ this.tail = tail.copy();
+ }
+ }
+
+ void finish() {
+ finished = true;
+ }
+
+ void seek(int tag) {
+ while (tip.hasNext() && tip.peek().getUnionTag() != tag) {
+ tip.next();
}
}
@Override
- public void remove() {
- throw new UnsupportedOperationException();
+ public Iterator<T> iterator() {
+ return new Iterator<T>() {
+
+ boolean isDone;
+ boolean advanced;
+ T next;
+
+ /** Keeps track of the index, in head, that this iterator points to. */
+ int index = -1;
+ /** If the index is beyond what was cached in head, this is this
iterators view of tail. */
+ Iterator<T> tailIter;
+
+ @Override
+ public boolean hasNext() {
+ if (!advanced) {
+ advance();
+ }
+ return !isDone;
+ }
+
+ @Override
+ public T next() {
+ if (!advanced) {
+ advance();
+ }
+ if (isDone) {
+ throw new NoSuchElementException();
+ }
+ advanced = false;
+ return next;
+ }
+
+ private void advance() {
+ assert !advanced;
+ assert !isDone;
+ advanced = true;
+
+ index++;
+ if (maybeAdvance()) {
+ return;
+ }
+
+ // We were unable to advance; advance tip to populate either head or
tail.
+ tip.fastForward();
+ if (tip.hasNext()) {
+ tip.next();
+ seek(tag);
+ }
+
+ // A this point, either head or tail should be sufficient to advance.
+ assert maybeAdvance();
+ }
+
+ private boolean maybeAdvance() {
+ if (index < head.size()) {
+ // First consume head.
+ assert tailIter == null;
+ next = head.get(index);
+ return true;
+ } else if (tail != null) {
+ // Next consume tail, if any.
+ if (tailIter == null) {
+ tailIter =
+ Iterators.transform(
+ Iterators.filter(
+ tail.copy(), taggedUnion ->
taggedUnion.getUnionTag() == tag),
+ taggedUnion -> (T) taggedUnion.getValue());
+ }
+ if (tailIter.hasNext()) {
+ next = tailIter.next();
+ } else {
+ isDone = true;
+ }
+ return true;
+ } else if (finished) {
+ // If there are no more elements in head, and tail was not
populated, and we are
+ // finshed, this is the end of the iteration.
Review comment:
```suggestion
// finished, this is the end of the iteration.
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]