http://git-wip-us.apache.org/repos/asf/flink/blob/05d2138f/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/operators/windows/KeyMap.java ---------------------------------------------------------------------- diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/operators/windows/KeyMap.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/operators/windows/KeyMap.java deleted file mode 100644 index 6e2d75e..0000000 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/operators/windows/KeyMap.java +++ /dev/null @@ -1,651 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.runtime.operators.windows; - -import org.apache.flink.api.common.functions.ReduceFunction; -import org.apache.flink.runtime.util.MathUtils; - -import java.util.Arrays; -import java.util.Comparator; -import java.util.Iterator; -import java.util.NoSuchElementException; - -/** - * A special Hash Map implementation that can be traversed efficiently in sync with other - * hash maps. - * <p> - * The differences between this hash map and Java's "java.util.HashMap" are: - * <ul> - * <li>A different hashing scheme. This implementation uses extensible hashing, meaning that - * each hash table growth takes one more lower hash code bit into account, and values that where - * formerly in the same bucket will afterwards be in the two adjacent buckets.</li> - * <li>This allows an efficient traversal of multiple hash maps together, even though the maps are - * of different sizes.</li> - * <li>The map offers functions such as "putIfAbsent()" and "putOrAggregate()"</li> - * <li>The map supports no removal/shrinking.</li> - * </ul> - */ -public class KeyMap<K, V> implements Iterable<KeyMap.Entry<K, V>> { - - /** The minimum table capacity, 64 entries */ - private static final int MIN_CAPACITY = 0x40; - - /** The maximum possible table capacity, the largest positive power of - * two in the 32bit signed integer value range */ - private static final int MAX_CAPACITY = 0x40000000; - - /** The number of bits used for table addressing when table is at max capacity */ - private static final int FULL_BIT_RANGE = MathUtils.log2strict(MAX_CAPACITY); - - // ------------------------------------------------------------------------ - - /** The hash index, as an array of entries */ - private Entry<K, V>[] table; - - /** The number of bits by which the hash code is shifted right, to find the bucket */ - private int shift; - - /** The number of elements in the hash table */ - private int numElements; - - /** The number of elements above which the hash table needs to grow */ - private int rehashThreshold; - - /** The base-2 logarithm of the table capacity */ - private int log2size; - - // ------------------------------------------------------------------------ - - /** - * Creates a new hash table with the default initial capacity. - */ - public KeyMap() { - this(0); - } - - /** - * Creates a new table with a capacity tailored to the given expected number of elements. - * - * @param expectedNumberOfElements The number of elements to tailor the capacity to. - */ - public KeyMap(int expectedNumberOfElements) { - if (expectedNumberOfElements < 0) { - throw new IllegalArgumentException("Invalid capacity: " + expectedNumberOfElements); - } - - // round up to the next power or two - // guard against too small capacity and integer overflows - int capacity = Integer.highestOneBit(expectedNumberOfElements) << 1; - capacity = capacity >= 0 ? Math.max(MIN_CAPACITY, capacity) : MAX_CAPACITY; - - // this also acts as a sanity check - log2size = MathUtils.log2strict(capacity); - shift = FULL_BIT_RANGE - log2size; - table = allocateTable(capacity); - rehashThreshold = getRehashThreshold(capacity); - } - - // ------------------------------------------------------------------------ - // Gets and Puts - // ------------------------------------------------------------------------ - - /** - * Inserts the given value, mapped under the given key. If the table already contains a value for - * the key, the value is replaced and returned. If no value is contained, yet, the function - * returns null. - * - * @param key The key to insert. - * @param value The value to insert. - * @return The previously mapped value for the key, or null, if no value was mapped for the key. - * - * @throws java.lang.NullPointerException Thrown, if the key is null. - */ - public final V put(K key, V value) { - final int hash = hash(key); - final int slot = indexOf (hash); - - // search the chain from the slot - for (Entry<K, V> e = table[slot]; e != null; e = e.next) { - Object k; - if (e.hashCode == hash && ((k = e.key) == key || key.equals(k))) { - // found match - V old = e.value; - e.value = value; - return old; - } - } - - // no match, insert a new value - insertNewEntry(hash, key, value, slot); - return null; - } - - /** - * Inserts a value for the given key, if no value is yet contained for that key. Otherwise, - * returns the value currently contained for the key. - * <p> - * The value that is inserted in case that the key is not contained, yet, is lazily created - * using the given factory. - * - * @param key The key to insert. - * @param factory The factory that produces the value, if no value is contained, yet, for the key. - * @return The value in the map after this operation (either the previously contained value, or the - * newly created value). - * - * @throws java.lang.NullPointerException Thrown, if the key is null. - */ - public final V putIfAbsent(K key, LazyFactory<V> factory) { - final int hash = hash(key); - final int slot = indexOf(hash); - - // search the chain from the slot - for (Entry<K, V> entry = table[slot]; entry != null; entry = entry.next) { - if (entry.hashCode == hash && entry.key.equals(key)) { - // found match - return entry.value; - } - } - - // no match, insert a new value - V value = factory.create(); - insertNewEntry(hash, key, value, slot); - - // return the created value - return value; - } - - /** - * Inserts or aggregates a value into the hash map. If the hash map does not yet contain the key, - * this method inserts the value. If the table already contains the key (and a value) this - * method will use the given ReduceFunction function to combine the existing value and the - * given value to a new value, and store that value for the key. - * - * @param key The key to map the value. - * @param value The new value to insert, or aggregate with the existing value. - * @param aggregator The aggregator to use if a value is already contained. - * - * @return The value in the map after this operation: Either the given value, or the aggregated value. - * - * @throws java.lang.NullPointerException Thrown, if the key is null. - * @throws Exception The method forwards exceptions from the aggregation function. - */ - public final V putOrAggregate(K key, V value, ReduceFunction<V> aggregator) throws Exception { - final int hash = hash(key); - final int slot = indexOf(hash); - - // search the chain from the slot - for (Entry<K, V> entry = table[slot]; entry != null; entry = entry.next) { - if (entry.hashCode == hash && entry.key.equals(key)) { - // found match - entry.value = aggregator.reduce(entry.value, value); - return entry.value; - } - } - - // no match, insert a new value - insertNewEntry(hash, key, value, slot); - // return the original value - return value; - } - - /** - * Looks up the value mapped under the given key. Returns null if no value is mapped under this key. - * - * @param key The key to look up. - * @return The value associated with the key, or null, if no value is found for the key. - * - * @throws java.lang.NullPointerException Thrown, if the key is null. - */ - public V get(K key) { - final int hash = hash(key); - final int slot = indexOf(hash); - - // search the chain from the slot - for (Entry<K, V> entry = table[slot]; entry != null; entry = entry.next) { - if (entry.hashCode == hash && entry.key.equals(key)) { - return entry.value; - } - } - - // not found - return null; - } - - private void insertNewEntry(int hashCode, K key, V value, int position) { - Entry<K,V> e = table[position]; - table[position] = new Entry<>(key, value, hashCode, e); - numElements++; - - // rehash if necessary - if (numElements > rehashThreshold) { - growTable(); - } - } - - private int indexOf(int hashCode) { - return (hashCode >> shift) & (table.length - 1); - } - - /** - * Creates an iterator over the entries of this map. - * - * @return An iterator over the entries of this map. - */ - @Override - public Iterator<Entry<K, V>> iterator() { - return new Iterator<Entry<K, V>>() { - - private final Entry<K, V>[] tab = KeyMap.this.table; - - private Entry<K, V> nextEntry; - - private int nextPos = 0; - - @Override - public boolean hasNext() { - if (nextEntry != null) { - return true; - } - else { - while (nextPos < tab.length) { - Entry<K, V> e = tab[nextPos++]; - if (e != null) { - nextEntry = e; - return true; - } - } - return false; - } - } - - @Override - public Entry<K, V> next() { - if (nextEntry != null || hasNext()) { - Entry<K, V> e = nextEntry; - nextEntry = nextEntry.next; - return e; - } - else { - throw new NoSuchElementException(); - } - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); - } - }; - } - - // ------------------------------------------------------------------------ - // Properties - // ------------------------------------------------------------------------ - - /** - * Gets the number of elements currently in the map. - * @return The number of elements currently in the map. - */ - public int size() { - return numElements; - } - - /** - * Checks whether the map is empty. - * @return True, if the map is empty, false otherwise. - */ - public boolean isEmpty() { - return numElements == 0; - } - - /** - * Gets the current table capacity, i.e., the number of slots in the hash table, without - * and overflow chaining. - * @return The number of slots in the hash table. - */ - public int getCurrentTableCapacity() { - return table.length; - } - - /** - * Gets the base-2 logarithm of the hash table capacity, as returned by - * {@link #getCurrentTableCapacity()}. - * - * @return The base-2 logarithm of the hash table capacity. - */ - public int getLog2TableCapacity() { - return log2size; - } - - public int getRehashThreshold() { - return rehashThreshold; - } - - public int getShift() { - return shift; - } - - // ------------------------------------------------------------------------ - // Utilities - // ------------------------------------------------------------------------ - - @SuppressWarnings("unchecked") - private Entry<K, V>[] allocateTable(int numElements) { - return (Entry<K, V>[]) new Entry<?, ?>[numElements]; - } - - private void growTable() { - final int newSize = table.length << 1; - - // only grow if there is still space to grow the table - if (newSize > 0) { - final Entry<K, V>[] oldTable = this.table; - final Entry<K, V>[] newTable = allocateTable(newSize); - - final int newShift = shift - 1; - final int newMask = newSize - 1; - - // go over all slots from the table. since we hash to adjacent positions in - // the new hash table, this is actually cache efficient - for (Entry<K, V> entry : oldTable) { - // traverse the chain for each slot - while (entry != null) { - final int newPos = (entry.hashCode >> newShift) & newMask; - Entry<K, V> nextEntry = entry.next; - entry.next = newTable[newPos]; - newTable[newPos] = entry; - entry = nextEntry; - } - } - - this.table = newTable; - this.shift = newShift; - this.rehashThreshold = getRehashThreshold(newSize); - this.log2size += 1; - } - } - - private static int hash(Object key) { - int code = key.hashCode(); - - // we need a strong hash function that generates diverse upper bits - // this hash function is more expensive than the "scramble" used by "java.util.HashMap", - // but required for this sort of hash table - code = (code + 0x7ed55d16) + (code << 12); - code = (code ^ 0xc761c23c) ^ (code >>> 19); - code = (code + 0x165667b1) + (code << 5); - code = (code + 0xd3a2646c) ^ (code << 9); - code = (code + 0xfd7046c5) + (code << 3); - return (code ^ 0xb55a4f09) ^ (code >>> 16); - } - - private static int getRehashThreshold(int capacity) { - // divide before multiply, to avoid overflow - return capacity / 4 * 3; - } - - // ------------------------------------------------------------------------ - // Testing Utilities - // ------------------------------------------------------------------------ - - /** - * For testing only: Actively counts the number of entries, rather than using the - * counter variable. This method has linear complexity, rather than constant. - * - * @return The counted number of entries. - */ - int traverseAndCountElements() { - int num = 0; - - for (Entry<?, ?> entry : table) { - while (entry != null) { - num++; - entry = entry.next; - } - } - - return num; - } - - /** - * For testing only: Gets the length of the longest overflow chain. - * This method has linear complexity. - * - * @return The length of the longest overflow chain. - */ - int getLongestChainLength() { - int maxLen = 0; - - for (Entry<?, ?> entry : table) { - int thisLen = 0; - while (entry != null) { - thisLen++; - entry = entry.next; - } - maxLen = Math.max(maxLen, thisLen); - } - - return maxLen; - } - - // ------------------------------------------------------------------------ - - /** - * An entry in the hash table. - * - * @param <K> Type of the key. - * @param <V> Type of the value. - */ - public static final class Entry<K, V> { - - final K key; - final int hashCode; - - V value; - Entry<K, V> next; - long touchedTag; - - Entry(K key, V value, int hashCode, Entry<K, V> next) { - this.key = key; - this.value = value; - this.next = next; - this.hashCode = hashCode; - } - - public K getKey() { - return key; - } - - public V getValue() { - return value; - } - } - - // ------------------------------------------------------------------------ - - /** - * Performs a traversal about logical the multi-map that results from the union of the - * given maps. This method does not actually build a union of the map, but traverses the hash maps - * together. - * - * @param maps The array uf maps whose union should be traversed. - * @param visitor The visitor that is called for each key and all values. - * @param touchedTag A tag that is used to mark elements that have been touched in this specific - * traversal. Each successive traversal should supply a larger value for this - * tag than the previous one. - * - * @param <K> The type of the map's key. - * @param <V> The type of the map's value. - */ - public static <K, V> void traverseMaps( - final KeyMap<K, V>[] maps, - final TraversalEvaluator<K, V> visitor, - final long touchedTag) - throws Exception - { - // we need to work on the maps in descending size - Arrays.sort(maps, CapacityDescendingComparator.INSTANCE); - - final int[] shifts = new int[maps.length]; - final int[] lowBitsMask = new int[maps.length]; - final int numSlots = maps[0].table.length; - final int numTables = maps.length; - - // figure out how much each hash table collapses the entries - for (int i = 0; i < numTables; i++) { - shifts[i] = maps[0].log2size - maps[i].log2size; - lowBitsMask[i] = (1 << shifts[i]) - 1; - } - - // go over all slots (based on the largest hash table) - for (int pos = 0; pos < numSlots; pos++) { - - // for each slot, go over all tables, until the table does not have that slot any more - // for tables where multiple slots collapse into one, we visit that one when we process the - // latest of all slots that collapse to that one - int mask; - for (int rootTable = 0; - rootTable < numTables && ((mask = lowBitsMask[rootTable]) & pos) == mask; - rootTable++) - { - // use that table to gather keys and start collecting keys from the following tables - // go over all entries of that slot in the table - Entry<K, V> entry = maps[rootTable].table[pos >> shifts[rootTable]]; - while (entry != null) { - // take only entries that have not been collected as part of other tables - if (entry.touchedTag < touchedTag) { - entry.touchedTag = touchedTag; - - final K key = entry.key; - final int hashCode = entry.hashCode; - visitor.startNewKey(key); - visitor.nextValue(entry.value); - - addEntriesFromChain(entry.next, visitor, key, touchedTag, hashCode); - - // go over the other hash tables and collect their entries for the key - for (int followupTable = rootTable + 1; followupTable < numTables; followupTable++) { - Entry<K, V> followupEntry = maps[followupTable].table[pos >> shifts[followupTable]]; - if (followupEntry != null) { - addEntriesFromChain(followupEntry, visitor, key, touchedTag, hashCode); - } - } - - visitor.keyDone(); - } - - entry = entry.next; - } - } - } - } - - private static <K, V> void addEntriesFromChain( - Entry<K, V> entry, - TraversalEvaluator<K, V> visitor, - K key, - long touchedTag, - int hashCode) throws Exception - { - while (entry != null) { - if (entry.touchedTag < touchedTag && entry.hashCode == hashCode && entry.key.equals(key)) { - entry.touchedTag = touchedTag; - visitor.nextValue(entry.value); - } - entry = entry.next; - } - } - - // ------------------------------------------------------------------------ - - /** - * Comparator that defines a descending order on maps depending on their table capacity - * and number of elements. - */ - static final class CapacityDescendingComparator implements Comparator<KeyMap<?, ?>> { - - static final CapacityDescendingComparator INSTANCE = new CapacityDescendingComparator(); - - private CapacityDescendingComparator() {} - - - @Override - public int compare(KeyMap<?, ?> o1, KeyMap<?, ?> o2) { - // this sorts descending - int cmp = o2.getLog2TableCapacity() - o1.getLog2TableCapacity(); - if (cmp != 0) { - return cmp; - } - else { - return o2.size() - o1.size(); - } - } - } - - // ------------------------------------------------------------------------ - - /** - * A factory for lazy/on-demand instantiation of values. - * - * @param <V> The type created by the factory. - */ - public static interface LazyFactory<V> { - - /** - * The factory method; creates the value. - * @return The value. - */ - V create(); - } - - // ------------------------------------------------------------------------ - - /** - * A visitor for a traversal over the union of multiple hash maps. The visitor is - * called for each key in the union of the maps and all values associated with that key - * (one per map, but multiple across maps). - * - * @param <K> The type of the key. - * @param <V> The type of the value. - */ - public static interface TraversalEvaluator<K, V> { - - /** - * Called whenever the traversal starts with a new key. - * - * @param key The key traversed. - * @throws Exception Method forwards all exceptions. - */ - void startNewKey(K key) throws Exception; - - /** - * Called for each value found for the current key. - * - * @param value The next value. - * @throws Exception Method forwards all exceptions. - */ - void nextValue(V value) throws Exception; - - /** - * Called when the traversal for the current key is complete. - * - * @throws Exception Method forwards all exceptions. - */ - void keyDone() throws Exception; - } -}
http://git-wip-us.apache.org/repos/asf/flink/blob/05d2138f/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/operators/windows/PolicyToOperator.java ---------------------------------------------------------------------- diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/operators/windows/PolicyToOperator.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/operators/windows/PolicyToOperator.java deleted file mode 100644 index 9d06ef5..0000000 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/operators/windows/PolicyToOperator.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.streaming.runtime.operators.windows; - -import org.apache.flink.api.common.functions.Function; -import org.apache.flink.api.common.functions.ReduceFunction; -import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.streaming.api.functions.windows.KeyedWindowFunction; -import org.apache.flink.streaming.api.operators.OneInputStreamOperator; -import org.apache.flink.streaming.api.windowing.windowpolicy.EventTime; -import org.apache.flink.streaming.api.windowing.windowpolicy.ProcessingTime; -import org.apache.flink.streaming.api.windowing.windowpolicy.WindowPolicy; - -/** - * This class implements the conversion from window policies to concrete operator - * implementations. - */ -public class PolicyToOperator { - - /** - * Entry point to create an operator for the given window policies and the window function. - */ - public static <IN, OUT, KEY> OneInputStreamOperator<IN, OUT> createOperatorForPolicies( - WindowPolicy window, WindowPolicy slide, Function function, KeySelector<IN, KEY> keySelector) - { - if (window == null || function == null) { - throw new NullPointerException(); - } - - // -- case 1: both policies are processing time policies - if (window instanceof ProcessingTime && (slide == null || slide instanceof ProcessingTime)) { - final long windowLength = ((ProcessingTime) window).toMilliseconds(); - final long windowSlide = slide == null ? windowLength : ((ProcessingTime) slide).toMilliseconds(); - - if (function instanceof ReduceFunction) { - @SuppressWarnings("unchecked") - ReduceFunction<IN> reducer = (ReduceFunction<IN>) function; - - @SuppressWarnings("unchecked") - OneInputStreamOperator<IN, OUT> op = (OneInputStreamOperator<IN, OUT>) - new AggregatingProcessingTimeWindowOperator<KEY, IN>( - reducer, keySelector, windowLength, windowSlide); - return op; - } - else if (function instanceof KeyedWindowFunction) { - @SuppressWarnings("unchecked") - KeyedWindowFunction<IN, OUT, KEY> wf = (KeyedWindowFunction<IN, OUT, KEY>) function; - - return new AccumulatingProcessingTimeWindowOperator<KEY, IN, OUT>( - wf, keySelector, windowLength, windowSlide); - } - } - - // -- case 2: both policies are event time policies - if (window instanceof EventTime && (slide == null || slide instanceof EventTime)) { - // add event time implementation - } - - throw new UnsupportedOperationException("The windowing mechanism does not yet support " + window.toString(slide)); - } - - // ------------------------------------------------------------------------ - - /** Don't instantiate */ - private PolicyToOperator() {} -} http://git-wip-us.apache.org/repos/asf/flink/blob/05d2138f/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/operators/windows/package-info.java ---------------------------------------------------------------------- diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/operators/windows/package-info.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/operators/windows/package-info.java deleted file mode 100644 index 63ed470..0000000 --- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/operators/windows/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * This package contains the operators that implement the various window operations - * on data streams. - */ -package org.apache.flink.streaming.runtime.operators.windows; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/05d2138f/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java ---------------------------------------------------------------------- diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java new file mode 100644 index 0000000..bcf02c5 --- /dev/null +++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java @@ -0,0 +1,547 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.operators.windowing; + +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.windows.KeyedWindowFunction; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.operators.Triggerable; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.streaming.runtime.tasks.StreamingRuntimeContext; + +import org.apache.flink.util.Collector; +import org.junit.After; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +import static org.mockito.Mockito.*; +import static org.junit.Assert.*; + +@SuppressWarnings("serial") +public class AccumulatingAlignedProcessingTimeWindowOperatorTest { + + @SuppressWarnings("unchecked") + private final KeyedWindowFunction<String, String, String> mockFunction = mock(KeyedWindowFunction.class); + + @SuppressWarnings("unchecked") + private final KeySelector<String, String> mockKeySelector = mock(KeySelector.class); + + private final KeySelector<Integer, Integer> identitySelector = new KeySelector<Integer, Integer>() { + @Override + public Integer getKey(Integer value) { + return value; + } + }; + + private final KeyedWindowFunction<Integer, Integer, Integer> validatingIdentityFunction = + new KeyedWindowFunction<Integer, Integer, Integer>() + { + @Override + public void evaluate(Integer key, Iterable<Integer> values, Collector<Integer> out) { + for (Integer val : values) { + assertEquals(key, val); + out.collect(val); + } + } + }; + + // ------------------------------------------------------------------------ + + @After + public void checkNoTriggerThreadsRunning() { + // make sure that all the threads we trigger are shut down + long deadline = System.currentTimeMillis() + 5000; + while (StreamTask.TRIGGER_THREAD_GROUP.activeCount() > 0 && System.currentTimeMillis() < deadline) { + try { + Thread.sleep(10); + } + catch (InterruptedException ignored) {} + } + + assertTrue("Not all trigger threads where properly shut down", + StreamTask.TRIGGER_THREAD_GROUP.activeCount() == 0); + } + + // ------------------------------------------------------------------------ + + @Test + public void testInvalidParameters() { + try { + assertInvalidParameter(-1L, -1L); + assertInvalidParameter(10000L, -1L); + assertInvalidParameter(-1L, 1000L); + assertInvalidParameter(1000L, 2000L); + + // actual internal slide is too low here: + assertInvalidParameter(1000L, 999L); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testWindowSizeAndSlide() { + try { + AbstractAlignedProcessingTimeWindowOperator<String, String, String> op; + + op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 5000, 1000); + assertEquals(5000, op.getWindowSize()); + assertEquals(1000, op.getWindowSlide()); + assertEquals(1000, op.getPaneSize()); + assertEquals(5, op.getNumPanesPerWindow()); + + op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1000, 1000); + assertEquals(1000, op.getWindowSize()); + assertEquals(1000, op.getWindowSlide()); + assertEquals(1000, op.getPaneSize()); + assertEquals(1, op.getNumPanesPerWindow()); + + op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1500, 1000); + assertEquals(1500, op.getWindowSize()); + assertEquals(1000, op.getWindowSlide()); + assertEquals(500, op.getPaneSize()); + assertEquals(3, op.getNumPanesPerWindow()); + + op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1200, 1100); + assertEquals(1200, op.getWindowSize()); + assertEquals(1100, op.getWindowSlide()); + assertEquals(100, op.getPaneSize()); + assertEquals(12, op.getNumPanesPerWindow()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testWindowTriggerTimeAlignment() { + try { + @SuppressWarnings("unchecked") + final Output<StreamRecord<String>> mockOut = mock(Output.class); + + final StreamingRuntimeContext mockContext = mock(StreamingRuntimeContext.class); + when(mockContext.getTaskName()).thenReturn("Test task name"); + + AbstractAlignedProcessingTimeWindowOperator<String, String, String> op; + + op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 5000, 1000); + op.setup(mockOut, mockContext); + op.open(new Configuration()); + assertTrue(op.getNextSlideTime() % 1000 == 0); + assertTrue(op.getNextEvaluationTime() % 1000 == 0); + op.dispose(); + + op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1000, 1000); + op.setup(mockOut, mockContext); + op.open(new Configuration()); + assertTrue(op.getNextSlideTime() % 1000 == 0); + assertTrue(op.getNextEvaluationTime() % 1000 == 0); + op.dispose(); + + op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1500, 1000); + op.setup(mockOut, mockContext); + op.open(new Configuration()); + assertTrue(op.getNextSlideTime() % 500 == 0); + assertTrue(op.getNextEvaluationTime() % 1000 == 0); + op.dispose(); + + op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1200, 1100); + op.setup(mockOut, mockContext); + op.open(new Configuration()); + assertTrue(op.getNextSlideTime() % 100 == 0); + assertTrue(op.getNextEvaluationTime() % 1100 == 0); + op.dispose(); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testTumblingWindow() { + try { + final int windowSize = 50; + final CollectingOutput<Integer> out = new CollectingOutput<>(windowSize); + + final StreamingRuntimeContext mockContext = mock(StreamingRuntimeContext.class); + when(mockContext.getTaskName()).thenReturn("Test task name"); + + // tumbling window that triggers every 20 milliseconds + AbstractAlignedProcessingTimeWindowOperator<Integer, Integer, Integer> op = + new AccumulatingProcessingTimeWindowOperator<>( + validatingIdentityFunction, identitySelector, windowSize, windowSize); + + op.setup(out, mockContext); + op.open(new Configuration()); + + final int numElements = 1000; + + for (int i = 0; i < numElements; i++) { + op.processElement(new StreamRecord<Integer>(i)); + Thread.sleep(1); + } + + op.close(); + op.dispose(); + + // get and verify the result + List<Integer> result = out.getElements(); + assertEquals(numElements, result.size()); + + Collections.sort(result); + for (int i = 0; i < numElements; i++) { + assertEquals(i, result.get(i).intValue()); + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testSlidingWindow() { + try { + final CollectingOutput<Integer> out = new CollectingOutput<>(50); + + final StreamingRuntimeContext mockContext = mock(StreamingRuntimeContext.class); + when(mockContext.getTaskName()).thenReturn("Test task name"); + + // tumbling window that triggers every 20 milliseconds + AbstractAlignedProcessingTimeWindowOperator<Integer, Integer, Integer> op = + new AccumulatingProcessingTimeWindowOperator<>(validatingIdentityFunction, identitySelector, 150, 50); + + op.setup(out, mockContext); + op.open(new Configuration()); + + final int numElements = 1000; + + for (int i = 0; i < numElements; i++) { + op.processElement(new StreamRecord<Integer>(i)); + Thread.sleep(1); + } + + op.close(); + op.dispose(); + + // get and verify the result + List<Integer> result = out.getElements(); + + // if we kept this running, each element would be in the result three times (for each slide). + // we are closing the window before the final panes are through three times, so we may have less + // elements. + if (result.size() < numElements || result.size() > 3 * numElements) { + fail("Wrong number of results: " + result.size()); + } + + Collections.sort(result); + int lastNum = -1; + int lastCount = -1; + + for (int num : result) { + if (num == lastNum) { + lastCount++; + assertTrue(lastCount <= 3); + } + else { + lastNum = num; + lastCount = 1; + } + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testTumblingWindowSingleElements() { + final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); + + try { + final CollectingOutput<Integer> out = new CollectingOutput<>(50); + + final StreamingRuntimeContext mockContext = mock(StreamingRuntimeContext.class); + when(mockContext.getTaskName()).thenReturn("Test task name"); + + final Object lock = new Object(); + + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) throws Throwable { + final Long timestamp = (Long) invocationOnMock.getArguments()[0]; + final Triggerable target = (Triggerable) invocationOnMock.getArguments()[1]; + timerService.schedule( + new Callable<Object>() { + @Override + public Object call() throws Exception { + synchronized (lock) { + target.trigger(timestamp); + } + return null; + } + }, + timestamp - System.currentTimeMillis(), + TimeUnit.MILLISECONDS); + return null; + } + }).when(mockContext).registerTimer(anyLong(), any(Triggerable.class)); + + // tumbling window that triggers every 20 milliseconds + AbstractAlignedProcessingTimeWindowOperator<Integer, Integer, Integer> op = + new AccumulatingProcessingTimeWindowOperator<>(validatingIdentityFunction, identitySelector, 50, 50); + + op.setup(out, mockContext); + op.open(new Configuration()); + + synchronized (lock) { + op.processElement(new StreamRecord<Integer>(1)); + op.processElement(new StreamRecord<Integer>(2)); + } + out.waitForNElements(2, 60000); + + synchronized (lock) { + op.processElement(new StreamRecord<Integer>(3)); + op.processElement(new StreamRecord<Integer>(4)); + op.processElement(new StreamRecord<Integer>(5)); + } + out.waitForNElements(5, 60000); + + synchronized (lock) { + op.processElement(new StreamRecord<Integer>(6)); + } + out.waitForNElements(6, 60000); + + List<Integer> result = out.getElements(); + assertEquals(6, result.size()); + + Collections.sort(result); + assertEquals(Arrays.asList(1, 2, 3, 4, 5, 6), result); + + op.close(); + op.dispose(); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } finally { + timerService.shutdown(); + } + } + + @Test + public void testSlidingWindowSingleElements() { + final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); + + try { + final CollectingOutput<Integer> out = new CollectingOutput<>(50); + + final StreamingRuntimeContext mockContext = mock(StreamingRuntimeContext.class); + when(mockContext.getTaskName()).thenReturn("Test task name"); + + final Object lock = new Object(); + + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) throws Throwable { + final Long timestamp = (Long) invocationOnMock.getArguments()[0]; + final Triggerable target = (Triggerable) invocationOnMock.getArguments()[1]; + timerService.schedule( + new Callable<Object>() { + @Override + public Object call() throws Exception { + synchronized (lock) { + target.trigger(timestamp); + } + return null; + } + }, + timestamp - System.currentTimeMillis(), + TimeUnit.MILLISECONDS); + return null; + } + }).when(mockContext).registerTimer(anyLong(), any(Triggerable.class)); + + // tumbling window that triggers every 20 milliseconds + AbstractAlignedProcessingTimeWindowOperator<Integer, Integer, Integer> op = + new AccumulatingProcessingTimeWindowOperator<>(validatingIdentityFunction, identitySelector, 150, 50); + + op.setup(out, mockContext); + op.open(new Configuration()); + + synchronized (lock) { + op.processElement(new StreamRecord<Integer>(1)); + op.processElement(new StreamRecord<Integer>(2)); + } + + // each element should end up in the output three times + // wait until the elements have arrived 6 times in the output + out.waitForNElements(6, 120000); + + List<Integer> result = out.getElements(); + assertEquals(6, result.size()); + + Collections.sort(result); + assertEquals(Arrays.asList(1, 1, 1, 2, 2, 2), result); + + op.close(); + op.dispose(); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } finally { + timerService.shutdown(); + } + } + + @Test + public void testEmitTrailingDataOnClose() { + try { + final CollectingOutput<Integer> out = new CollectingOutput<>(); + + final StreamingRuntimeContext mockContext = mock(StreamingRuntimeContext.class); + when(mockContext.getTaskName()).thenReturn("Test task name"); + + // the operator has a window time that is so long that it will not fire in this test + final long oneYear = 365L * 24 * 60 * 60 * 1000; + AbstractAlignedProcessingTimeWindowOperator<Integer, Integer, Integer> op = + new AccumulatingProcessingTimeWindowOperator<>(validatingIdentityFunction, identitySelector, + oneYear, oneYear); + + op.setup(out, mockContext); + op.open(new Configuration()); + + List<Integer> data = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + for (Integer i : data) { + op.processElement(new StreamRecord<Integer>(i)); + } + + op.close(); + op.dispose(); + + // get and verify the result + List<Integer> result = out.getElements(); + Collections.sort(result); + assertEquals(data, result); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testPropagateExceptionsFromClose() { + try { + final CollectingOutput<Integer> out = new CollectingOutput<>(); + + final StreamingRuntimeContext mockContext = mock(StreamingRuntimeContext.class); + when(mockContext.getTaskName()).thenReturn("Test task name"); + + KeyedWindowFunction<Integer, Integer, Integer> failingFunction = new FailingFunction(100); + + // the operator has a window time that is so long that it will not fire in this test + final long hundredYears = 100L * 365 * 24 * 60 * 60 * 1000; + AbstractAlignedProcessingTimeWindowOperator<Integer, Integer, Integer> op = + new AccumulatingProcessingTimeWindowOperator<>( + failingFunction, identitySelector, hundredYears, hundredYears); + + op.setup(out, mockContext); + op.open(new Configuration()); + + for (int i = 0; i < 150; i++) { + op.processElement(new StreamRecord<Integer>(i)); + } + + try { + op.close(); + fail("This should fail with an exception"); + } + catch (Exception e) { + assertTrue( + e.getMessage().contains("Artificial Test Exception") || + (e.getCause() != null && e.getCause().getMessage().contains("Artificial Test Exception"))); + } + + op.dispose(); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // ------------------------------------------------------------------------ + + private void assertInvalidParameter(long windowSize, long windowSlide) { + try { + new AccumulatingProcessingTimeWindowOperator<String, String, String>( + mockFunction, mockKeySelector, windowSize, windowSlide); + fail("This should fail with an IllegalArgumentException"); + } + catch (IllegalArgumentException e) { + // expected + } + catch (Exception e) { + fail("Wrong exception. Expected IllegalArgumentException but found " + e.getClass().getSimpleName()); + } + } + + // ------------------------------------------------------------------------ + + private static class FailingFunction implements KeyedWindowFunction<Integer, Integer, Integer> { + + private final int failAfterElements; + + private int numElements; + + FailingFunction(int failAfterElements) { + this.failAfterElements = failAfterElements; + } + + @Override + public void evaluate(Integer integer, Iterable<Integer> values, Collector<Integer> out) throws Exception { + for (Integer i : values) { + out.collect(i); + numElements++; + + if (numElements >= failAfterElements) { + throw new Exception("Artificial Test Exception"); + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/05d2138f/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java ---------------------------------------------------------------------- diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java new file mode 100644 index 0000000..7ad9dd4 --- /dev/null +++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java @@ -0,0 +1,550 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.operators.windowing; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.operators.Triggerable; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.streaming.runtime.tasks.StreamingRuntimeContext; + +import org.junit.After; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@SuppressWarnings("serial") +public class AggregatingAlignedProcessingTimeWindowOperatorTest { + + @SuppressWarnings("unchecked") + private final ReduceFunction<String> mockFunction = mock(ReduceFunction.class); + + @SuppressWarnings("unchecked") + private final KeySelector<String, String> mockKeySelector = mock(KeySelector.class); + + private final KeySelector<Integer, Integer> identitySelector = new KeySelector<Integer, Integer>() { + @Override + public Integer getKey(Integer value) { + return value; + } + }; + + private final ReduceFunction<Integer> sumFunction = new ReduceFunction<Integer>() { + @Override + public Integer reduce(Integer value1, Integer value2) { + return value1 + value2; + } + }; + + // ------------------------------------------------------------------------ + + @After + public void checkNoTriggerThreadsRunning() { + // make sure that all the threads we trigger are shut down + long deadline = System.currentTimeMillis() + 5000; + while (StreamTask.TRIGGER_THREAD_GROUP.activeCount() > 0 && System.currentTimeMillis() < deadline) { + try { + Thread.sleep(10); + } + catch (InterruptedException ignored) {} + } + + assertTrue("Not all trigger threads where properly shut down", + StreamTask.TRIGGER_THREAD_GROUP.activeCount() == 0); + } + + // ------------------------------------------------------------------------ + + @Test + public void testInvalidParameters() { + try { + assertInvalidParameter(-1L, -1L); + assertInvalidParameter(10000L, -1L); + assertInvalidParameter(-1L, 1000L); + assertInvalidParameter(1000L, 2000L); + + // actual internal slide is too low here: + assertInvalidParameter(1000L, 999L); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testWindowSizeAndSlide() { + try { + AbstractAlignedProcessingTimeWindowOperator<String, String, String> op; + + op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 5000, 1000); + assertEquals(5000, op.getWindowSize()); + assertEquals(1000, op.getWindowSlide()); + assertEquals(1000, op.getPaneSize()); + assertEquals(5, op.getNumPanesPerWindow()); + + op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1000, 1000); + assertEquals(1000, op.getWindowSize()); + assertEquals(1000, op.getWindowSlide()); + assertEquals(1000, op.getPaneSize()); + assertEquals(1, op.getNumPanesPerWindow()); + + op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1500, 1000); + assertEquals(1500, op.getWindowSize()); + assertEquals(1000, op.getWindowSlide()); + assertEquals(500, op.getPaneSize()); + assertEquals(3, op.getNumPanesPerWindow()); + + op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1200, 1100); + assertEquals(1200, op.getWindowSize()); + assertEquals(1100, op.getWindowSlide()); + assertEquals(100, op.getPaneSize()); + assertEquals(12, op.getNumPanesPerWindow()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testWindowTriggerTimeAlignment() { + try { + @SuppressWarnings("unchecked") + final Output<StreamRecord<String>> mockOut = mock(Output.class); + + final StreamingRuntimeContext mockContext = mock(StreamingRuntimeContext.class); + when(mockContext.getTaskName()).thenReturn("Test task name"); + + AbstractAlignedProcessingTimeWindowOperator<String, String, String> op; + + op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 5000, 1000); + op.setup(mockOut, mockContext); + op.open(new Configuration()); + assertTrue(op.getNextSlideTime() % 1000 == 0); + assertTrue(op.getNextEvaluationTime() % 1000 == 0); + op.dispose(); + + op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1000, 1000); + op.setup(mockOut, mockContext); + op.open(new Configuration()); + assertTrue(op.getNextSlideTime() % 1000 == 0); + assertTrue(op.getNextEvaluationTime() % 1000 == 0); + op.dispose(); + + op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1500, 1000); + op.setup(mockOut, mockContext); + op.open(new Configuration()); + assertTrue(op.getNextSlideTime() % 500 == 0); + assertTrue(op.getNextEvaluationTime() % 1000 == 0); + op.dispose(); + + op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1200, 1100); + op.setup(mockOut, mockContext); + op.open(new Configuration()); + assertTrue(op.getNextSlideTime() % 100 == 0); + assertTrue(op.getNextEvaluationTime() % 1100 == 0); + op.dispose(); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testTumblingWindowUniqueElements() { + try { + final int windowSize = 50; + final CollectingOutput<Integer> out = new CollectingOutput<>(windowSize); + + final StreamingRuntimeContext mockContext = mock(StreamingRuntimeContext.class); + when(mockContext.getTaskName()).thenReturn("Test task name"); + + AggregatingProcessingTimeWindowOperator<Integer, Integer> op = + new AggregatingProcessingTimeWindowOperator<>( + sumFunction, identitySelector, windowSize, windowSize); + + op.setup(out, mockContext); + op.open(new Configuration()); + + final int numElements = 1000; + + for (int i = 0; i < numElements; i++) { + op.processElement(new StreamRecord<Integer>(i)); + Thread.sleep(1); + } + + op.close(); + op.dispose(); + + // get and verify the result + List<Integer> result = out.getElements(); + assertEquals(numElements, result.size()); + + Collections.sort(result); + for (int i = 0; i < numElements; i++) { + assertEquals(i, result.get(i).intValue()); + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testTumblingWindowDuplicateElements() { + + final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); + + try { + final int windowSize = 50; + final CollectingOutput<Integer> out = new CollectingOutput<>(windowSize); + + final StreamingRuntimeContext mockContext = mock(StreamingRuntimeContext.class); + when(mockContext.getTaskName()).thenReturn("Test task name"); + + final Object lock = new Object(); + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) throws Throwable { + final Long timestamp = (Long) invocationOnMock.getArguments()[0]; + final Triggerable target = (Triggerable) invocationOnMock.getArguments()[1]; + timerService.schedule( + new Callable<Object>() { + @Override + public Object call() throws Exception { + synchronized (lock) { + target.trigger(timestamp); + } + return null; + } + }, + timestamp - System.currentTimeMillis(), + TimeUnit.MILLISECONDS); + return null; + } + }).when(mockContext).registerTimer(anyLong(), any(Triggerable.class)); + + AggregatingProcessingTimeWindowOperator<Integer, Integer> op = + new AggregatingProcessingTimeWindowOperator<>( + sumFunction, identitySelector, windowSize, windowSize); + + op.setup(out, mockContext); + op.open(new Configuration()); + + final int numWindows = 10; + + long previousNextTime = 0; + int window = 1; + + while (window <= numWindows) { + long nextTime = op.getNextEvaluationTime(); + int val = ((int) nextTime) ^ ((int) (nextTime >>> 32)); + + synchronized (lock) { + op.processElement(new StreamRecord<Integer>(val)); + } + + if (nextTime != previousNextTime) { + window++; + previousNextTime = nextTime; + } + + Thread.sleep(1); + } + + op.close(); + op.dispose(); + + List<Integer> result = out.getElements(); + + // we have ideally one element per window. we may have more, when we emitted a value into the + // successive window (corner case), so we can have twice the number of elements, in the worst case. + assertTrue(result.size() >= numWindows && result.size() <= 2 * numWindows); + + // deduplicate for more accurate checks + HashSet<Integer> set = new HashSet<>(result); + assertTrue(set.size() == 10 || set.size() == 11); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } finally { + timerService.shutdown(); + } + } + + @Test + public void testSlidingWindow() { + try { + final CollectingOutput<Integer> out = new CollectingOutput<>(50); + + final StreamingRuntimeContext mockContext = mock(StreamingRuntimeContext.class); + when(mockContext.getTaskName()).thenReturn("Test task name"); + + // tumbling window that triggers every 20 milliseconds + AggregatingProcessingTimeWindowOperator<Integer, Integer> op = + new AggregatingProcessingTimeWindowOperator<>(sumFunction, identitySelector, 150, 50); + + op.setup(out, mockContext); + op.open(new Configuration()); + + final int numElements = 1000; + + for (int i = 0; i < numElements; i++) { + op.processElement(new StreamRecord<Integer>(i)); + Thread.sleep(1); + } + + op.close(); + op.dispose(); + + // get and verify the result + List<Integer> result = out.getElements(); + + // every element can occur between one and three times + if (result.size() < numElements || result.size() > 3 * numElements) { + System.out.println(result); + fail("Wrong number of results: " + result.size()); + } + + Collections.sort(result); + int lastNum = -1; + int lastCount = -1; + + for (int num : result) { + if (num == lastNum) { + lastCount++; + assertTrue(lastCount <= 3); + } + else { + lastNum = num; + lastCount = 1; + } + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testSlidingWindowSingleElements() { + final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor(); + + try { + final CollectingOutput<Integer> out = new CollectingOutput<>(50); + + final StreamingRuntimeContext mockContext = mock(StreamingRuntimeContext.class); + when(mockContext.getTaskName()).thenReturn("Test task name"); + + final Object lock = new Object(); + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) throws Throwable { + final Long timestamp = (Long) invocationOnMock.getArguments()[0]; + final Triggerable target = (Triggerable) invocationOnMock.getArguments()[1]; + timerService.schedule( + new Callable<Object>() { + @Override + public Object call() throws Exception { + synchronized (lock) { + target.trigger(timestamp); + } + return null; + } + }, + timestamp - System.currentTimeMillis(), + TimeUnit.MILLISECONDS); + return null; + } + }).when(mockContext).registerTimer(anyLong(), any(Triggerable.class)); + + // tumbling window that triggers every 20 milliseconds + AggregatingProcessingTimeWindowOperator<Integer, Integer> op = + new AggregatingProcessingTimeWindowOperator<>(sumFunction, identitySelector, 150, 50); + + op.setup(out, mockContext); + op.open(new Configuration()); + + synchronized (lock) { + op.processElement(new StreamRecord<Integer>(1)); + op.processElement(new StreamRecord<Integer>(2)); + } + + // each element should end up in the output three times + // wait until the elements have arrived 6 times in the output + out.waitForNElements(6, 120000); + + List<Integer> result = out.getElements(); + assertEquals(6, result.size()); + + Collections.sort(result); + assertEquals(Arrays.asList(1, 1, 1, 2, 2, 2), result); + + op.close(); + op.dispose(); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } finally { + timerService.shutdown(); + } + } + + @Test + public void testEmitTrailingDataOnClose() { + try { + final CollectingOutput<Integer> out = new CollectingOutput<>(); + + final StreamingRuntimeContext mockContext = mock(StreamingRuntimeContext.class); + when(mockContext.getTaskName()).thenReturn("Test task name"); + + // the operator has a window time that is so long that it will not fire in this test + final long oneYear = 365L * 24 * 60 * 60 * 1000; + AggregatingProcessingTimeWindowOperator<Integer, Integer> op = + new AggregatingProcessingTimeWindowOperator<>(sumFunction, identitySelector, oneYear, oneYear); + + op.setup(out, mockContext); + op.open(new Configuration()); + + List<Integer> data = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + for (Integer i : data) { + op.processElement(new StreamRecord<Integer>(i)); + } + + op.close(); + op.dispose(); + + // get and verify the result + List<Integer> result = out.getElements(); + Collections.sort(result); + assertEquals(data, result); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testPropagateExceptionsFromProcessElement() { + try { + final CollectingOutput<Integer> out = new CollectingOutput<>(); + + final StreamingRuntimeContext mockContext = mock(StreamingRuntimeContext.class); + when(mockContext.getTaskName()).thenReturn("Test task name"); + + ReduceFunction<Integer> failingFunction = new FailingFunction(100); + + // the operator has a window time that is so long that it will not fire in this test + final long hundredYears = 100L * 365 * 24 * 60 * 60 * 1000; + AggregatingProcessingTimeWindowOperator<Integer, Integer> op = + new AggregatingProcessingTimeWindowOperator<>( + failingFunction, identitySelector, hundredYears, hundredYears); + + op.setup(out, mockContext); + op.open(new Configuration()); + + for (int i = 0; i < 100; i++) { + op.processElement(new StreamRecord<Integer>(1)); + } + + try { + op.processElement(new StreamRecord<Integer>(1)); + fail("This fail with an exception"); + } + catch (Exception e) { + assertTrue(e.getMessage().contains("Artificial Test Exception")); + } + + op.dispose(); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // ------------------------------------------------------------------------ + + private void assertInvalidParameter(long windowSize, long windowSlide) { + try { + new AggregatingProcessingTimeWindowOperator<String, String>( + mockFunction, mockKeySelector, windowSize, windowSlide); + fail("This should fail with an IllegalArgumentException"); + } + catch (IllegalArgumentException e) { + // expected + } + catch (Exception e) { + fail("Wrong exception. Expected IllegalArgumentException but found " + e.getClass().getSimpleName()); + } + } + + // ------------------------------------------------------------------------ + + private static class FailingFunction implements ReduceFunction<Integer> { + + private final int failAfterElements; + + private int numElements; + + FailingFunction(int failAfterElements) { + this.failAfterElements = failAfterElements; + } + + @Override + public Integer reduce(Integer value1, Integer value2) throws Exception { + numElements++; + + if (numElements >= failAfterElements) { + throw new Exception("Artificial Test Exception"); + } + + return value1 + value2; + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/05d2138f/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/CollectingOutput.java ---------------------------------------------------------------------- diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/CollectingOutput.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/CollectingOutput.java new file mode 100644 index 0000000..3c1c24b --- /dev/null +++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/CollectingOutput.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.operators.windowing; + +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import java.util.ArrayList; +import java.util.List; + +public class CollectingOutput<T> implements Output<StreamRecord<T>> { + + private final List<T> elements = new ArrayList<>(); + + private final int timeStampModulus; + + + public CollectingOutput() { + this.timeStampModulus = 0; + } + + public CollectingOutput(int timeStampModulus) { + this.timeStampModulus = timeStampModulus; + } + + // ------------------------------------------------------------------------ + + public List<T> getElements() { + return elements; + } + + public void waitForNElements(int n, long timeout) throws InterruptedException { + long deadline = System.currentTimeMillis() + timeout; + synchronized (elements) { + long now; + while (elements.size() < n && (now = System.currentTimeMillis()) < deadline) { + elements.wait(deadline - now); + } + } + } + + // ------------------------------------------------------------------------ + + @Override + public void emitWatermark(Watermark mark) { + throw new UnsupportedOperationException("the output should not emit watermarks"); + } + + @Override + public void collect(StreamRecord<T> record) { + elements.add(record.getValue()); + + if (timeStampModulus != 0 && record.getTimestamp() % timeStampModulus != 0) { + throw new IllegalArgumentException("Invalid timestamp"); + } + synchronized (elements) { + elements.notifyAll(); + } + } + + @Override + public void close() {} +} http://git-wip-us.apache.org/repos/asf/flink/blob/05d2138f/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/KeyMapPutIfAbsentTest.java ---------------------------------------------------------------------- diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/KeyMapPutIfAbsentTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/KeyMapPutIfAbsentTest.java new file mode 100644 index 0000000..c0b20a3 --- /dev/null +++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/KeyMapPutIfAbsentTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.operators.windowing; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class KeyMapPutIfAbsentTest { + + @Test + public void testPutIfAbsentUniqueKeysAndGrowth() { + try { + KeyMap<Integer, Integer> map = new KeyMap<>(); + IntegerFactory factory = new IntegerFactory(); + + final int numElements = 1000000; + + for (int i = 0; i < numElements; i++) { + factory.set(2 * i + 1); + map.putIfAbsent(i, factory); + + assertEquals(i+1, map.size()); + assertTrue(map.getCurrentTableCapacity() > map.size()); + assertTrue(map.getCurrentTableCapacity() > map.getRehashThreshold()); + assertTrue(map.size() <= map.getRehashThreshold()); + } + + assertEquals(numElements, map.size()); + assertEquals(numElements, map.traverseAndCountElements()); + assertEquals(1 << 21, map.getCurrentTableCapacity()); + + for (int i = 0; i < numElements; i++) { + assertEquals(2 * i + 1, map.get(i).intValue()); + } + + for (int i = numElements - 1; i >= 0; i--) { + assertEquals(2 * i + 1, map.get(i).intValue()); + } + + assertEquals(numElements, map.size()); + assertEquals(numElements, map.traverseAndCountElements()); + assertEquals(1 << 21, map.getCurrentTableCapacity()); + assertTrue(map.getLongestChainLength() <= 7); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testPutIfAbsentDuplicateKeysAndGrowth() { + try { + KeyMap<Integer, Integer> map = new KeyMap<>(); + IntegerFactory factory = new IntegerFactory(); + + final int numElements = 1000000; + + for (int i = 0; i < numElements; i++) { + int val = 2 * i + 1; + factory.set(val); + Integer put = map.putIfAbsent(i, factory); + assertEquals(val, put.intValue()); + } + + for (int i = 0; i < numElements; i += 3) { + factory.set(2 * i); + Integer put = map.putIfAbsent(i, factory); + assertEquals(2 * i + 1, put.intValue()); + } + + for (int i = 0; i < numElements; i++) { + assertEquals(2 * i + 1, map.get(i).intValue()); + } + + assertEquals(numElements, map.size()); + assertEquals(numElements, map.traverseAndCountElements()); + assertEquals(1 << 21, map.getCurrentTableCapacity()); + assertTrue(map.getLongestChainLength() <= 7); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // ------------------------------------------------------------------------ + + private static class IntegerFactory implements KeyMap.LazyFactory<Integer> { + + private Integer toCreate; + + public void set(Integer toCreate) { + this.toCreate = toCreate; + } + + @Override + public Integer create() { + return toCreate; + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/05d2138f/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/KeyMapPutTest.java ---------------------------------------------------------------------- diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/KeyMapPutTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/KeyMapPutTest.java new file mode 100644 index 0000000..09c44fe --- /dev/null +++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/KeyMapPutTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.operators.windowing; + +import org.junit.Test; + +import java.util.BitSet; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class KeyMapPutTest { + + @Test + public void testPutUniqueKeysAndGrowth() { + try { + KeyMap<Integer, Integer> map = new KeyMap<>(); + + final int numElements = 1000000; + + for (int i = 0; i < numElements; i++) { + map.put(i, 2 * i + 1); + + assertEquals(i+1, map.size()); + assertTrue(map.getCurrentTableCapacity() > map.size()); + assertTrue(map.getCurrentTableCapacity() > map.getRehashThreshold()); + assertTrue(map.size() <= map.getRehashThreshold()); + } + + assertEquals(numElements, map.size()); + assertEquals(numElements, map.traverseAndCountElements()); + assertEquals(1 << 21, map.getCurrentTableCapacity()); + + for (int i = 0; i < numElements; i++) { + assertEquals(2 * i + 1, map.get(i).intValue()); + } + + for (int i = numElements - 1; i >= 0; i--) { + assertEquals(2 * i + 1, map.get(i).intValue()); + } + + BitSet bitset = new BitSet(); + int numContained = 0; + for (KeyMap.Entry<Integer, Integer> entry : map) { + numContained++; + + assertEquals(entry.getKey() * 2 + 1, entry.getValue().intValue()); + assertFalse(bitset.get(entry.getKey())); + bitset.set(entry.getKey()); + } + + assertEquals(numElements, numContained); + assertEquals(numElements, bitset.cardinality()); + + + assertEquals(numElements, map.size()); + assertEquals(numElements, map.traverseAndCountElements()); + assertEquals(1 << 21, map.getCurrentTableCapacity()); + assertTrue(map.getLongestChainLength() <= 7); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testPutDuplicateKeysAndGrowth() { + try { + final KeyMap<Integer, Integer> map = new KeyMap<>(); + final int numElements = 1000000; + + for (int i = 0; i < numElements; i++) { + Integer put = map.put(i, 2*i+1); + assertNull(put); + } + + for (int i = 0; i < numElements; i += 3) { + Integer put = map.put(i, 2*i); + assertNotNull(put); + assertEquals(2*i+1, put.intValue()); + } + + for (int i = 0; i < numElements; i++) { + int expected = (i % 3 == 0) ? (2*i) : (2*i+1); + assertEquals(expected, map.get(i).intValue()); + } + + assertEquals(numElements, map.size()); + assertEquals(numElements, map.traverseAndCountElements()); + assertEquals(1 << 21, map.getCurrentTableCapacity()); + assertTrue(map.getLongestChainLength() <= 7); + + + BitSet bitset = new BitSet(); + int numContained = 0; + for (KeyMap.Entry<Integer, Integer> entry : map) { + numContained++; + + int key = entry.getKey(); + int expected = key % 3 == 0 ? (2*key) : (2*key+1); + + assertEquals(expected, entry.getValue().intValue()); + assertFalse(bitset.get(key)); + bitset.set(key); + } + + assertEquals(numElements, numContained); + assertEquals(numElements, bitset.cardinality()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/05d2138f/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/KeyMapTest.java ---------------------------------------------------------------------- diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/KeyMapTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/KeyMapTest.java new file mode 100644 index 0000000..49310df --- /dev/null +++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/KeyMapTest.java @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.operators.windowing; + +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Random; + +import static org.junit.Assert.*; + +public class KeyMapTest { + + @Test + public void testInitialSizeComputation() { + try { + KeyMap<String, String> map; + + map = new KeyMap<>(); + assertEquals(64, map.getCurrentTableCapacity()); + assertEquals(6, map.getLog2TableCapacity()); + assertEquals(24, map.getShift()); + assertEquals(48, map.getRehashThreshold()); + + map = new KeyMap<>(0); + assertEquals(64, map.getCurrentTableCapacity()); + assertEquals(6, map.getLog2TableCapacity()); + assertEquals(24, map.getShift()); + assertEquals(48, map.getRehashThreshold()); + + map = new KeyMap<>(1); + assertEquals(64, map.getCurrentTableCapacity()); + assertEquals(6, map.getLog2TableCapacity()); + assertEquals(24, map.getShift()); + assertEquals(48, map.getRehashThreshold()); + + map = new KeyMap<>(9); + assertEquals(64, map.getCurrentTableCapacity()); + assertEquals(6, map.getLog2TableCapacity()); + assertEquals(24, map.getShift()); + assertEquals(48, map.getRehashThreshold()); + + map = new KeyMap<>(63); + assertEquals(64, map.getCurrentTableCapacity()); + assertEquals(6, map.getLog2TableCapacity()); + assertEquals(24, map.getShift()); + assertEquals(48, map.getRehashThreshold()); + + map = new KeyMap<>(64); + assertEquals(128, map.getCurrentTableCapacity()); + assertEquals(7, map.getLog2TableCapacity()); + assertEquals(23, map.getShift()); + assertEquals(96, map.getRehashThreshold()); + + map = new KeyMap<>(500); + assertEquals(512, map.getCurrentTableCapacity()); + assertEquals(9, map.getLog2TableCapacity()); + assertEquals(21, map.getShift()); + assertEquals(384, map.getRehashThreshold()); + + map = new KeyMap<>(127); + assertEquals(128, map.getCurrentTableCapacity()); + assertEquals(7, map.getLog2TableCapacity()); + assertEquals(23, map.getShift()); + assertEquals(96, map.getRehashThreshold()); + + // no negative number of elements + try { + new KeyMap<>(-1); + fail("should fail with an exception"); + } + catch (IllegalArgumentException e) { + // expected + } + + // check integer overflow + try { + map = new KeyMap<>(0x65715522); + + final int maxCap = Integer.highestOneBit(Integer.MAX_VALUE); + assertEquals(Integer.highestOneBit(Integer.MAX_VALUE), map.getCurrentTableCapacity()); + assertEquals(30, map.getLog2TableCapacity()); + assertEquals(0, map.getShift()); + assertEquals(maxCap / 4 * 3, map.getRehashThreshold()); + } + catch (OutOfMemoryError e) { + // this may indeed happen in small test setups. we tolerate this in this test + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testPutAndGetRandom() { + try { + final KeyMap<Integer, Integer> map = new KeyMap<>(); + final Random rnd = new Random(); + + final long seed = rnd.nextLong(); + final int numElements = 10000; + + final HashMap<Integer, Integer> groundTruth = new HashMap<>(); + + rnd.setSeed(seed); + for (int i = 0; i < numElements; i++) { + Integer key = rnd.nextInt(); + Integer value = rnd.nextInt(); + + if (rnd.nextBoolean()) { + groundTruth.put(key, value); + map.put(key, value); + } + } + + rnd.setSeed(seed); + for (int i = 0; i < numElements; i++) { + Integer key = rnd.nextInt(); + + // skip these, evaluating it is tricky due to duplicates + rnd.nextInt(); + rnd.nextBoolean(); + + Integer expected = groundTruth.get(key); + if (expected == null) { + assertNull(map.get(key)); + } + else { + Integer contained = map.get(key); + assertNotNull(contained); + assertEquals(expected, contained); + } + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testConjunctTraversal() { + try { + final Random rootRnd = new Random(654685486325439L); + + final int numMaps = 7; + final int numKeys = 1000000; + + // ------ create a set of maps ------ + @SuppressWarnings("unchecked") + final KeyMap<Integer, Integer>[] maps = (KeyMap<Integer, Integer>[]) new KeyMap<?, ?>[numMaps]; + for (int i = 0; i < numMaps; i++) { + maps[i] = new KeyMap<>(); + } + + // ------ prepare probabilities for maps ------ + final double[] probabilities = new double[numMaps]; + final double[] probabilitiesTemp = new double[numMaps]; + { + probabilities[0] = 0.5; + double remainingProb = 1.0 - probabilities[0]; + for (int i = 1; i < numMaps - 1; i++) { + remainingProb /= 2; + probabilities[i] = remainingProb; + } + + // compensate for rounding errors + probabilities[numMaps - 1] = remainingProb; + } + + // ------ generate random elements ------ + final long probSeed = rootRnd.nextLong(); + final long keySeed = rootRnd.nextLong(); + + final Random probRnd = new Random(probSeed); + final Random keyRnd = new Random(keySeed); + + final int maxStride = Integer.MAX_VALUE / numKeys; + + int totalNumElements = 0; + int nextKeyValue = 1; + + for (int i = 0; i < numKeys; i++) { + int numCopies = (nextKeyValue % 3) + 1; + System.arraycopy(probabilities, 0, probabilitiesTemp, 0, numMaps); + + double totalProb = 1.0; + for (int copy = 0; copy < numCopies; copy++) { + int pos = drawPosProportionally(probabilitiesTemp, totalProb, probRnd); + totalProb -= probabilitiesTemp[pos]; + probabilitiesTemp[pos] = 0.0; + + Integer boxed = nextKeyValue; + Integer previous = maps[pos].put(boxed, boxed); + assertNull("Test problem - test does not assign unique maps", previous); + } + + totalNumElements += numCopies; + nextKeyValue += keyRnd.nextInt(maxStride) + 1; + } + + + // check that all maps contain the total number of elements + int numContained = 0; + for (KeyMap<?, ?> map : maps) { + numContained += map.size(); + } + assertEquals(totalNumElements, numContained); + + // ------ check that all elements can be found in the maps ------ + keyRnd.setSeed(keySeed); + + numContained = 0; + nextKeyValue = 1; + for (int i = 0; i < numKeys; i++) { + int numCopiesExpected = (nextKeyValue % 3) + 1; + int numCopiesContained = 0; + + for (KeyMap<Integer, Integer> map : maps) { + Integer val = map.get(nextKeyValue); + if (val != null) { + assertEquals(nextKeyValue, val.intValue()); + numCopiesContained++; + } + } + + assertEquals(numCopiesExpected, numCopiesContained); + numContained += numCopiesContained; + + nextKeyValue += keyRnd.nextInt(maxStride) + 1; + } + assertEquals(totalNumElements, numContained); + + // ------ make a traversal over all keys and validate the keys in the traversal ------ + final int[] keysStartedAndFinished = { 0, 0 }; + KeyMap.TraversalEvaluator<Integer, Integer> traversal = new KeyMap.TraversalEvaluator<Integer, Integer>() { + + private int key; + private int valueCount; + + @Override + public void startNewKey(Integer key) { + this.key = key; + this.valueCount = 0; + + keysStartedAndFinished[0]++; + } + + @Override + public void nextValue(Integer value) { + assertEquals(this.key, value.intValue()); + this.valueCount++; + } + + @Override + public void keyDone() { + int expected = (key % 3) + 1; + if (expected != valueCount) { + fail("Wrong count for key " + key + " ; expected=" + expected + " , count=" + valueCount); + } + + keysStartedAndFinished[1]++; + } + }; + + KeyMap.traverseMaps(shuffleArray(maps, rootRnd), traversal, 17); + + assertEquals(numKeys, keysStartedAndFinished[0]); + assertEquals(numKeys, keysStartedAndFinished[1]); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testSizeComparator() { + try { + KeyMap<String, String> map1 = new KeyMap<>(5); + KeyMap<String, String> map2 = new KeyMap<>(80); + + assertTrue(map1.getCurrentTableCapacity() < map2.getCurrentTableCapacity()); + + assertTrue(KeyMap.CapacityDescendingComparator.INSTANCE.compare(map1, map1) == 0); + assertTrue(KeyMap.CapacityDescendingComparator.INSTANCE.compare(map2, map2) == 0); + assertTrue(KeyMap.CapacityDescendingComparator.INSTANCE.compare(map1, map2) > 0); + assertTrue(KeyMap.CapacityDescendingComparator.INSTANCE.compare(map2, map1) < 0); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // ------------------------------------------------------------------------ + + private static int drawPosProportionally(double[] array, double totalProbability, Random rnd) { + double val = rnd.nextDouble() * totalProbability; + + double accum = 0; + for (int i = 0; i < array.length; i++) { + accum += array[i]; + if (val <= accum && array[i] > 0.0) { + return i; + } + } + + // in case of rounding errors + return array.length - 1; + } + + private static <E> E[] shuffleArray(E[] array, Random rnd) { + E[] target = Arrays.copyOf(array, array.length); + + for (int i = target.length - 1; i > 0; i--) { + int swapPos = rnd.nextInt(i + 1); + E temp = target[i]; + target[i] = target[swapPos]; + target[swapPos] = temp; + } + + return target; + } +}
