This is an automated email from the ASF dual-hosted git repository.

emilyye pushed a commit to branch release-2.36.0
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/release-2.36.0 by this push:
     new bd46e7b  [release-2.36.0][BEAM-13541] More intelligent caching of 
CoGBK values. (#16354, #16407) (#16421)
bd46e7b is described below

commit bd46e7b01678e3e32a4b779f02e8745728acbee8
Author: emily <[email protected]>
AuthorDate: Thu Jan 6 10:24:36 2022 -0800

    [release-2.36.0][BEAM-13541] More intelligent caching of CoGBK values. 
(#16354, #16407) (#16421)
    
    Co-authored-by: Robert Bradshaw <[email protected]>
    Co-authored-by: Lukasz Cwik <[email protected]>
---
 .../beam/sdk/transforms/join/CoGbkResult.java      | 434 +++++++++++++++++----
 .../beam/sdk/transforms/join/CoGbkResultTest.java  | 201 +++++++++-
 2 files changed, 555 insertions(+), 80 deletions(-)

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
index 26f84a7..18f31ac 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
@@ -23,6 +23,7 @@ import java.io.OutputStream;
 import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
+import java.util.NoSuchElementException;
 import java.util.Objects;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderException;
@@ -35,7 +36,6 @@ import org.apache.beam.sdk.values.TupleTagList;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
-import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.PeekingIterator;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -59,6 +59,12 @@ public class CoGbkResult {
 
   private static final int DEFAULT_IN_MEMORY_ELEMENT_COUNT = 10_000;
 
+  /**
+   * Always try to cache at least this many elements per tag, even if it 
requires caching more than
+   * the total in memory count.
+   */
+  private static final int DEFAULT_MIN_ELEMENTS_PER_TAG = 100;
+
   private static final Logger LOG = LoggerFactory.getLogger(CoGbkResult.class);
 
   /**
@@ -69,16 +75,19 @@ public class CoGbkResult {
    * @param taggedValues the raw results from a group-by-key
    */
   public CoGbkResult(CoGbkResultSchema schema, Iterable<RawUnionValue> 
taggedValues) {
-    this(schema, taggedValues, DEFAULT_IN_MEMORY_ELEMENT_COUNT);
+    this(schema, taggedValues, DEFAULT_IN_MEMORY_ELEMENT_COUNT, 
DEFAULT_MIN_ELEMENTS_PER_TAG);
   }
 
   @SuppressWarnings("unchecked")
   public CoGbkResult(
-      CoGbkResultSchema schema, Iterable<RawUnionValue> taggedValues, int 
inMemoryElementCount) {
+      CoGbkResultSchema schema,
+      Iterable<RawUnionValue> taggedValues,
+      int inMemoryElementCount,
+      int minElementsPerTag) {
     this.schema = schema;
-    valueMap = new ArrayList<>();
+    List<List<Object>> valuesByTag = new ArrayList<>();
     for (int unionTag = 0; unionTag < schema.size(); unionTag++) {
-      valueMap.add(new ArrayList<>());
+      valuesByTag.add(new ArrayList<>());
     }
 
     // Demultiplex the first imMemoryElementCount tagged union values
@@ -98,38 +107,48 @@ public class CoGbkResult {
         throw new IllegalStateException(
             "union tag " + unionTag + " has no corresponding tuple tag in the 
result schema");
       }
-      List<Object> valueList = (List<Object>) valueMap.get(unionTag);
-      valueList.add(value.getValue());
-    }
-
-    if (taggedIter.hasNext()) {
-      // If we get here, there were more elements than we can afford to
-      // keep in memory, so we copy the re-iterable of remaining items
-      // and append filtered views to each of the sorted lists computed 
earlier.
-      LOG.info(
-          "CoGbkResult has more than {} elements, reiteration (which may be 
slow) is required.",
-          inMemoryElementCount);
-      final Reiterator<RawUnionValue> tail = (Reiterator<RawUnionValue>) 
taggedIter;
-      // This is a trinary-state array recording whether a given tag is 
present in the tail. The
-      // initial value is null (unknown) for all tags, and the first iteration 
through the entire
-      // list will set these values to true or false to avoid needlessly 
iterating if filtering
-      // against a given tag would not match anything.
-      final Boolean[] containsTag = new Boolean[schema.size()];
-      for (int unionTag = 0; unionTag < schema.size(); unionTag++) {
-        updateUnionTag(tail, containsTag, unionTag);
-      }
+      valuesByTag.get(unionTag).add(value.getValue());
     }
-  }
 
-  private <T> void updateUnionTag(
-      final Reiterator<RawUnionValue> tail, final Boolean[] containsTag, final 
int unionTag) {
-    @SuppressWarnings("unchecked")
-    final Iterable<T> head = (Iterable<T>) valueMap.get(unionTag);
-    valueMap.set(
-        unionTag,
-        () ->
-            Iterators.concat(
-                head.iterator(), new UnionValueIterator<T>(unionTag, 
tail.copy(), containsTag)));
+    if (!taggedIter.hasNext()) {
+      valueMap = (List) valuesByTag;
+      return;
+    }
+
+    // If we get here, there were more elements than we can afford to
+    // keep in memory, so we copy the re-iterable of remaining items
+    // and append filtered views to each of the sorted lists computed earlier.
+    LOG.info(
+        "CoGbkResult has more than {} elements, reiteration (which may be 
slow) is required.",
+        inMemoryElementCount);
+    final Reiterator<RawUnionValue> tail = (Reiterator<RawUnionValue>) 
taggedIter;
+
+    // As we iterate over this re-iterable (e.g. while iterating for one tag) 
we populate values
+    // for other observed tags, if any.
+    ObservingReiterator<RawUnionValue> tip =
+        new ObservingReiterator<>(
+            tail,
+            new ObservingReiterator.Observer<RawUnionValue>() {
+              @Override
+              public void observeAt(ObservingReiterator<RawUnionValue> 
reiterator) {
+                ((TagIterable<?>) 
valueMap.get(reiterator.peek().getUnionTag())).offer(reiterator);
+              }
+
+              @Override
+              public void done() {
+                // Inform all tags that we have reached the end of the 
iterable, so anything that
+                // can be observed has been observed.
+                for (Iterable<?> iter : valueMap) {
+                  ((TagIterable<?>) iter).finish();
+                }
+              }
+            });
+
+    valueMap = new ArrayList<>();
+    for (int unionTag = 0; unionTag < schema.size(); unionTag++) {
+      valueMap.add(
+          new TagIterable<Object>(valuesByTag.get(unionTag), unionTag, 
minElementsPerTag, tip));
+    }
   }
 
   public boolean isEmpty() {
@@ -361,62 +380,331 @@ public class CoGbkResult {
   }
 
   /**
-   * Lazily filters and recasts an {@code Iterator<RawUnionValue>} into an 
{@code Iterator<V>},
-   * where V is the type of the raw union value's contents.
+   * A re-iterable that notifies an observer at every advance, and upon 
finishing, but only once
+   * across all copies.
+   *
+   * @param <T> The value type of the underlying iterable.
    */
-  private static class UnionValueIterator<V> implements Iterator<V> {
+  private static class ObservingReiterator<T> implements Reiterator<T> {
+
+    public interface Observer<T> {
+      /**
+       * Called exactly once, across all copies before advancing this iterator.
+       *
+       * <p>The iterator rather than the element is given so that the callee 
can perform a copy if
+       * desired. This class offers a peek method to get at the current 
element without disturbing
+       * the state of this iterator.
+       */
+      void observeAt(ObservingReiterator<T> reiterator);
+
+      /** Called exactly once, across all copies, once this iterator is 
exhausted. */
+      void done();
+    }
 
-    private final int tag;
-    private final PeekingIterator<RawUnionValue> unions;
-    private final Boolean[] containsTag;
+    private PeekingReiterator<IndexingReiterator.Indexed<T>> underlying;
+    private Observer<T> observer;
 
-    private UnionValueIterator(int tag, Iterator<RawUnionValue> unions, 
Boolean[] containsTag) {
-      this.tag = tag;
-      this.unions = Iterators.peekingIterator(unions);
-      this.containsTag = containsTag;
+    // Used to keep track of what has been observed so far.
+    // These are arrays to facilitate sharing values among all copies of the 
same root Reiterator.
+    private final int[] lastObserved;
+    private final boolean[] doneHasRun;
+    private final PeekingReiterator[] mostAdvanced;
+
+    public ObservingReiterator(Reiterator<T> underlying, Observer<T> observer) 
{
+      this(new PeekingReiterator<>(new IndexingReiterator<>(underlying)), 
observer);
+    }
+
+    public ObservingReiterator(
+        PeekingReiterator<IndexingReiterator.Indexed<T>> underlying, 
Observer<T> observer) {
+      this(
+          underlying,
+          observer,
+          new int[] {-1},
+          new boolean[] {false},
+          new PeekingReiterator[] {underlying});
+    }
+
+    private ObservingReiterator(
+        PeekingReiterator<IndexingReiterator.Indexed<T>> underlying,
+        Observer<T> observer,
+        int[] lastObserved,
+        boolean[] doneHasRun,
+        PeekingReiterator[] mostAdvanced) {
+      this.underlying = underlying;
+      this.observer = observer;
+      this.lastObserved = lastObserved;
+      this.doneHasRun = doneHasRun;
+      this.mostAdvanced = mostAdvanced;
+    }
+
+    @Override
+    public Reiterator<T> copy() {
+      return new ObservingReiterator<T>(
+          underlying.copy(), observer, lastObserved, doneHasRun, mostAdvanced);
     }
 
     @Override
     public boolean hasNext() {
-      if (Boolean.FALSE.equals(containsTag[tag])) {
-        return false;
+      boolean hasNext = underlying.hasNext();
+      if (!hasNext && !doneHasRun[0]) {
+        mostAdvanced[0] = underlying;
+        observer.done();
+        doneHasRun[0] = true;
       }
-      advance();
-      if (unions.hasNext()) {
-        return true;
-      } else {
-        // Now that we've iterated over all the values, we can resolve all the 
"unknown" null
-        // values to false.
-        for (int i = 0; i < containsTag.length; i++) {
-          if (containsTag[i] == null) {
-            containsTag[i] = false;
-          }
-        }
-        return false;
+      return hasNext;
+    }
+
+    @Override
+    public T next() {
+      peek(); // trigger observation *before* advancing
+      return underlying.next().value;
+    }
+
+    public T peek() {
+      IndexingReiterator.Indexed<T> next = underlying.peek();
+      if (next.index > lastObserved[0]) {
+        assert next.index == lastObserved[0] + 1;
+        mostAdvanced[0] = underlying;
+        lastObserved[0] = next.index;
+        observer.observeAt(this);
       }
+      return next.value;
+    }
+
+    public void fastForward() {
+      if (underlying != mostAdvanced[0]) {
+        underlying = mostAdvanced[0].copy();
+      }
+    }
+  }
+
+  /**
+   * Assigns a monotonically increasing index to each item in the underling 
Reiterator.
+   *
+   * @param <T> The value type of the underlying iterable.
+   */
+  private static class IndexingReiterator<T> implements 
Reiterator<IndexingReiterator.Indexed<T>> {
+
+    private Reiterator<T> underlying;
+    private int index;
+
+    public IndexingReiterator(Reiterator<T> underlying) {
+      this(underlying, 0);
+    }
+
+    public IndexingReiterator(Reiterator<T> underlying, int start) {
+      this.underlying = underlying;
+      this.index = start;
     }
 
     @Override
-    @SuppressWarnings("unchecked")
-    public V next() {
-      advance();
-      return (V) unions.next().getValue();
+    public IndexingReiterator<T> copy() {
+      return new IndexingReiterator(underlying.copy(), index);
     }
 
-    private void advance() {
-      while (unions.hasNext()) {
-        int curTag = unions.peek().getUnionTag();
-        containsTag[curTag] = true;
-        if (curTag == tag) {
-          break;
-        }
-        unions.next();
+    @Override
+    public boolean hasNext() {
+      return underlying.hasNext();
+    }
+
+    @Override
+    public Indexed<T> next() {
+      return new Indexed<T>(index++, underlying.next());
+    }
+
+    public static class Indexed<T> {
+      public final int index;
+      public final T value;
+
+      public Indexed(int index, T value) {
+        this.index = index;
+        this.value = value;
       }
     }
+  }
+
+  /**
+   * Adapts an Reiterator, giving it a peek() method that can be used to 
observe the next element
+   * without consuming it.
+   *
+   * @param <T> The value type of the underlying iterable.
+   */
+  private static class PeekingReiterator<T> implements Reiterator<T> {
+    private Reiterator<T> underlying;
+    private T next;
+    private boolean nextIsValid;
+
+    public PeekingReiterator(Reiterator<T> underlying) {
+      this(underlying, null, false);
+    }
+
+    private PeekingReiterator(Reiterator<T> underlying, T next, boolean 
nextIsValid) {
+      this.underlying = underlying;
+      this.next = next;
+      this.nextIsValid = nextIsValid;
+    }
+
+    @Override
+    public PeekingReiterator<T> copy() {
+      return new PeekingReiterator(underlying.copy(), next, nextIsValid);
+    }
+
+    @Override
+    public boolean hasNext() {
+      return nextIsValid || underlying.hasNext();
+    }
 
     @Override
-    public void remove() {
-      throw new UnsupportedOperationException();
+    public T next() {
+      if (nextIsValid) {
+        nextIsValid = false;
+        return next;
+      } else {
+        return underlying.next();
+      }
+    }
+
+    public T peek() {
+      if (!nextIsValid) {
+        next = underlying.next();
+        nextIsValid = true;
+      }
+      return next;
+    }
+  }
+
+  /**
+   * An Iterable corresponding to a single tag.
+   *
+   * <p>The values in this iterable are populated lazily via the offer method 
as tip advances for
+   * any tag.
+   *
+   * @param <T> The value type of the corresponging tag.
+   */
+  private static class TagIterable<T> implements Iterable<T> {
+    int tag;
+    int cacheSize;
+
+    ObservingReiterator<RawUnionValue> tip;
+
+    List<T> head;
+    Reiterator<RawUnionValue> tail;
+    boolean finished;
+
+    public TagIterable(
+        List<T> head, int tag, int cacheSize, 
ObservingReiterator<RawUnionValue> tip) {
+      this.tag = tag;
+      this.cacheSize = cacheSize;
+      this.head = head;
+      this.tip = tip;
+    }
+
+    void offer(ObservingReiterator<RawUnionValue> tail) {
+      assert !finished;
+      assert tail.peek().getUnionTag() == tag;
+      if (head.size() < cacheSize) {
+        head.add((T) tail.peek().getValue());
+      } else if (this.tail == null) {
+        this.tail = tail.copy();
+      }
+    }
+
+    void finish() {
+      finished = true;
+    }
+
+    void seek(int tag) {
+      while (tip.hasNext() && tip.peek().getUnionTag() != tag) {
+        tip.next();
+      }
+    }
+
+    @Override
+    public Iterator<T> iterator() {
+      return new Iterator<T>() {
+
+        boolean isDone;
+        boolean advanced;
+        T next;
+
+        /** Keeps track of the index, in head, that this iterator points to. */
+        int index = -1;
+        /** If the index is beyond what was cached in head, this is this 
iterators view of tail. */
+        Iterator<T> tailIter;
+
+        @Override
+        public boolean hasNext() {
+          if (!advanced) {
+            advance();
+          }
+          return !isDone;
+        }
+
+        @Override
+        public T next() {
+          if (!advanced) {
+            advance();
+          }
+          if (isDone) {
+            throw new NoSuchElementException();
+          }
+          advanced = false;
+          return next;
+        }
+
+        private void advance() {
+          assert !advanced;
+          assert !isDone;
+          advanced = true;
+
+          index++;
+          if (maybeAdvance()) {
+            return;
+          }
+
+          // We were unable to advance; advance tip to populate either head or 
tail.
+          tip.fastForward();
+          if (tip.hasNext()) {
+            tip.next();
+            seek(tag);
+          }
+
+          // A this point, either head or tail should be sufficient to advance.
+          assert maybeAdvance();
+        }
+
+        private boolean maybeAdvance() {
+          if (index < head.size()) {
+            // First consume head.
+            assert tailIter == null;
+            next = head.get(index);
+            return true;
+          } else if (tail != null) {
+            // Next consume tail, if any.
+            if (tailIter == null) {
+              tailIter =
+                  Iterators.transform(
+                      Iterators.filter(
+                          tail.copy(), taggedUnion -> 
taggedUnion.getUnionTag() == tag),
+                      taggedUnion -> (T) taggedUnion.getValue());
+            }
+            if (tailIter.hasNext()) {
+              next = tailIter.next();
+            } else {
+              isDone = true;
+            }
+            return true;
+          } else if (finished) {
+            // If there are no more elements in head, and tail was not 
populated, and we are
+            // finished, this is the end of the iteration.
+            isDone = true;
+            return true;
+          } else {
+            // We need more elements in either head or tail.
+            return false;
+          }
+        }
+      };
     }
   }
 }
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/join/CoGbkResultTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/join/CoGbkResultTest.java
index 0598999..53ddc17 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/join/CoGbkResultTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/join/CoGbkResultTest.java
@@ -21,33 +21,50 @@ import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.emptyIterable;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.lessThanOrEqualTo;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.sameInstance;
 
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
 import java.util.List;
+import java.util.Map;
+import java.util.Random;
 import org.apache.beam.sdk.util.common.Reiterable;
 import org.apache.beam.sdk.util.common.Reiterator;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
+import org.apache.commons.compress.utils.Lists;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /** Tests the CoGbkResult. */
 @RunWith(JUnit4.class)
 public class CoGbkResultTest {
 
+  private static final Logger LOG = 
LoggerFactory.getLogger(CoGbkResultTest.class);
+
+  private static final int TEST_CACHE_SIZE = 5;
+
   @Test
-  public void testLazyResults() {
-    runLazyResult(0);
-    runLazyResult(1);
-    runLazyResult(3);
-    runLazyResult(10);
+  public void testExpectedResults() {
+    runExpectedResult(0);
+    runExpectedResult(1);
+    runExpectedResult(3);
+    runExpectedResult(10);
   }
 
-  public void runLazyResult(int cacheSize) {
+  public void runExpectedResult(int cacheSize) {
     int valueLen = 7;
     TestUnionValues values = new TestUnionValues(0, 1, 0, 3, 0, 3, 3);
-    CoGbkResult result = new CoGbkResult(createSchema(5), values, cacheSize);
+    CoGbkResult result = new CoGbkResult(createSchema(5), values, cacheSize, 
0);
     assertThat(values.maxPos(), equalTo(Math.min(cacheSize, valueLen)));
     assertThat(result.getAll(new TupleTag<>("tag0")), contains(0, 2, 4));
     assertThat(values.maxPos(), equalTo(valueLen));
@@ -57,6 +74,176 @@ public class CoGbkResultTest {
     assertThat(result.getAll(new TupleTag<>("tag0")), contains(0, 2, 4));
   }
 
+  @Test
+  public void testLazyResults() {
+    TestUnionValues values = new TestUnionValues(0, 0, 1, 1, 0, 1, 1);
+    CoGbkResult result = new CoGbkResult(createSchema(5), values, 0, 2);
+    // Nothing is read until we try to iterate.
+    assertThat(values.maxPos(), equalTo(0));
+    Iterable<?> tag0iterable = result.getAll("tag0");
+    assertThat(values.maxPos(), equalTo(0));
+    tag0iterable.iterator();
+    assertThat(values.maxPos(), equalTo(0));
+
+    // Iterating reads (nearly) the minimal number of values.
+    Iterator<?> tag0 = tag0iterable.iterator();
+    tag0.next();
+    assertThat(values.maxPos(), lessThanOrEqualTo(2));
+    tag0.next();
+    assertThat(values.maxPos(), equalTo(2));
+    // Note that we're skipping over tag 1.
+    tag0.next();
+    assertThat(values.maxPos(), equalTo(5));
+
+    // Iterating again does not cause more reads.
+    Iterator<?> tag0iterAgain = tag0iterable.iterator();
+    tag0iterAgain.next();
+    tag0iterAgain.next();
+    tag0iterAgain.next();
+    assertThat(values.maxPos(), equalTo(5));
+
+    // Iterating over other tags does not cause more reads for values we have 
seen.
+    Iterator<?> tag1 = result.getAll("tag1").iterator();
+    tag1.next();
+    tag1.next();
+    assertThat(values.maxPos(), equalTo(5));
+    // However, finding the next tag1 value does require more reads.
+    tag1.next();
+    assertThat(values.maxPos(), equalTo(6));
+  }
+
+  @Test
+  @SuppressWarnings("BoxedPrimitiveEquality")
+  public void testCachedResults() {
+    // Ensure we don't fail below due to a non-default 
java.lang.Integer.IntegerCache.high setting,
+    // as we want to test our cache is working as expected, unimpeded by a 
higher-level cache.
+    int integerCacheLimit = 128;
+    assertThat(
+        Integer.valueOf(integerCacheLimit), 
not(sameInstance(Integer.valueOf(integerCacheLimit))));
+
+    int perTagCache = 10;
+    int crossTagCache = 2 * integerCacheLimit;
+    int[] tags = new int[crossTagCache + 8 * perTagCache];
+    for (int i = 0; i < 2 * perTagCache; i++) {
+      tags[crossTagCache + 4 * i] = 1;
+      tags[crossTagCache + 4 * i + 1] = 2;
+    }
+
+    TestUnionValues values = new TestUnionValues(tags);
+    CoGbkResult result = new CoGbkResult(createSchema(5), values, 
crossTagCache, perTagCache);
+
+    // More that perTagCache values should be cached for the first tag, as 
they came first.
+    List<Object> tag0 = Lists.newArrayList(result.getAll("tag0").iterator());
+    List<Object> tag0again = 
Lists.newArrayList(result.getAll("tag0").iterator());
+    assertThat(tag0.get(0), sameInstance(tag0again.get(0)));
+    assertThat(tag0.get(integerCacheLimit), 
sameInstance(tag0again.get(integerCacheLimit)));
+    assertThat(tag0.get(crossTagCache - 1), 
sameInstance(tag0again.get(crossTagCache - 1)));
+    // However, not all elements are cached.
+    assertThat(tag0.get(tag0.size() - 1), 
not(sameInstance(tag0again.get(tag0.size() - 1))));
+
+    // For tag 1 and tag 2, we cache perTagCache elements, plus possibly one 
more due to peeking
+    // iterators.
+    List<Object> tag1 = Lists.newArrayList(result.getAll("tag1").iterator());
+    List<Object> tag1again = 
Lists.newArrayList(result.getAll("tag1").iterator());
+    assertThat(tag1.get(0), sameInstance(tag1again.get(0)));
+    assertThat(tag1.get(perTagCache - 1), 
sameInstance(tag1again.get(perTagCache - 1)));
+    assertThat(tag1.get(perTagCache + 1), 
not(sameInstance(tag1again.get(perTagCache + 1))));
+
+    List<Object> tag2 = Lists.newArrayList(result.getAll("tag1").iterator());
+    List<Object> tag2again = 
Lists.newArrayList(result.getAll("tag1").iterator());
+    assertThat(tag2.get(0), sameInstance(tag2again.get(0)));
+    assertThat(tag2.get(perTagCache - 1), 
sameInstance(tag2again.get(perTagCache - 1)));
+    assertThat(tag2.get(perTagCache + 1), 
not(sameInstance(tag2again.get(perTagCache + 1))));
+  }
+
+  @Test
+  public void testSingleTag() {
+    runCrazyIteration(1, TEST_CACHE_SIZE / 2);
+    runCrazyIteration(1, TEST_CACHE_SIZE * 2);
+    runCrazyIteration(2, TEST_CACHE_SIZE / 2);
+    runCrazyIteration(2, TEST_CACHE_SIZE * 2);
+    runCrazyIteration(10, TEST_CACHE_SIZE * 10);
+  }
+
+  @Test
+  public void testTwoTags() {
+    runCrazyIteration(1, TEST_CACHE_SIZE / 2, TEST_CACHE_SIZE / 2);
+    runCrazyIteration(1, TEST_CACHE_SIZE * 2, TEST_CACHE_SIZE * 2);
+    runCrazyIteration(2, TEST_CACHE_SIZE / 2, TEST_CACHE_SIZE / 2);
+    runCrazyIteration(2, TEST_CACHE_SIZE * 2, TEST_CACHE_SIZE * 2);
+    runCrazyIteration(10, TEST_CACHE_SIZE * 10);
+  }
+
+  @Test
+  public void testLargeSmall() {
+    runCrazyIteration(1, TEST_CACHE_SIZE / 2, TEST_CACHE_SIZE * 2);
+    runCrazyIteration(1, TEST_CACHE_SIZE / 2, TEST_CACHE_SIZE * 20);
+    runCrazyIteration(2, TEST_CACHE_SIZE / 2, TEST_CACHE_SIZE * 20);
+    runCrazyIteration(10, TEST_CACHE_SIZE / 2, TEST_CACHE_SIZE * 20);
+  }
+
+  @Test
+  public void testManyTags() {
+    runCrazyIteration(1, 2, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100);
+    runCrazyIteration(2, 2, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100);
+    runCrazyIteration(10, 2, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100);
+  }
+
+  public void runCrazyIteration(int numIterations, int... tagSizes) {
+    for (int trial = 0; trial < 10; trial++) {
+      // Populate this with a constant to reproduce failures.
+      int seed = (int) (Integer.MAX_VALUE * Math.random());
+      LOG.info("Running " + Arrays.toString(tagSizes) + " with seed " + seed);
+      Random random = new Random(seed);
+      List<Integer> tags = new ArrayList<>();
+      for (int tagNum = 0; tagNum < tagSizes.length; tagNum++) {
+        for (int i = 0; i < tagSizes[tagNum]; i++) {
+          tags.add(tagNum);
+        }
+      }
+      Collections.shuffle(tags, random);
+
+      Map<TupleTag<Integer>, List<Integer>> expected = new HashMap<>();
+      for (int tagNum = 0; tagNum < tagSizes.length; tagNum++) {
+        expected.put(new TupleTag<>("tag" + tagNum), new ArrayList<>());
+      }
+      for (int i = 0; i < tags.size(); i++) {
+        expected.get(new TupleTag<>("tag" + tags.get(i))).add(i);
+      }
+
+      List<KV<Integer, TupleTag<Integer>>> callOrder = new ArrayList<>();
+      for (int i = 0; i < numIterations; i++) {
+        for (int tagNum = 0; tagNum < tagSizes.length; tagNum++) {
+          for (int k = 0; k < tagSizes[tagNum]; k++) {
+            callOrder.add(KV.of(i, new TupleTag<>("tag" + tagNum)));
+          }
+        }
+      }
+      Collections.shuffle(callOrder, random);
+
+      Map<KV<Integer, TupleTag<Integer>>, Iterator<Integer>> iters = new 
HashMap<>();
+      Map<KV<Integer, TupleTag<Integer>>, List<Integer>> actual = new 
HashMap<>();
+      TestUnionValues values = new TestUnionValues(tags.stream().mapToInt(i -> 
i).toArray());
+      CoGbkResult coGbkResult =
+          new CoGbkResult(createSchema(tagSizes.length), values, 0, 
TEST_CACHE_SIZE);
+
+      for (KV<Integer, TupleTag<Integer>> call : callOrder) {
+        if (!iters.containsKey(call)) {
+          iters.put(call, coGbkResult.getAll(call.getValue()).iterator());
+          actual.put(call, new ArrayList<>());
+        }
+        actual.get(call).add(iters.get(call).next());
+        if (random.nextDouble() < 0.5 / numIterations) {
+          iters.get(call).hasNext();
+        }
+      }
+
+      for (Map.Entry<KV<Integer, TupleTag<Integer>>, List<Integer>> result : 
actual.entrySet()) {
+        assertThat(result.getValue(), 
contains(expected.get(result.getKey().getValue()).toArray()));
+      }
+    }
+  }
+
   private CoGbkResultSchema createSchema(int size) {
     List<TupleTag<?>> tags = new ArrayList<>();
     for (int i = 0; i < size; i++) {

Reply via email to