This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch release-1.10
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.10 by this push:
     new 81b1895  [FLINK-15465][FLINK-11964][table-runtime-blink] Fix required 
memory calculation not accurate and hash collision bugs in hash table (#10756)
81b1895 is described below

commit 81b18957da8e35b414b6c6017d13720157340d59
Author: Jingsong Lee <[email protected]>
AuthorDate: Tue Jan 7 18:41:25 2020 +0800

    [FLINK-15465][FLINK-11964][table-runtime-blink] Fix required memory 
calculation not accurate and hash collision bugs in hash table (#10756)
---
 .../runtime/hashtable/BaseHybridHashTable.java     |  11 ++-
 .../runtime/hashtable/BinaryHashBucketArea.java    | 107 +++++++++++++--------
 .../table/runtime/hashtable/BinaryHashTable.java   |   6 +-
 .../table/runtime/hashtable/LongHashPartition.java |  32 +++---
 .../runtime/hashtable/LongHybridHashTable.java     |   2 +-
 .../runtime/hashtable/BinaryHashTableTest.java     |  16 +++
 6 files changed, 106 insertions(+), 68 deletions(-)

diff --git 
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java
 
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java
index 24dca46..456b18d 100644
--- 
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java
+++ 
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/BaseHybridHashTable.java
@@ -73,10 +73,6 @@ public abstract class BaseHybridHashTable implements 
MemorySegmentPool {
        private static final int MIN_NUM_MEMORY_SEGMENTS = 33;
        protected final int initPartitionFanOut;
 
-       /**
-        * The owner to associate with the memory segment.
-        */
-       private Object owner;
        private final int avgRecordLen;
        protected final long buildRowCount;
 
@@ -177,7 +173,6 @@ public abstract class BaseHybridHashTable implements 
MemorySegmentPool {
                this.compressionBlockSize = (int) MemorySize.parse(
                        
conf.getString(ExecutionConfigOptions.TABLE_EXEC_SPILL_COMPRESSION_BLOCK_SIZE)).getBytes();
 
-               this.owner = owner;
                this.avgRecordLen = avgRecordLen;
                this.buildRowCount = buildRowCount;
                this.tryDistinctBuildRow = tryDistinctBuildRow;
@@ -516,4 +511,10 @@ public abstract class BaseHybridHashTable implements 
MemorySegmentPool {
                return code >= 0 ? code : -(code + 1);
        }
 
+       /**
+        * Partition level hash again, for avoid two layer hash conflict.
+        */
+       static int partitionLevelHash(int hash) {
+               return hash ^ (hash >>> 16);
+       }
 }
diff --git 
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashBucketArea.java
 
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashBucketArea.java
index e44cc81..27de439 100644
--- 
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashBucketArea.java
+++ 
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashBucketArea.java
@@ -19,6 +19,7 @@
 package org.apache.flink.table.runtime.hashtable;
 
 import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.io.disk.RandomAccessInputView;
 import org.apache.flink.table.dataformat.BinaryRow;
 import org.apache.flink.util.MathUtils;
@@ -28,9 +29,10 @@ import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.nio.ByteOrder;
-import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 
+import static 
org.apache.flink.table.runtime.hashtable.BaseHybridHashTable.partitionLevelHash;
 import static org.apache.flink.util.Preconditions.checkArgument;
 
 /**
@@ -139,6 +141,8 @@ public class BinaryHashBucketArea {
        final BinaryHashTable table;
        private final double estimatedRowCount;
        private final double loadFactor;
+       private final boolean spillingAllowed;
+
        BinaryHashPartition partition;
        private int size;
 
@@ -154,13 +158,23 @@ public class BinaryHashBucketArea {
        private boolean inReHash = false;
 
        BinaryHashBucketArea(BinaryHashTable table, double estimatedRowCount, 
int maxSegs) {
-               this(table, estimatedRowCount, maxSegs, DEFAULT_LOAD_FACTOR);
+               this(table, estimatedRowCount, maxSegs, DEFAULT_LOAD_FACTOR, 
true);
+       }
+
+       BinaryHashBucketArea(BinaryHashTable table, double estimatedRowCount, 
int maxSegs, boolean spillingAllowed) {
+               this(table, estimatedRowCount, maxSegs, DEFAULT_LOAD_FACTOR, 
spillingAllowed);
        }
 
-       private BinaryHashBucketArea(BinaryHashTable table, double 
estimatedRowCount, int maxSegs, double loadFactor) {
+       private BinaryHashBucketArea(
+                       BinaryHashTable table,
+                       double estimatedRowCount,
+                       int maxSegs,
+                       double loadFactor,
+                       boolean spillingAllowed) {
                this.table = table;
                this.estimatedRowCount = estimatedRowCount;
                this.loadFactor = loadFactor;
+               this.spillingAllowed = spillingAllowed;
                this.size = 0;
 
                int minNumBuckets = (int) Math.ceil((estimatedRowCount / 
loadFactor / NUM_ENTRIES_PER_BUCKET));
@@ -198,7 +212,7 @@ public class BinaryHashBucketArea {
                this.partition = partition;
        }
 
-       private void resize(boolean spillingAllowed) throws IOException {
+       private void resize() throws IOException {
                MemorySegment[] oldBuckets = this.buckets;
                int oldNumBuckets = numBuckets;
                MemorySegment[] oldOverflowSegments = overflowSegments;
@@ -265,7 +279,7 @@ public class BinaryHashBucketArea {
                                while (numInBucket < countInBucket) {
                                        int hashCode = 
bucketSeg.getInt(hashCodeOffset);
                                        int pointer = 
bucketSeg.getInt(pointerOffset);
-                                       if (!insertToBucket(hashCode, pointer, 
true, false)) {
+                                       if (!insertToBucket(hashCode, pointer, 
false)) {
                                                
buildBloomFilterAndFree(oldBuckets, oldNumBuckets, oldOverflowSegments);
                                                return;
                                        }
@@ -293,17 +307,6 @@ public class BinaryHashBucketArea {
                LOG.info("The rehash take {} ms for {} segments", 
(System.currentTimeMillis() - reHashStartTime), numBuckets);
        }
 
-       private void freeMemory(MemorySegment[] buckets, MemorySegment[] 
overflowSegments) {
-               for (MemorySegment segment : buckets) {
-                       table.free(segment);
-               }
-               for (MemorySegment segment : overflowSegments) {
-                       if (segment != null) {
-                               table.free(segment);
-                       }
-               }
-       }
-
        private void initMemorySegment(MemorySegment seg) {
                // go over all buckets in the segment
                for (int k = 0; k < table.bucketsPerSegment; k++) {
@@ -315,8 +318,11 @@ public class BinaryHashBucketArea {
        }
 
        private boolean insertToBucket(
-                       MemorySegment bucket, int bucketInSegmentPos,
-                       int hashCode, int pointer, boolean spillingAllowed, 
boolean sizeAddAndCheckResize) throws IOException {
+                       MemorySegment bucket,
+                       int bucketInSegmentPos,
+                       int hashCode,
+                       int pointer,
+                       boolean sizeAddAndCheckResize) throws IOException {
                final int count = bucket.getShort(bucketInSegmentPos + 
HEADER_COUNT_OFFSET);
                if (count < NUM_ENTRIES_PER_BUCKET) {
                        // we are good in our current bucket, put the values
@@ -364,19 +370,23 @@ public class BinaryHashBucketArea {
                                // no space left in last bucket, or no bucket 
yet, so create an overflow segment
                                overflowSeg = table.getNextBuffer();
                                if (overflowSeg == null) {
-                                       // no memory available to create 
overflow bucket. we need to spill a partition
                                        if (!spillingAllowed) {
-                                               throw new 
IOException("Hashtable memory ran out in a non-spillable situation. " +
-                                                               "This is 
probably related to wrong size calculations.");
-                                       }
-                                       final int spilledPart = 
table.spillPartition();
-                                       if (spilledPart == 
partition.partitionNumber) {
-                                               // this bucket is no longer 
in-memory
-                                               return false;
-                                       }
-                                       overflowSeg = table.getNextBuffer();
-                                       if (overflowSeg == null) {
-                                               throw new RuntimeException("Bug 
in HybridHashJoin: No memory became available after spilling a partition.");
+                                               // In this corner case, we 
steal memory from heap.
+                                               // Because the linked hash 
conflict solution, the required memory
+                                               // calculation are not 
accurate, in this case, we apply for insufficient
+                                               // memory from heap.
+                                               // NOTE: must be careful, the 
steal memory should not return to table.
+                                               overflowSeg = 
MemorySegmentFactory.allocateUnpooledSegment(table.segmentSize, this);
+                                       } else {
+                                               final int spilledPart = 
table.spillPartition();
+                                               if (spilledPart == 
partition.partitionNumber) {
+                                                       // this bucket is no 
longer in-memory
+                                                       return false;
+                                               }
+                                               overflowSeg = 
table.getNextBuffer();
+                                               if (overflowSeg == null) {
+                                                       throw new 
RuntimeException("Bug in HybridHashJoin: No memory became available after 
spilling a partition.");
+                                               }
                                        }
                                }
                                overflowBucketOffset = 0;
@@ -420,26 +430,27 @@ public class BinaryHashBucketArea {
                }
 
                if (sizeAddAndCheckResize && ++size > threshold) {
-                       resize(spillingAllowed);
+                       resize();
                }
                return true;
        }
 
-       private int findBucket(int hashCode) {
-               return hashCode & this.numBucketsMask;
+       private int findBucket(int hash) {
+               // Avoid two layer hash conflict
+               return partitionLevelHash(hash) & this.numBucketsMask;
        }
 
        /**
         * Insert into bucket by hashCode and pointer.
         * @return return false when spill own partition.
         */
-       boolean insertToBucket(int hashCode, int pointer, boolean 
spillingAllowed, boolean sizeAddAndCheckResize) throws IOException {
+       boolean insertToBucket(int hashCode, int pointer, boolean 
sizeAddAndCheckResize) throws IOException {
                final int posHashCode = findBucket(hashCode);
                // get the bucket for the given hash code
                final int bucketArrayPos = posHashCode >> 
table.bucketsPerSegmentBits;
                final int bucketInSegmentPos = (posHashCode & 
table.bucketsPerSegmentMask) << BUCKET_SIZE_BITS;
                final MemorySegment bucket = this.buckets[bucketArrayPos];
-               return insertToBucket(bucket, bucketInSegmentPos, hashCode, 
pointer, spillingAllowed, sizeAddAndCheckResize);
+               return insertToBucket(bucket, bucketInSegmentPos, hashCode, 
pointer, sizeAddAndCheckResize);
        }
 
        /**
@@ -458,7 +469,7 @@ public class BinaryHashBucketArea {
                        int pointer = partition.insertIntoBuildBuffer(record);
                        if (pointer != -1) {
                                // record was inserted into an in-memory 
partition. a pointer must be inserted into the buckets
-                               insertToBucket(bucket, bucketInSegmentPos, 
hashCode, pointer, true, true);
+                               insertToBucket(bucket, bucketInSegmentPos, 
hashCode, pointer, true);
                                return true;
                        } else {
                                return false;
@@ -536,13 +547,27 @@ public class BinaryHashBucketArea {
        }
 
        void returnMemory(List<MemorySegment> target) {
-               target.addAll(Arrays.asList(overflowSegments).subList(0, 
numOverflowSegments));
-               target.addAll(Arrays.asList(buckets));
+               returnMemory(target, buckets, overflowSegments);
        }
 
-       private void freeMemory() {
-               
table.availableMemory.addAll(Arrays.asList(overflowSegments).subList(0, 
numOverflowSegments));
-               table.availableMemory.addAll(Arrays.asList(buckets));
+       void freeMemory() {
+               returnMemory(table.availableMemory, buckets, overflowSegments);
+       }
+
+       private void freeMemory(MemorySegment[] buckets, MemorySegment[] 
overflowSegments) {
+               returnMemory(table.availableMemory, buckets, overflowSegments);
+       }
+
+       private void returnMemory(
+                       List<MemorySegment> target, MemorySegment[] buckets, 
MemorySegment[] overflowSegments) {
+               Collections.addAll(target, buckets);
+               for (MemorySegment segment : overflowSegments) {
+                       if (segment != null &&
+                                       // except stealing from heap.
+                                       segment.getOwner() != this) {
+                               target.add(segment);
+                       }
+               }
        }
 
        /**
diff --git 
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java
 
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java
index 153323e..244c596 100644
--- 
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java
+++ 
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/BinaryHashTable.java
@@ -477,7 +477,7 @@ public class BinaryHashTable extends BaseHybridHashTable {
 
                        // first read the partition in
                        final List<MemorySegment> partitionBuffers = 
readAllBuffers(p.getBuildSideChannel().getChannelID(), 
p.getBuildSideBlockCount());
-                       BinaryHashBucketArea area = new 
BinaryHashBucketArea(this, (int) p.getBuildSideRecordCount(), 
maxBucketAreaBuffers);
+                       BinaryHashBucketArea area = new 
BinaryHashBucketArea(this, (int) p.getBuildSideRecordCount(), 
maxBucketAreaBuffers, false);
                        final BinaryHashPartition newPart = new 
BinaryHashPartition(area, this.binaryBuildSideSerializer, 
this.binaryProbeSideSerializer,
                                        0, nextRecursionLevel, 
partitionBuffers, p.getBuildSideRecordCount(), this.segmentSize, 
p.getLastSegmentLimit());
                        area.setPartition(newPart);
@@ -490,7 +490,7 @@ public class BinaryHashTable extends BaseHybridHashTable {
                        while (pIter.advanceNext()) {
                                final int hashCode = 
hash(buildSideProjection.apply(pIter.getRow()).hashCode(), nextRecursionLevel);
                                final int pointer = (int) pIter.getPointer();
-                               area.insertToBucket(hashCode, pointer, false, 
true);
+                               area.insertToBucket(hashCode, pointer, true);
                        }
                } else {
                        // go over the complete input and insert every element 
into the hash table
@@ -624,7 +624,7 @@ public class BinaryHashTable extends BaseHybridHashTable {
                return largestPartNum;
        }
 
-       boolean applyCondition(BinaryRow candidate) throws Exception {
+       boolean applyCondition(BinaryRow candidate) {
                BinaryRow buildKey = buildSideProjection.apply(candidate);
                // They come from Projection, so we can make sure it is in 
byte[].
                boolean equal = buildKey.getSizeInBytes() == 
probeKey.getSizeInBytes()
diff --git 
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/LongHashPartition.java
 
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/LongHashPartition.java
index 17e8902..847d8f6e 100644
--- 
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/LongHashPartition.java
+++ 
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/LongHashPartition.java
@@ -47,6 +47,7 @@ import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.LinkedBlockingQueue;
 
+import static 
org.apache.flink.table.runtime.hashtable.BaseHybridHashTable.partitionLevelHash;
 import static 
org.apache.flink.table.runtime.hashtable.LongHybridHashTable.hashLong;
 import static org.apache.flink.util.Preconditions.checkArgument;
 
@@ -198,7 +199,7 @@ public class LongHashPartition extends 
AbstractPagedInputView implements Seekabl
 
        private static MemorySegment[] listToArray(List<MemorySegment> list) {
                if (list != null) {
-                       return list.toArray(new MemorySegment[list.size()]);
+                       return list.toArray(new MemorySegment[0]);
                }
                return null;
        }
@@ -225,15 +226,15 @@ public class LongHashPartition extends 
AbstractPagedInputView implements Seekabl
                this.numKeys = 0;
        }
 
-       static long toAddrAndLen(long address, int size) {
+       private static long toAddrAndLen(long address, int size) {
                return (address << SIZE_BITS) | size;
        }
 
-       static long toAddress(long addrAndLen) {
+       private static long toAddress(long addrAndLen) {
                return addrAndLen >>> SIZE_BITS;
        }
 
-       static int toLength(long addrAndLen) {
+       private static int toLength(long addrAndLen) {
                return (int) (addrAndLen & SIZE_MASK);
        }
 
@@ -245,15 +246,11 @@ public class LongHashPartition extends 
AbstractPagedInputView implements Seekabl
                return iterator;
        }
 
-//     public MatchIterator get(long key) {
-//             return get(key, hashLong(key, recursionLevel));
-//     }
-
        /**
         * Returns an iterator for all the values for the given key, or null if 
no value found.
         */
        public MatchIterator get(long key, int hashCode) {
-               int bucket = hashCode & numBucketsMask;
+               int bucket = findBucket(hashCode);
 
                int bucketOffset = bucket << 4;
                MemorySegment segment = buckets[bucketOffset >>> 
segmentSizeBits];
@@ -291,7 +288,7 @@ public class LongHashPartition extends 
AbstractPagedInputView implements Seekabl
                        MemorySegment dataSegment,
                        int currentPositionInSegment) throws IOException {
                assert (numKeys <= numBuckets / 2);
-               int bucketId = hashCode & numBucketsMask;
+               int bucketId = findBucket(hashCode);
 
                // each bucket occupied 16 bytes (long key + long pointer to 
data address)
                int bucketOffset = bucketId * 
SPARSE_BUCKET_ELEMENT_SIZE_IN_BYTES;
@@ -342,6 +339,10 @@ public class LongHashPartition extends 
AbstractPagedInputView implements Seekabl
                }
        }
 
+       private int findBucket(int hash) {
+               return partitionLevelHash(hash) & this.numBucketsMask;
+       }
+
        private void resize() throws IOException {
                MemorySegment[] oldBuckets = this.buckets;
                int oldNumBuckets = numBuckets;
@@ -418,10 +419,6 @@ public class LongHashPartition extends 
AbstractPagedInputView implements Seekabl
                return this.buildSideChannel;
        }
 
-       FileIOChannel.ID getProbeSideChannelID() {
-               return probeSideBuffer.getChannel().getChannelID();
-       }
-
        int getPartitionNumber() {
                return this.partitionNum;
        }
@@ -622,7 +619,7 @@ public class LongHashPartition extends 
AbstractPagedInputView implements Seekabl
                        if (row.getSegments().length == 1) {
                                
buildSideWriteBuffer.write(row.getSegments()[0], row.getOffset(), sizeInBytes);
                        } else {
-                               
buildSideSerializer.serializeWithoutLengthSlow(row, buildSideWriteBuffer);
+                               
BinaryRowSerializer.serializeWithoutLengthSlow(row, buildSideWriteBuffer);
                        }
                } else {
                        serializeToPages(row);
@@ -642,7 +639,7 @@ public class LongHashPartition extends 
AbstractPagedInputView implements Seekabl
                if (row.getSegments().length == 1) {
                        buildSideWriteBuffer.write(row.getSegments()[0], 
row.getOffset(), sizeInBytes);
                } else {
-                       buildSideSerializer.serializeWithoutLengthSlow(row, 
buildSideWriteBuffer);
+                       BinaryRowSerializer.serializeWithoutLengthSlow(row, 
buildSideWriteBuffer);
                }
        }
 
@@ -729,8 +726,7 @@ public class LongHashPartition extends 
AbstractPagedInputView implements Seekabl
 
                        if (this.writer == null) {
                                this.targetList.add(current);
-                               MemorySegment[] buffers =
-                                               this.targetList.toArray(new 
MemorySegment[this.targetList.size()]);
+                               MemorySegment[] buffers = 
this.targetList.toArray(new MemorySegment[0]);
                                this.targetList.clear();
                                return buffers;
                        } else {
diff --git 
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java
 
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java
index f0d7ebc..3335479 100644
--- 
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java
+++ 
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/hashtable/LongHybridHashTable.java
@@ -247,7 +247,7 @@ public abstract class LongHybridHashTable extends 
BaseHybridHashTable {
 
                        this.denseBuckets = denseBuckets;
                        this.densePartition = new LongHashPartition(this, 
buildSideSerializer,
-                                       dataBuffers.toArray(new 
MemorySegment[dataBuffers.size()]));
+                                       dataBuffers.toArray(new 
MemorySegment[0]));
                        freeCurrent();
                }
        }
diff --git 
a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/hashtable/BinaryHashTableTest.java
 
b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/hashtable/BinaryHashTableTest.java
index dd39c95..f42c3db 100644
--- 
a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/hashtable/BinaryHashTableTest.java
+++ 
b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/hashtable/BinaryHashTableTest.java
@@ -848,6 +848,22 @@ public class BinaryHashTableTest {
                table.free();
        }
 
+       @Test
+       public void testBinaryHashBucketAreaNotEnoughMem() throws IOException {
+               MemoryManager memManager = 
MemoryManagerBuilder.newBuilder().setMemorySize(35 * PAGE_SIZE).build();
+               BinaryHashTable table = newBinaryHashTable(
+                               this.buildSideSerializer, 
this.probeSideSerializer,
+                               new MyProjection(), new MyProjection(), 
memManager,
+                               35 * PAGE_SIZE, ioManager);
+               BinaryHashBucketArea area = new BinaryHashBucketArea(table, 
100, 1, false);
+               for (int i = 0; i < 100000; i++) {
+                       area.insertToBucket(i, i, true);
+               }
+               area.freeMemory();
+               table.close();
+               Assert.assertEquals(35, table.getFreedMemory().size());
+       }
+
        // 
============================================================================================
 
        /**

Reply via email to