This is an automated email from the ASF dual-hosted git repository.
emilyye pushed a commit to branch release-2.36.0
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/release-2.36.0 by this push:
new bd46e7b [release-2.36.0][BEAM-13541] More intelligent caching of
CoGBK values. (#16354, #16407) (#16421)
bd46e7b is described below
commit bd46e7b01678e3e32a4b779f02e8745728acbee8
Author: emily <[email protected]>
AuthorDate: Thu Jan 6 10:24:36 2022 -0800
[release-2.36.0][BEAM-13541] More intelligent caching of CoGBK values.
(#16354, #16407) (#16421)
Co-authored-by: Robert Bradshaw <[email protected]>
Co-authored-by: Lukasz Cwik <[email protected]>
---
.../beam/sdk/transforms/join/CoGbkResult.java | 434 +++++++++++++++++----
.../beam/sdk/transforms/join/CoGbkResultTest.java | 201 +++++++++-
2 files changed, 555 insertions(+), 80 deletions(-)
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
index 26f84a7..18f31ac 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/CoGbkResult.java
@@ -23,6 +23,7 @@ import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
+import java.util.NoSuchElementException;
import java.util.Objects;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
@@ -35,7 +36,6 @@ import org.apache.beam.sdk.values.TupleTagList;
import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
-import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.PeekingIterator;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -59,6 +59,12 @@ public class CoGbkResult {
private static final int DEFAULT_IN_MEMORY_ELEMENT_COUNT = 10_000;
+ /**
+ * Always try to cache at least this many elements per tag, even if it
requires caching more than
+ * the total in memory count.
+ */
+ private static final int DEFAULT_MIN_ELEMENTS_PER_TAG = 100;
+
private static final Logger LOG = LoggerFactory.getLogger(CoGbkResult.class);
/**
@@ -69,16 +75,19 @@ public class CoGbkResult {
* @param taggedValues the raw results from a group-by-key
*/
public CoGbkResult(CoGbkResultSchema schema, Iterable<RawUnionValue>
taggedValues) {
- this(schema, taggedValues, DEFAULT_IN_MEMORY_ELEMENT_COUNT);
+ this(schema, taggedValues, DEFAULT_IN_MEMORY_ELEMENT_COUNT,
DEFAULT_MIN_ELEMENTS_PER_TAG);
}
@SuppressWarnings("unchecked")
public CoGbkResult(
- CoGbkResultSchema schema, Iterable<RawUnionValue> taggedValues, int
inMemoryElementCount) {
+ CoGbkResultSchema schema,
+ Iterable<RawUnionValue> taggedValues,
+ int inMemoryElementCount,
+ int minElementsPerTag) {
this.schema = schema;
- valueMap = new ArrayList<>();
+ List<List<Object>> valuesByTag = new ArrayList<>();
for (int unionTag = 0; unionTag < schema.size(); unionTag++) {
- valueMap.add(new ArrayList<>());
+ valuesByTag.add(new ArrayList<>());
}
// Demultiplex the first imMemoryElementCount tagged union values
@@ -98,38 +107,48 @@ public class CoGbkResult {
throw new IllegalStateException(
"union tag " + unionTag + " has no corresponding tuple tag in the
result schema");
}
- List<Object> valueList = (List<Object>) valueMap.get(unionTag);
- valueList.add(value.getValue());
- }
-
- if (taggedIter.hasNext()) {
- // If we get here, there were more elements than we can afford to
- // keep in memory, so we copy the re-iterable of remaining items
- // and append filtered views to each of the sorted lists computed
earlier.
- LOG.info(
- "CoGbkResult has more than {} elements, reiteration (which may be
slow) is required.",
- inMemoryElementCount);
- final Reiterator<RawUnionValue> tail = (Reiterator<RawUnionValue>)
taggedIter;
- // This is a trinary-state array recording whether a given tag is
present in the tail. The
- // initial value is null (unknown) for all tags, and the first iteration
through the entire
- // list will set these values to true or false to avoid needlessly
iterating if filtering
- // against a given tag would not match anything.
- final Boolean[] containsTag = new Boolean[schema.size()];
- for (int unionTag = 0; unionTag < schema.size(); unionTag++) {
- updateUnionTag(tail, containsTag, unionTag);
- }
+ valuesByTag.get(unionTag).add(value.getValue());
}
- }
- private <T> void updateUnionTag(
- final Reiterator<RawUnionValue> tail, final Boolean[] containsTag, final
int unionTag) {
- @SuppressWarnings("unchecked")
- final Iterable<T> head = (Iterable<T>) valueMap.get(unionTag);
- valueMap.set(
- unionTag,
- () ->
- Iterators.concat(
- head.iterator(), new UnionValueIterator<T>(unionTag,
tail.copy(), containsTag)));
+ if (!taggedIter.hasNext()) {
+ valueMap = (List) valuesByTag;
+ return;
+ }
+
+ // If we get here, there were more elements than we can afford to
+ // keep in memory, so we copy the re-iterable of remaining items
+ // and append filtered views to each of the sorted lists computed earlier.
+ LOG.info(
+ "CoGbkResult has more than {} elements, reiteration (which may be
slow) is required.",
+ inMemoryElementCount);
+ final Reiterator<RawUnionValue> tail = (Reiterator<RawUnionValue>)
taggedIter;
+
+ // As we iterate over this re-iterable (e.g. while iterating for one tag)
we populate values
+ // for other observed tags, if any.
+ ObservingReiterator<RawUnionValue> tip =
+ new ObservingReiterator<>(
+ tail,
+ new ObservingReiterator.Observer<RawUnionValue>() {
+ @Override
+ public void observeAt(ObservingReiterator<RawUnionValue>
reiterator) {
+ ((TagIterable<?>)
valueMap.get(reiterator.peek().getUnionTag())).offer(reiterator);
+ }
+
+ @Override
+ public void done() {
+ // Inform all tags that we have reached the end of the
iterable, so anything that
+ // can be observed has been observed.
+ for (Iterable<?> iter : valueMap) {
+ ((TagIterable<?>) iter).finish();
+ }
+ }
+ });
+
+ valueMap = new ArrayList<>();
+ for (int unionTag = 0; unionTag < schema.size(); unionTag++) {
+ valueMap.add(
+ new TagIterable<Object>(valuesByTag.get(unionTag), unionTag,
minElementsPerTag, tip));
+ }
}
public boolean isEmpty() {
@@ -361,62 +380,331 @@ public class CoGbkResult {
}
/**
- * Lazily filters and recasts an {@code Iterator<RawUnionValue>} into an
{@code Iterator<V>},
- * where V is the type of the raw union value's contents.
+ * A re-iterable that notifies an observer at every advance, and upon
finishing, but only once
+ * across all copies.
+ *
+ * @param <T> The value type of the underlying iterable.
*/
- private static class UnionValueIterator<V> implements Iterator<V> {
+ private static class ObservingReiterator<T> implements Reiterator<T> {
+
+ public interface Observer<T> {
+ /**
+ * Called exactly once, across all copies before advancing this iterator.
+ *
+ * <p>The iterator rather than the element is given so that the callee
can perform a copy if
+ * desired. This class offers a peek method to get at the current
element without disturbing
+ * the state of this iterator.
+ */
+ void observeAt(ObservingReiterator<T> reiterator);
+
+ /** Called exactly once, across all copies, once this iterator is
exhausted. */
+ void done();
+ }
- private final int tag;
- private final PeekingIterator<RawUnionValue> unions;
- private final Boolean[] containsTag;
+ private PeekingReiterator<IndexingReiterator.Indexed<T>> underlying;
+ private Observer<T> observer;
- private UnionValueIterator(int tag, Iterator<RawUnionValue> unions,
Boolean[] containsTag) {
- this.tag = tag;
- this.unions = Iterators.peekingIterator(unions);
- this.containsTag = containsTag;
+ // Used to keep track of what has been observed so far.
+ // These are arrays to facilitate sharing values among all copies of the
same root Reiterator.
+ private final int[] lastObserved;
+ private final boolean[] doneHasRun;
+ private final PeekingReiterator[] mostAdvanced;
+
+ public ObservingReiterator(Reiterator<T> underlying, Observer<T> observer)
{
+ this(new PeekingReiterator<>(new IndexingReiterator<>(underlying)),
observer);
+ }
+
+ public ObservingReiterator(
+ PeekingReiterator<IndexingReiterator.Indexed<T>> underlying,
Observer<T> observer) {
+ this(
+ underlying,
+ observer,
+ new int[] {-1},
+ new boolean[] {false},
+ new PeekingReiterator[] {underlying});
+ }
+
+ private ObservingReiterator(
+ PeekingReiterator<IndexingReiterator.Indexed<T>> underlying,
+ Observer<T> observer,
+ int[] lastObserved,
+ boolean[] doneHasRun,
+ PeekingReiterator[] mostAdvanced) {
+ this.underlying = underlying;
+ this.observer = observer;
+ this.lastObserved = lastObserved;
+ this.doneHasRun = doneHasRun;
+ this.mostAdvanced = mostAdvanced;
+ }
+
+ @Override
+ public Reiterator<T> copy() {
+ return new ObservingReiterator<T>(
+ underlying.copy(), observer, lastObserved, doneHasRun, mostAdvanced);
}
@Override
public boolean hasNext() {
- if (Boolean.FALSE.equals(containsTag[tag])) {
- return false;
+ boolean hasNext = underlying.hasNext();
+ if (!hasNext && !doneHasRun[0]) {
+ mostAdvanced[0] = underlying;
+ observer.done();
+ doneHasRun[0] = true;
}
- advance();
- if (unions.hasNext()) {
- return true;
- } else {
- // Now that we've iterated over all the values, we can resolve all the
"unknown" null
- // values to false.
- for (int i = 0; i < containsTag.length; i++) {
- if (containsTag[i] == null) {
- containsTag[i] = false;
- }
- }
- return false;
+ return hasNext;
+ }
+
+ @Override
+ public T next() {
+ peek(); // trigger observation *before* advancing
+ return underlying.next().value;
+ }
+
+ public T peek() {
+ IndexingReiterator.Indexed<T> next = underlying.peek();
+ if (next.index > lastObserved[0]) {
+ assert next.index == lastObserved[0] + 1;
+ mostAdvanced[0] = underlying;
+ lastObserved[0] = next.index;
+ observer.observeAt(this);
}
+ return next.value;
+ }
+
+ public void fastForward() {
+ if (underlying != mostAdvanced[0]) {
+ underlying = mostAdvanced[0].copy();
+ }
+ }
+ }
+
+ /**
+ * Assigns a monotonically increasing index to each item in the underling
Reiterator.
+ *
+ * @param <T> The value type of the underlying iterable.
+ */
+ private static class IndexingReiterator<T> implements
Reiterator<IndexingReiterator.Indexed<T>> {
+
+ private Reiterator<T> underlying;
+ private int index;
+
+ public IndexingReiterator(Reiterator<T> underlying) {
+ this(underlying, 0);
+ }
+
+ public IndexingReiterator(Reiterator<T> underlying, int start) {
+ this.underlying = underlying;
+ this.index = start;
}
@Override
- @SuppressWarnings("unchecked")
- public V next() {
- advance();
- return (V) unions.next().getValue();
+ public IndexingReiterator<T> copy() {
+ return new IndexingReiterator(underlying.copy(), index);
}
- private void advance() {
- while (unions.hasNext()) {
- int curTag = unions.peek().getUnionTag();
- containsTag[curTag] = true;
- if (curTag == tag) {
- break;
- }
- unions.next();
+ @Override
+ public boolean hasNext() {
+ return underlying.hasNext();
+ }
+
+ @Override
+ public Indexed<T> next() {
+ return new Indexed<T>(index++, underlying.next());
+ }
+
+ public static class Indexed<T> {
+ public final int index;
+ public final T value;
+
+ public Indexed(int index, T value) {
+ this.index = index;
+ this.value = value;
}
}
+ }
+
+ /**
+ * Adapts an Reiterator, giving it a peek() method that can be used to
observe the next element
+ * without consuming it.
+ *
+ * @param <T> The value type of the underlying iterable.
+ */
+ private static class PeekingReiterator<T> implements Reiterator<T> {
+ private Reiterator<T> underlying;
+ private T next;
+ private boolean nextIsValid;
+
+ public PeekingReiterator(Reiterator<T> underlying) {
+ this(underlying, null, false);
+ }
+
+ private PeekingReiterator(Reiterator<T> underlying, T next, boolean
nextIsValid) {
+ this.underlying = underlying;
+ this.next = next;
+ this.nextIsValid = nextIsValid;
+ }
+
+ @Override
+ public PeekingReiterator<T> copy() {
+ return new PeekingReiterator(underlying.copy(), next, nextIsValid);
+ }
+
+ @Override
+ public boolean hasNext() {
+ return nextIsValid || underlying.hasNext();
+ }
@Override
- public void remove() {
- throw new UnsupportedOperationException();
+ public T next() {
+ if (nextIsValid) {
+ nextIsValid = false;
+ return next;
+ } else {
+ return underlying.next();
+ }
+ }
+
+ public T peek() {
+ if (!nextIsValid) {
+ next = underlying.next();
+ nextIsValid = true;
+ }
+ return next;
+ }
+ }
+
+ /**
+ * An Iterable corresponding to a single tag.
+ *
+ * <p>The values in this iterable are populated lazily via the offer method
as tip advances for
+ * any tag.
+ *
+ * @param <T> The value type of the corresponging tag.
+ */
+ private static class TagIterable<T> implements Iterable<T> {
+ int tag;
+ int cacheSize;
+
+ ObservingReiterator<RawUnionValue> tip;
+
+ List<T> head;
+ Reiterator<RawUnionValue> tail;
+ boolean finished;
+
+ public TagIterable(
+ List<T> head, int tag, int cacheSize,
ObservingReiterator<RawUnionValue> tip) {
+ this.tag = tag;
+ this.cacheSize = cacheSize;
+ this.head = head;
+ this.tip = tip;
+ }
+
+ void offer(ObservingReiterator<RawUnionValue> tail) {
+ assert !finished;
+ assert tail.peek().getUnionTag() == tag;
+ if (head.size() < cacheSize) {
+ head.add((T) tail.peek().getValue());
+ } else if (this.tail == null) {
+ this.tail = tail.copy();
+ }
+ }
+
+ void finish() {
+ finished = true;
+ }
+
+ void seek(int tag) {
+ while (tip.hasNext() && tip.peek().getUnionTag() != tag) {
+ tip.next();
+ }
+ }
+
+ @Override
+ public Iterator<T> iterator() {
+ return new Iterator<T>() {
+
+ boolean isDone;
+ boolean advanced;
+ T next;
+
+ /** Keeps track of the index, in head, that this iterator points to. */
+ int index = -1;
+ /** If the index is beyond what was cached in head, this is this
iterators view of tail. */
+ Iterator<T> tailIter;
+
+ @Override
+ public boolean hasNext() {
+ if (!advanced) {
+ advance();
+ }
+ return !isDone;
+ }
+
+ @Override
+ public T next() {
+ if (!advanced) {
+ advance();
+ }
+ if (isDone) {
+ throw new NoSuchElementException();
+ }
+ advanced = false;
+ return next;
+ }
+
+ private void advance() {
+ assert !advanced;
+ assert !isDone;
+ advanced = true;
+
+ index++;
+ if (maybeAdvance()) {
+ return;
+ }
+
+ // We were unable to advance; advance tip to populate either head or
tail.
+ tip.fastForward();
+ if (tip.hasNext()) {
+ tip.next();
+ seek(tag);
+ }
+
+ // A this point, either head or tail should be sufficient to advance.
+ assert maybeAdvance();
+ }
+
+ private boolean maybeAdvance() {
+ if (index < head.size()) {
+ // First consume head.
+ assert tailIter == null;
+ next = head.get(index);
+ return true;
+ } else if (tail != null) {
+ // Next consume tail, if any.
+ if (tailIter == null) {
+ tailIter =
+ Iterators.transform(
+ Iterators.filter(
+ tail.copy(), taggedUnion ->
taggedUnion.getUnionTag() == tag),
+ taggedUnion -> (T) taggedUnion.getValue());
+ }
+ if (tailIter.hasNext()) {
+ next = tailIter.next();
+ } else {
+ isDone = true;
+ }
+ return true;
+ } else if (finished) {
+ // If there are no more elements in head, and tail was not
populated, and we are
+ // finished, this is the end of the iteration.
+ isDone = true;
+ return true;
+ } else {
+ // We need more elements in either head or tail.
+ return false;
+ }
+ }
+ };
}
}
}
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/join/CoGbkResultTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/join/CoGbkResultTest.java
index 0598999..53ddc17 100644
---
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/join/CoGbkResultTest.java
+++
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/join/CoGbkResultTest.java
@@ -21,33 +21,50 @@ import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.emptyIterable;
import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.lessThanOrEqualTo;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.sameInstance;
import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
import java.util.List;
+import java.util.Map;
+import java.util.Random;
import org.apache.beam.sdk.util.common.Reiterable;
import org.apache.beam.sdk.util.common.Reiterator;
+import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
+import org.apache.commons.compress.utils.Lists;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/** Tests the CoGbkResult. */
@RunWith(JUnit4.class)
public class CoGbkResultTest {
+ private static final Logger LOG =
LoggerFactory.getLogger(CoGbkResultTest.class);
+
+ private static final int TEST_CACHE_SIZE = 5;
+
@Test
- public void testLazyResults() {
- runLazyResult(0);
- runLazyResult(1);
- runLazyResult(3);
- runLazyResult(10);
+ public void testExpectedResults() {
+ runExpectedResult(0);
+ runExpectedResult(1);
+ runExpectedResult(3);
+ runExpectedResult(10);
}
- public void runLazyResult(int cacheSize) {
+ public void runExpectedResult(int cacheSize) {
int valueLen = 7;
TestUnionValues values = new TestUnionValues(0, 1, 0, 3, 0, 3, 3);
- CoGbkResult result = new CoGbkResult(createSchema(5), values, cacheSize);
+ CoGbkResult result = new CoGbkResult(createSchema(5), values, cacheSize,
0);
assertThat(values.maxPos(), equalTo(Math.min(cacheSize, valueLen)));
assertThat(result.getAll(new TupleTag<>("tag0")), contains(0, 2, 4));
assertThat(values.maxPos(), equalTo(valueLen));
@@ -57,6 +74,176 @@ public class CoGbkResultTest {
assertThat(result.getAll(new TupleTag<>("tag0")), contains(0, 2, 4));
}
+ @Test
+ public void testLazyResults() {
+ TestUnionValues values = new TestUnionValues(0, 0, 1, 1, 0, 1, 1);
+ CoGbkResult result = new CoGbkResult(createSchema(5), values, 0, 2);
+ // Nothing is read until we try to iterate.
+ assertThat(values.maxPos(), equalTo(0));
+ Iterable<?> tag0iterable = result.getAll("tag0");
+ assertThat(values.maxPos(), equalTo(0));
+ tag0iterable.iterator();
+ assertThat(values.maxPos(), equalTo(0));
+
+ // Iterating reads (nearly) the minimal number of values.
+ Iterator<?> tag0 = tag0iterable.iterator();
+ tag0.next();
+ assertThat(values.maxPos(), lessThanOrEqualTo(2));
+ tag0.next();
+ assertThat(values.maxPos(), equalTo(2));
+ // Note that we're skipping over tag 1.
+ tag0.next();
+ assertThat(values.maxPos(), equalTo(5));
+
+ // Iterating again does not cause more reads.
+ Iterator<?> tag0iterAgain = tag0iterable.iterator();
+ tag0iterAgain.next();
+ tag0iterAgain.next();
+ tag0iterAgain.next();
+ assertThat(values.maxPos(), equalTo(5));
+
+ // Iterating over other tags does not cause more reads for values we have
seen.
+ Iterator<?> tag1 = result.getAll("tag1").iterator();
+ tag1.next();
+ tag1.next();
+ assertThat(values.maxPos(), equalTo(5));
+ // However, finding the next tag1 value does require more reads.
+ tag1.next();
+ assertThat(values.maxPos(), equalTo(6));
+ }
+
+ @Test
+ @SuppressWarnings("BoxedPrimitiveEquality")
+ public void testCachedResults() {
+ // Ensure we don't fail below due to a non-default
java.lang.Integer.IntegerCache.high setting,
+ // as we want to test our cache is working as expected, unimpeded by a
higher-level cache.
+ int integerCacheLimit = 128;
+ assertThat(
+ Integer.valueOf(integerCacheLimit),
not(sameInstance(Integer.valueOf(integerCacheLimit))));
+
+ int perTagCache = 10;
+ int crossTagCache = 2 * integerCacheLimit;
+ int[] tags = new int[crossTagCache + 8 * perTagCache];
+ for (int i = 0; i < 2 * perTagCache; i++) {
+ tags[crossTagCache + 4 * i] = 1;
+ tags[crossTagCache + 4 * i + 1] = 2;
+ }
+
+ TestUnionValues values = new TestUnionValues(tags);
+ CoGbkResult result = new CoGbkResult(createSchema(5), values,
crossTagCache, perTagCache);
+
+ // More that perTagCache values should be cached for the first tag, as
they came first.
+ List<Object> tag0 = Lists.newArrayList(result.getAll("tag0").iterator());
+ List<Object> tag0again =
Lists.newArrayList(result.getAll("tag0").iterator());
+ assertThat(tag0.get(0), sameInstance(tag0again.get(0)));
+ assertThat(tag0.get(integerCacheLimit),
sameInstance(tag0again.get(integerCacheLimit)));
+ assertThat(tag0.get(crossTagCache - 1),
sameInstance(tag0again.get(crossTagCache - 1)));
+ // However, not all elements are cached.
+ assertThat(tag0.get(tag0.size() - 1),
not(sameInstance(tag0again.get(tag0.size() - 1))));
+
+ // For tag 1 and tag 2, we cache perTagCache elements, plus possibly one
more due to peeking
+ // iterators.
+ List<Object> tag1 = Lists.newArrayList(result.getAll("tag1").iterator());
+ List<Object> tag1again =
Lists.newArrayList(result.getAll("tag1").iterator());
+ assertThat(tag1.get(0), sameInstance(tag1again.get(0)));
+ assertThat(tag1.get(perTagCache - 1),
sameInstance(tag1again.get(perTagCache - 1)));
+ assertThat(tag1.get(perTagCache + 1),
not(sameInstance(tag1again.get(perTagCache + 1))));
+
+ List<Object> tag2 = Lists.newArrayList(result.getAll("tag1").iterator());
+ List<Object> tag2again =
Lists.newArrayList(result.getAll("tag1").iterator());
+ assertThat(tag2.get(0), sameInstance(tag2again.get(0)));
+ assertThat(tag2.get(perTagCache - 1),
sameInstance(tag2again.get(perTagCache - 1)));
+ assertThat(tag2.get(perTagCache + 1),
not(sameInstance(tag2again.get(perTagCache + 1))));
+ }
+
+ @Test
+ public void testSingleTag() {
+ runCrazyIteration(1, TEST_CACHE_SIZE / 2);
+ runCrazyIteration(1, TEST_CACHE_SIZE * 2);
+ runCrazyIteration(2, TEST_CACHE_SIZE / 2);
+ runCrazyIteration(2, TEST_CACHE_SIZE * 2);
+ runCrazyIteration(10, TEST_CACHE_SIZE * 10);
+ }
+
+ @Test
+ public void testTwoTags() {
+ runCrazyIteration(1, TEST_CACHE_SIZE / 2, TEST_CACHE_SIZE / 2);
+ runCrazyIteration(1, TEST_CACHE_SIZE * 2, TEST_CACHE_SIZE * 2);
+ runCrazyIteration(2, TEST_CACHE_SIZE / 2, TEST_CACHE_SIZE / 2);
+ runCrazyIteration(2, TEST_CACHE_SIZE * 2, TEST_CACHE_SIZE * 2);
+ runCrazyIteration(10, TEST_CACHE_SIZE * 10);
+ }
+
+ @Test
+ public void testLargeSmall() {
+ runCrazyIteration(1, TEST_CACHE_SIZE / 2, TEST_CACHE_SIZE * 2);
+ runCrazyIteration(1, TEST_CACHE_SIZE / 2, TEST_CACHE_SIZE * 20);
+ runCrazyIteration(2, TEST_CACHE_SIZE / 2, TEST_CACHE_SIZE * 20);
+ runCrazyIteration(10, TEST_CACHE_SIZE / 2, TEST_CACHE_SIZE * 20);
+ }
+
+ @Test
+ public void testManyTags() {
+ runCrazyIteration(1, 2, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100);
+ runCrazyIteration(2, 2, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100);
+ runCrazyIteration(10, 2, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100);
+ }
+
+ public void runCrazyIteration(int numIterations, int... tagSizes) {
+ for (int trial = 0; trial < 10; trial++) {
+ // Populate this with a constant to reproduce failures.
+ int seed = (int) (Integer.MAX_VALUE * Math.random());
+ LOG.info("Running " + Arrays.toString(tagSizes) + " with seed " + seed);
+ Random random = new Random(seed);
+ List<Integer> tags = new ArrayList<>();
+ for (int tagNum = 0; tagNum < tagSizes.length; tagNum++) {
+ for (int i = 0; i < tagSizes[tagNum]; i++) {
+ tags.add(tagNum);
+ }
+ }
+ Collections.shuffle(tags, random);
+
+ Map<TupleTag<Integer>, List<Integer>> expected = new HashMap<>();
+ for (int tagNum = 0; tagNum < tagSizes.length; tagNum++) {
+ expected.put(new TupleTag<>("tag" + tagNum), new ArrayList<>());
+ }
+ for (int i = 0; i < tags.size(); i++) {
+ expected.get(new TupleTag<>("tag" + tags.get(i))).add(i);
+ }
+
+ List<KV<Integer, TupleTag<Integer>>> callOrder = new ArrayList<>();
+ for (int i = 0; i < numIterations; i++) {
+ for (int tagNum = 0; tagNum < tagSizes.length; tagNum++) {
+ for (int k = 0; k < tagSizes[tagNum]; k++) {
+ callOrder.add(KV.of(i, new TupleTag<>("tag" + tagNum)));
+ }
+ }
+ }
+ Collections.shuffle(callOrder, random);
+
+ Map<KV<Integer, TupleTag<Integer>>, Iterator<Integer>> iters = new
HashMap<>();
+ Map<KV<Integer, TupleTag<Integer>>, List<Integer>> actual = new
HashMap<>();
+ TestUnionValues values = new TestUnionValues(tags.stream().mapToInt(i ->
i).toArray());
+ CoGbkResult coGbkResult =
+ new CoGbkResult(createSchema(tagSizes.length), values, 0,
TEST_CACHE_SIZE);
+
+ for (KV<Integer, TupleTag<Integer>> call : callOrder) {
+ if (!iters.containsKey(call)) {
+ iters.put(call, coGbkResult.getAll(call.getValue()).iterator());
+ actual.put(call, new ArrayList<>());
+ }
+ actual.get(call).add(iters.get(call).next());
+ if (random.nextDouble() < 0.5 / numIterations) {
+ iters.get(call).hasNext();
+ }
+ }
+
+ for (Map.Entry<KV<Integer, TupleTag<Integer>>, List<Integer>> result :
actual.entrySet()) {
+ assertThat(result.getValue(),
contains(expected.get(result.getKey().getValue()).toArray()));
+ }
+ }
+ }
+
private CoGbkResultSchema createSchema(int size) {
List<TupleTag<?>> tags = new ArrayList<>();
for (int i = 0; i < size; i++) {