This is an automated email from the ASF dual-hosted git repository.
szita pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hive.git
The following commit(s) were added to refs/heads/master by this push:
new a8bd1eb HIVE-21217: Optimize range calculation for PTF (Adam Szita,
reviewed by Peter Vary)
a8bd1eb is described below
commit a8bd1eb09e3feb10d68075c6bc676ec333f498da
Author: Adam Szita <[email protected]>
AuthorDate: Wed Feb 20 10:26:57 2019 +0100
HIVE-21217: Optimize range calculation for PTF (Adam Szita, reviewed by
Peter Vary)
---
.../java/org/apache/hadoop/hive/conf/HiveConf.java | 4 +
.../apache/hadoop/hive/ql/exec/BoundaryCache.java | 124 ++++++
.../apache/hadoop/hive/ql/exec/PTFPartition.java | 7 +
.../hive/ql/udf/ptf/BasePartitionEvaluator.java | 1 +
.../hive/ql/udf/ptf/ValueBoundaryScanner.java | 416 +++++++++++++++++----
.../hadoop/hive/ql/udf/ptf/TestBoundaryCache.java | 295 +++++++++++++++
6 files changed, 765 insertions(+), 82 deletions(-)
diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
index 4a86b0a..11f165a 100644
--- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
+++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
@@ -1634,6 +1634,10 @@ public class HiveConf extends Configuration {
+ "the evaluation of certain joins, since we will not be emitting rows
which are thrown away by "
+ "a Filter operator straight away. However, currently vectorization
does not support them, thus "
+ "enabling it is only recommended when vectorization is disabled."),
+ HIVE_PTF_RANGECACHE_SIZE("hive.ptf.rangecache.size", 10000,
+ "Size of the cache used on reducer side, that stores boundaries of
ranges within a PTF " +
+ "partition. Used if a query specifies a RANGE type window including an
orderby clause." +
+ "Set this to 0 to disable this cache."),
// CBO related
HIVE_CBO_ENABLED("hive.cbo.enable", true, "Flag to control enabling Cost
Based Optimizations using Calcite framework."),
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/BoundaryCache.java
b/ql/src/java/org/apache/hadoop/hive/ql/exec/BoundaryCache.java
new file mode 100644
index 0000000..7cf278c
--- /dev/null
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/BoundaryCache.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec;
+
+import java.util.LinkedList;
+import java.util.Map;
+import java.util.TreeMap;
+
+/**
+ * Cache for storing boundaries found within a partition - used for PTF
functions.
+ * Stores key-value pairs where key is the row index in the partition from
which a range begins,
+ * value is the corresponding row value (based on what the user specified in
the orderby column).
+ */
+public class BoundaryCache extends TreeMap<Integer, Object> {
+
+ private boolean isComplete = false;
+ private final int maxSize;
+ private final LinkedList<Integer> queue = new LinkedList<>();
+
+ public BoundaryCache(int maxSize) {
+ if (maxSize <= 1) {
+ throw new IllegalArgumentException("Cache size of 1 and below it doesn't
make sense.");
+ }
+ this.maxSize = maxSize;
+ }
+
+ /**
+ * True if the last range(s) of the partition are loaded into the cache.
+ * @return
+ */
+ public boolean isComplete() {
+ return isComplete;
+ }
+
+ public void setComplete(boolean complete) {
+ isComplete = complete;
+ }
+
+ @Override
+ public Object put(Integer key, Object value) {
+ Object result = super.put(key, value);
+ //Every new element is added to FIFO too.
+ if (result == null) {
+ queue.add(key);
+ }
+ //If FIFO size reaches maxSize we evict the eldest entry.
+ if (queue.size() > maxSize) {
+ evictOne();
+ }
+ return result;
+ }
+
+ /**
+ * Puts new key-value pair in cache.
+ * @param key
+ * @param value
+ * @return false if queue was full and put failed. True otherwise.
+ */
+ public Boolean putIfNotFull(Integer key, Object value) {
+ if (isFull()) {
+ return false;
+ } else {
+ put(key, value);
+ return true;
+ }
+ }
+
+ /**
+ * Checks if cache is full.
+ * @return true if full, false otherwise.
+ */
+ public Boolean isFull() {
+ return queue.size() >= maxSize;
+ }
+
+ @Override
+ public void clear() {
+ this.isComplete = false;
+ this.queue.clear();
+ super.clear();
+ }
+
+ /**
+ * Returns entry corresponding to highest row index.
+ * @return max entry.
+ */
+ public Map.Entry<Integer, Object> getMaxEntry() {
+ return floorEntry(Integer.MAX_VALUE);
+ }
+
+ /**
+ * Removes eldest entry from the boundary cache.
+ */
+ public void evictOne() {
+ if (queue.isEmpty()) {
+ return;
+ }
+ Integer elementToDelete = queue.poll();
+ this.remove(elementToDelete);
+ }
+
+ public void evictThisAndAllBefore(int rowIdx) {
+ while (queue.peek() <= rowIdx) {
+ evictOne();
+ }
+ }
+
+}
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java
b/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java
index f125f9b..e17068e 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java
@@ -46,6 +46,7 @@ public class PTFPartition {
StructObjectInspector inputOI;
StructObjectInspector outputOI;
private final PTFRowContainer<List<Object>> elems;
+ private final BoundaryCache boundaryCache;
protected PTFPartition(Configuration cfg,
AbstractSerDe serDe, StructObjectInspector inputOI,
@@ -70,6 +71,8 @@ public class PTFPartition {
} else {
elems = null;
}
+ int boundaryCacheSize = HiveConf.getIntVar(cfg,
ConfVars.HIVE_PTF_RANGECACHE_SIZE);
+ boundaryCache = boundaryCacheSize > 1 ? new
BoundaryCache(boundaryCacheSize) : null;
}
public void reset() throws HiveException {
@@ -262,4 +265,8 @@ public class PTFPartition {
ObjectInspectorCopyOption.WRITABLE);
}
+ public BoundaryCache getBoundaryCache() {
+ return boundaryCache;
+ }
+
}
diff --git
a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java
b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java
index d44604d..20dc862 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java
@@ -256,6 +256,7 @@ public class BasePartitionEvaluator {
end = getRowBoundaryEnd(endB, currRow, p);
} else {
ValueBoundaryScanner vbs = ValueBoundaryScanner.getScanner(winFrame,
nullsLast);
+ vbs.handleCache(currRow, p);
start = vbs.computeStart(currRow, p);
end = vbs.computeEnd(currRow, p);
}
diff --git
a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java
b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java
index e633edb..524812f 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java
@@ -18,10 +18,15 @@
package org.apache.hadoop.hive.ql.udf.ptf;
+import java.util.Map;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.hive.common.type.Date;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.common.type.Timestamp;
import org.apache.hadoop.hive.common.type.TimestampTZ;
+import org.apache.hadoop.hive.ql.exec.BoundaryCache;
import org.apache.hadoop.hive.ql.exec.PTFPartition;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order;
@@ -44,10 +49,258 @@ public abstract class ValueBoundaryScanner {
this.nullsLast = nullsLast;
}
+ public abstract Object computeValue(Object row) throws HiveException;
+
+ /**
+ * Checks if the distance of v2 to v1 is greater than the given amt.
+ * @return True if the value of v1 - v2 is greater than amt or either value
is null.
+ */
+ public abstract boolean isDistanceGreater(Object v1, Object v2, int amt);
+
+ /**
+ * Checks if the values of v1 or v2 are the same.
+ * @return True if both values are the same or both are nulls.
+ */
+ public abstract boolean isEqual(Object v1, Object v2);
+
public abstract int computeStart(int rowIdx, PTFPartition p) throws
HiveException;
public abstract int computeEnd(int rowIdx, PTFPartition p) throws
HiveException;
+ /**
+ * Checks and maintains cache content - optimizes cache window to always be
around current row
+ * thereby makes it follow the current progress.
+ * @param rowIdx current row
+ * @param p current partition for the PTF operator
+ * @throws HiveException
+ */
+ public void handleCache(int rowIdx, PTFPartition p) throws HiveException {
+ BoundaryCache cache = p.getBoundaryCache();
+ if (cache == null) {
+ return;
+ }
+
+ //No need to setup/fill cache.
+ if (start.isUnbounded() && end.isUnbounded()) {
+ return;
+ }
+
+ //Start of partition.
+ if (rowIdx == 0) {
+ cache.clear();
+ }
+ if (cache.isComplete()) {
+ return;
+ }
+ if (cache.isEmpty()) {
+ fillCacheUntilEndOrFull(rowIdx, p);
+ return;
+ }
+
+ if (start.isPreceding()) {
+ if (start.isUnbounded()) {
+ if (end.isPreceding()) {
+ //We can wait with cache eviction until we're at the end of
currently known ranges.
+ Map.Entry<Integer, Object> maxEntry = cache.getMaxEntry();
+ if (maxEntry != null && maxEntry.getKey() <= rowIdx) {
+ cache.evictOne();
+ }
+ } else {
+ //Starting from current row, all previous ranges can be evicted.
+ checkIfCacheCanEvict(rowIdx, p, true);
+ }
+ } else {
+ //We either evict when we're at the end of currently known ranges, or
if not there yet and
+ // END is of FOLLOWING type: we should remove ranges preceding the
current range beginning.
+ Map.Entry<Integer, Object> maxEntry = cache.getMaxEntry();
+ if (maxEntry != null && maxEntry.getKey() <= rowIdx) {
+ cache.evictOne();
+ } else if (end.isFollowing()) {
+ int startIdx = computeStart(rowIdx, p);
+ checkIfCacheCanEvict(startIdx - 1, p, true);
+ }
+ }
+ }
+
+ if (start.isCurrentRow()) {
+ //Starting from current row, all previous ranges before the previous
range can be evicted.
+ checkIfCacheCanEvict(rowIdx, p, false);
+ }
+ if (start.isFollowing()) {
+ //Starting from current row, all previous ranges can be evicted.
+ checkIfCacheCanEvict(rowIdx, p, true);
+ }
+
+ fillCacheUntilEndOrFull(rowIdx, p);
+ }
+
+ /**
+ * Retrieves the range for rowIdx, then removes all previous range entries
before it.
+ * @param rowIdx row index.
+ * @param p partition.
+ * @param willScanFwd false: removal is started only from the previous
previous range.
+ */
+ private void checkIfCacheCanEvict(int rowIdx, PTFPartition p, boolean
willScanFwd) {
+ BoundaryCache cache = p.getBoundaryCache();
+ if (cache == null) {
+ return;
+ }
+ Map.Entry<Integer, Object> floorEntry = cache.floorEntry(rowIdx);
+ if (floorEntry != null) {
+ floorEntry = cache.floorEntry(floorEntry.getKey() - 1);
+ if (floorEntry != null) {
+ if (willScanFwd) {
+ cache.evictThisAndAllBefore(floorEntry.getKey());
+ } else {
+ floorEntry = cache.floorEntry(floorEntry.getKey() - 1);
+ if (floorEntry != null) {
+ cache.evictThisAndAllBefore(floorEntry.getKey());
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Inserts values into cache starting from rowIdx in the current partition
p. Stops if cache
+ * reaches its maximum size or we get out of rows in p.
+ * @param rowIdx
+ * @param p
+ * @throws HiveException
+ */
+ private void fillCacheUntilEndOrFull(int rowIdx, PTFPartition p) throws
HiveException {
+ BoundaryCache cache = p.getBoundaryCache();
+ if (cache == null || p.size() <= 0) {
+ return;
+ }
+
+ Object rowVal = null;
+
+ //If we continue building cache
+ Map.Entry<Integer, Object> ceilingEntry = cache.getMaxEntry();
+ if (ceilingEntry != null) {
+ rowIdx = ceilingEntry.getKey();
+ rowVal = ceilingEntry.getValue();
+ ++rowIdx;
+ }
+
+ Object lastRowVal = rowVal;
+
+ while (rowIdx < p.size() && !cache.isFull()) {
+ rowVal = computeValue(p.getAt(rowIdx));
+ if (!isEqual(rowVal, lastRowVal)){
+ cache.put(rowIdx, rowVal);
+ }
+ lastRowVal = rowVal;
+ ++rowIdx;
+
+ }
+ //Signaling end of all rows in a partition
+ if (cache.putIfNotFull(rowIdx, null)) {
+ cache.setComplete(true);
+ }
+ }
+
+ /**
+ * Uses cache content to jump backwards if possible. If not, it steps one
back.
+ * @param r
+ * @param p
+ * @return pair of (row we stepped/jumped onto ; row value at this position)
+ * @throws HiveException
+ */
+ protected Pair<Integer, Object> skipOrStepBack(int r, PTFPartition p)
+ throws HiveException {
+ Object rowVal = null;
+ BoundaryCache cache = p.getBoundaryCache();
+
+ Map.Entry<Integer, Object> floorEntry = null;
+ Map.Entry<Integer, Object> ceilingEntry = null;
+
+ if (cache != null) {
+ floorEntry = cache.floorEntry(r);
+ ceilingEntry = cache.ceilingEntry(r);
+ }
+
+ if (floorEntry != null && ceilingEntry != null) {
+ r = floorEntry.getKey() - 1;
+ floorEntry = cache.floorEntry(r);
+ if (floorEntry != null) {
+ rowVal = floorEntry.getValue();
+ } else if (r >= 0){
+ rowVal = computeValue(p.getAt(r));
+ }
+ } else {
+ r--;
+ if (r >= 0) {
+ rowVal = computeValue(p.getAt(r));
+ }
+ }
+ return new ImmutablePair<>(r, rowVal);
+ }
+
+ /**
+ * Uses cache content to jump forward if possible. If not, it steps one
forward.
+ * @param r
+ * @param p
+ * @return pair of (row we stepped/jumped onto ; row value at this position)
+ * @throws HiveException
+ */
+ protected Pair<Integer, Object> skipOrStepForward(int r, PTFPartition p)
+ throws HiveException {
+ Object rowVal = null;
+ BoundaryCache cache = p.getBoundaryCache();
+
+ Map.Entry<Integer, Object> floorEntry = null;
+ Map.Entry<Integer, Object> ceilingEntry = null;
+
+ if (cache != null) {
+ floorEntry = cache.floorEntry(r);
+ ceilingEntry = cache.ceilingEntry(r);
+ }
+
+ if (ceilingEntry != null && ceilingEntry.getKey().equals(r)){
+ ceilingEntry = cache.ceilingEntry(r + 1);
+ }
+ if (floorEntry != null && ceilingEntry != null) {
+ r = ceilingEntry.getKey();
+ rowVal = ceilingEntry.getValue();
+ } else {
+ r++;
+ if (r < p.size()) {
+ rowVal = computeValue(p.getAt(r));
+ }
+ }
+ return new ImmutablePair<>(r, rowVal);
+ }
+
+ /**
+ * Uses cache to lookup row value. Computes it on the fly on cache miss.
+ * @param r
+ * @param p
+ * @return row value.
+ * @throws HiveException
+ */
+ protected Object computeValueUseCache(int r, PTFPartition p) throws
HiveException {
+ BoundaryCache cache = p.getBoundaryCache();
+
+ Map.Entry<Integer, Object> floorEntry = null;
+ Map.Entry<Integer, Object> ceilingEntry = null;
+
+ if (cache != null) {
+ floorEntry = cache.floorEntry(r);
+ ceilingEntry = cache.ceilingEntry(r);
+ }
+
+ if (ceilingEntry != null && ceilingEntry.getKey().equals(r)){
+ return ceilingEntry.getValue();
+ }
+ if (floorEntry != null && ceilingEntry != null) {
+ return floorEntry.getValue();
+ } else {
+ return computeValue(p.getAt(r));
+ }
+ }
+
public static ValueBoundaryScanner getScanner(WindowFrameDef winFrameDef,
boolean nullsLast)
throws HiveException {
OrderDef orderDef = winFrameDef.getOrderDef();
@@ -108,6 +361,7 @@ abstract class SingleValueBoundaryScanner extends
ValueBoundaryScanner {
| | | | | | such that R2.sk
- R.sk > amt |
|------+----------------+----------------+----------+-------+-----------------------------------|
*/
+
@Override
public int computeStart(int rowIdx, PTFPartition p) throws HiveException {
switch(start.getDirection()) {
@@ -127,18 +381,17 @@ abstract class SingleValueBoundaryScanner extends
ValueBoundaryScanner {
if ( amt == BoundarySpec.UNBOUNDED_AMOUNT ) {
return 0;
}
- Object sortKey = computeValue(p.getAt(rowIdx));
+ Object sortKey = computeValueUseCache(rowIdx, p);
if ( sortKey == null ) {
// Use Case 3.
if (nullsLast || expressionDef.getOrder() == Order.DESC) {
while ( sortKey == null && rowIdx >= 0 ) {
- --rowIdx;
- if ( rowIdx >= 0 ) {
- sortKey = computeValue(p.getAt(rowIdx));
- }
+ Pair<Integer, Object> stepResult = skipOrStepBack(rowIdx, p);
+ rowIdx = stepResult.getLeft();
+ sortKey = stepResult.getRight();
}
- return rowIdx+1;
+ return rowIdx + 1;
}
else { // Use Case 2.
if ( expressionDef.getOrder() == Order.ASC ) {
@@ -153,36 +406,34 @@ abstract class SingleValueBoundaryScanner extends
ValueBoundaryScanner {
// Use Case 4.
if ( expressionDef.getOrder() == Order.DESC ) {
while (r >= 0 && !isDistanceGreater(rowVal, sortKey, amt) ) {
- r--;
- if ( r >= 0 ) {
- rowVal = computeValue(p.getAt(r));
- }
+ Pair<Integer, Object> stepResult = skipOrStepBack(r, p);
+ r = stepResult.getLeft();
+ rowVal = stepResult.getRight();
}
return r + 1;
}
else { // Use Case 5.
while (r >= 0 && !isDistanceGreater(sortKey, rowVal, amt) ) {
- r--;
- if ( r >= 0 ) {
- rowVal = computeValue(p.getAt(r));
- }
+ Pair<Integer, Object> stepResult = skipOrStepBack(r, p);
+ r = stepResult.getLeft();
+ rowVal = stepResult.getRight();
}
+
return r + 1;
}
}
protected int computeStartCurrentRow(int rowIdx, PTFPartition p) throws
HiveException {
- Object sortKey = computeValue(p.getAt(rowIdx));
+ Object sortKey = computeValueUseCache(rowIdx, p);
// Use Case 6.
if ( sortKey == null ) {
while ( sortKey == null && rowIdx >= 0 ) {
- --rowIdx;
- if ( rowIdx >= 0 ) {
- sortKey = computeValue(p.getAt(rowIdx));
- }
+ Pair<Integer, Object> stepResult = skipOrStepBack(rowIdx, p);
+ rowIdx = stepResult.getLeft();
+ sortKey = stepResult.getRight();
}
- return rowIdx+1;
+ return rowIdx + 1;
}
Object rowVal = sortKey;
@@ -190,17 +441,16 @@ abstract class SingleValueBoundaryScanner extends
ValueBoundaryScanner {
// Use Case 7.
while (r >= 0 && isEqual(rowVal, sortKey) ) {
- r--;
- if ( r >= 0 ) {
- rowVal = computeValue(p.getAt(r));
- }
+ Pair<Integer, Object> stepResult = skipOrStepBack(r, p);
+ r = stepResult.getLeft();
+ rowVal = stepResult.getRight();
}
return r + 1;
}
protected int computeStartFollowing(int rowIdx, PTFPartition p) throws
HiveException {
int amt = start.getAmt();
- Object sortKey = computeValue(p.getAt(rowIdx));
+ Object sortKey = computeValueUseCache(rowIdx, p);
Object rowVal = sortKey;
int r = rowIdx;
@@ -212,10 +462,9 @@ abstract class SingleValueBoundaryScanner extends
ValueBoundaryScanner {
}
else { // Use Case 10.
while (r < p.size() && rowVal == null ) {
- r++;
- if ( r < p.size() ) {
- rowVal = computeValue(p.getAt(r));
- }
+ Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+ r = stepResult.getLeft();
+ rowVal = stepResult.getRight();
}
return r;
}
@@ -224,19 +473,17 @@ abstract class SingleValueBoundaryScanner extends
ValueBoundaryScanner {
// Use Case 11.
if ( expressionDef.getOrder() == Order.DESC) {
while (r < p.size() && !isDistanceGreater(sortKey, rowVal, amt) ) {
- r++;
- if ( r < p.size() ) {
- rowVal = computeValue(p.getAt(r));
- }
+ Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+ r = stepResult.getLeft();
+ rowVal = stepResult.getRight();
}
return r;
}
else { // Use Case 12.
while (r < p.size() && !isDistanceGreater(rowVal, sortKey, amt) ) {
- r++;
- if ( r < p.size() ) {
- rowVal = computeValue(p.getAt(r));
- }
+ Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+ r = stepResult.getLeft();
+ rowVal = stepResult.getRight();
}
return r;
}
@@ -292,7 +539,7 @@ abstract class SingleValueBoundaryScanner extends
ValueBoundaryScanner {
// Use Case 1.
// amt == UNBOUNDED, is caught during translation
- Object sortKey = computeValue(p.getAt(rowIdx));
+ Object sortKey = computeValueUseCache(rowIdx, p);
if ( sortKey == null ) {
// Use Case 2.
@@ -310,34 +557,31 @@ abstract class SingleValueBoundaryScanner extends
ValueBoundaryScanner {
// Use Case 4.
if ( expressionDef.getOrder() == Order.DESC ) {
while (r >= 0 && !isDistanceGreater(rowVal, sortKey, amt) ) {
- r--;
- if ( r >= 0 ) {
- rowVal = computeValue(p.getAt(r));
- }
+ Pair<Integer, Object> stepResult = skipOrStepBack(r, p);
+ r = stepResult.getLeft();
+ rowVal = stepResult.getRight();
}
return r + 1;
}
else { // Use Case 5.
while (r >= 0 && !isDistanceGreater(sortKey, rowVal, amt) ) {
- r--;
- if ( r >= 0 ) {
- rowVal = computeValue(p.getAt(r));
- }
+ Pair<Integer, Object> stepResult = skipOrStepBack(r, p);
+ r = stepResult.getLeft();
+ rowVal = stepResult.getRight();
}
return r + 1;
}
}
protected int computeEndCurrentRow(int rowIdx, PTFPartition p) throws
HiveException {
- Object sortKey = computeValue(p.getAt(rowIdx));
+ Object sortKey = computeValueUseCache(rowIdx, p);
// Use Case 6.
if ( sortKey == null ) {
while ( sortKey == null && rowIdx < p.size() ) {
- ++rowIdx;
- if ( rowIdx < p.size() ) {
- sortKey = computeValue(p.getAt(rowIdx));
- }
+ Pair<Integer, Object> stepResult = skipOrStepForward(rowIdx, p);
+ rowIdx = stepResult.getLeft();
+ sortKey = stepResult.getRight();
}
return rowIdx;
}
@@ -347,10 +591,9 @@ abstract class SingleValueBoundaryScanner extends
ValueBoundaryScanner {
// Use Case 7.
while (r < p.size() && isEqual(sortKey, rowVal) ) {
- r++;
- if ( r < p.size() ) {
- rowVal = computeValue(p.getAt(r));
- }
+ Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+ r = stepResult.getLeft();
+ rowVal = stepResult.getRight();
}
return r;
}
@@ -362,7 +605,7 @@ abstract class SingleValueBoundaryScanner extends
ValueBoundaryScanner {
if ( amt == BoundarySpec.UNBOUNDED_AMOUNT ) {
return p.size();
}
- Object sortKey = computeValue(p.getAt(rowIdx));
+ Object sortKey = computeValueUseCache(rowIdx, p);
Object rowVal = sortKey;
int r = rowIdx;
@@ -374,10 +617,9 @@ abstract class SingleValueBoundaryScanner extends
ValueBoundaryScanner {
}
else { // Use Case 10.
while (r < p.size() && rowVal == null ) {
- r++;
- if ( r < p.size() ) {
- rowVal = computeValue(p.getAt(r));
- }
+ Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+ r = stepResult.getLeft();
+ rowVal = stepResult.getRight();
}
return r;
}
@@ -386,19 +628,17 @@ abstract class SingleValueBoundaryScanner extends
ValueBoundaryScanner {
// Use Case 11.
if ( expressionDef.getOrder() == Order.DESC) {
while (r < p.size() && !isDistanceGreater(sortKey, rowVal, amt) ) {
- r++;
- if ( r < p.size() ) {
- rowVal = computeValue(p.getAt(r));
- }
+ Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+ r = stepResult.getLeft();
+ rowVal = stepResult.getRight();
}
return r;
}
else { // Use Case 12.
while (r < p.size() && !isDistanceGreater(rowVal, sortKey, amt) ) {
- r++;
- if ( r < p.size() ) {
- rowVal = computeValue(p.getAt(r));
- }
+ Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+ r = stepResult.getLeft();
+ rowVal = stepResult.getRight();
}
return r;
}
@@ -717,15 +957,14 @@ class StringValueBoundaryScanner extends
SingleValueBoundaryScanner {
}
protected int computeStartCurrentRow(int rowIdx, PTFPartition p) throws
HiveException {
- Object[] sortKey = computeValues(p.getAt(rowIdx));
- Object[] rowVal = sortKey;
+ Object sortKey = computeValueUseCache(rowIdx, p);
+ Object rowVal = sortKey;
int r = rowIdx;
while (r >= 0 && isEqual(rowVal, sortKey) ) {
- r--;
- if ( r >= 0 ) {
- rowVal = computeValues(p.getAt(r));
- }
+ Pair<Integer, Object> stepResult = skipOrStepBack(r, p);
+ r = stepResult.getLeft();
+ rowVal = stepResult.getRight();
}
return r + 1;
}
@@ -741,6 +980,7 @@ class StringValueBoundaryScanner extends
SingleValueBoundaryScanner {
| 2. | FOLLOWING | UNB | ANY | ANY | end =
partition.size() |
|------+----------------+---------------+----------+-------+-----------------------------------|
*/
+
@Override
public int computeEnd(int rowIdx, PTFPartition p) throws HiveException {
switch(end.getDirection()) {
@@ -756,15 +996,14 @@ class StringValueBoundaryScanner extends
SingleValueBoundaryScanner {
}
protected int computeEndCurrentRow(int rowIdx, PTFPartition p) throws
HiveException {
- Object[] sortKey = computeValues(p.getAt(rowIdx));
- Object[] rowVal = sortKey;
+ Object sortKey = computeValueUseCache(rowIdx, p);
+ Object rowVal = sortKey;
int r = rowIdx;
while (r < p.size() && isEqual(sortKey, rowVal) ) {
- r++;
- if ( r < p.size() ) {
- rowVal = computeValues(p.getAt(r));
- }
+ Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+ r = stepResult.getLeft();
+ rowVal = stepResult.getRight();
}
return r;
}
@@ -778,7 +1017,8 @@ class StringValueBoundaryScanner extends
SingleValueBoundaryScanner {
"FOLLOWING needs UNBOUNDED for RANGE with multiple expressions in
ORDER BY");
}
- public Object[] computeValues(Object row) throws HiveException {
+ @Override
+ public Object computeValue(Object row) throws HiveException {
Object[] objs = new Object[orderDef.getExpressions().size()];
for (int i = 0; i < objs.length; i++) {
Object o =
orderDef.getExpressions().get(i).getExprEvaluator().evaluate(row);
@@ -787,7 +1027,14 @@ class StringValueBoundaryScanner extends
SingleValueBoundaryScanner {
return objs;
}
- public boolean isEqual(Object[] v1, Object[] v2) {
+ @Override
+ public boolean isEqual(Object val1, Object val2) {
+ if (val1 == null || val2 == null) {
+ return (val1 == null && val2 == null);
+ }
+ Object[] v1 = (Object[]) val1;
+ Object[] v2 = (Object[]) val2;
+
assert v1.length == v2.length;
for (int i = 0; i < v1.length; i++) {
if (v1[i] == null && v2[i] == null) {
@@ -804,5 +1051,10 @@ class StringValueBoundaryScanner extends
SingleValueBoundaryScanner {
}
return true;
}
+
+ @Override
+ public boolean isDistanceGreater(Object v1, Object v2, int amt) {
+ throw new UnsupportedOperationException("Only unbounded ranges supported");
+ }
}
diff --git
a/ql/src/test/org/apache/hadoop/hive/ql/udf/ptf/TestBoundaryCache.java
b/ql/src/test/org/apache/hadoop/hive/ql/udf/ptf/TestBoundaryCache.java
new file mode 100644
index 0000000..714c51b
--- /dev/null
+++ b/ql/src/test/org/apache/hadoop/hive/ql/udf/ptf/TestBoundaryCache.java
@@ -0,0 +1,295 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.udf.ptf;
+
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.hadoop.hive.ql.exec.BoundaryCache;
+import org.apache.hadoop.hive.ql.exec.PTFPartition;
+import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec;
+import org.apache.hadoop.hive.ql.parse.WindowingSpec;
+import org.apache.hadoop.hive.ql.plan.ptf.BoundaryDef;
+import org.apache.hadoop.hive.ql.plan.ptf.OrderExpressionDef;
+import org.apache.hadoop.io.IntWritable;
+
+import com.google.common.collect.Lists;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static java.util.Optional.ofNullable;
+import static java.util.stream.Collectors.toCollection;
+import static java.util.stream.Collectors.toList;
+import static org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order.ASC;
+import static org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order.DESC;
+import static
org.apache.hadoop.hive.ql.parse.WindowingSpec.BoundarySpec.UNBOUNDED_AMOUNT;
+import static org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction.CURRENT;
+import static
org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction.FOLLOWING;
+import static
org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction.PRECEDING;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests BoundaryCache used for RANGE windows in PTF functions.
+ */
+public class TestBoundaryCache {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(TestBoundaryCache.class);
+ private static final LinkedList<List<IntWritable>> TEST_PARTITION = new
LinkedList<>();
+ //Null for using no cache at all, 2 is minimum cache length, 5-9-15 for
checking with smaller,
+ // exactly equal and larger cache than needed.
+ private static final List<Integer> CACHE_SIZES = Lists.newArrayList(null, 2,
5, 9, 15);
+ private static final List<PTFInvocationSpec.Order> ORDERS =
Lists.newArrayList(ASC, DESC);
+ private static final int ORDER_BY_COL = 2;
+
+ @BeforeClass
+ public static void setupTests() throws Exception {
+ //8 ranges, max cache content is 8+1=9 entries
+ addRow(TEST_PARTITION, 1, 1, -7);
+ addRow(TEST_PARTITION, 2, 1, -1);
+ addRow(TEST_PARTITION, 3, 1, -1);
+ addRow(TEST_PARTITION, 4, 1, 1);
+ addRow(TEST_PARTITION, 5, 1, 1);
+ addRow(TEST_PARTITION, 6, 1, 1);
+ addRow(TEST_PARTITION, 7, 1, 1);
+ addRow(TEST_PARTITION, 8, 1, 2);
+ addRow(TEST_PARTITION, 9, 1, 2);
+ addRow(TEST_PARTITION, 10, 1, 2);
+ addRow(TEST_PARTITION, 11, 1, 2);
+ addRow(TEST_PARTITION, 12, 1, 3);
+ addRow(TEST_PARTITION, 13, 1, 5);
+ addRow(TEST_PARTITION, 14, 1, 5);
+ addRow(TEST_PARTITION, 15, 1, 5);
+ addRow(TEST_PARTITION, 16, 1, 5);
+ addRow(TEST_PARTITION, 17, 1, 6);
+ addRow(TEST_PARTITION, 18, 1, 6);
+ addRow(TEST_PARTITION, 19, 1, 9);
+ addRow(TEST_PARTITION, 20, 1, null);
+ addRow(TEST_PARTITION, 21, 1, null);
+
+ }
+
+ @Test
+ public void testPrecedingUnboundedFollowingUnbounded() throws Exception {
+ runTest(PRECEDING, UNBOUNDED_AMOUNT, FOLLOWING, UNBOUNDED_AMOUNT);
+ }
+
+ @Test
+ public void testPrecedingUnboundedCurrentRow() throws Exception {
+ runTest(PRECEDING, UNBOUNDED_AMOUNT, CURRENT, 0);
+ }
+
+ @Test
+ public void testPrecedingUnboundedPreceding2() throws Exception {
+ runTest(PRECEDING, UNBOUNDED_AMOUNT, PRECEDING, 2);
+ }
+
+ @Test
+ public void testPreceding4Preceding1() throws Exception {
+ runTest(PRECEDING, 4, PRECEDING, 1);
+ }
+
+ @Test
+ public void testPreceding2CurrentRow() throws Exception {
+ runTest(PRECEDING, 2, CURRENT, 0);
+ }
+
+ @Test
+ public void testPreceding2Following100() throws Exception {
+ runTest(PRECEDING, 1, FOLLOWING, 100);
+ }
+
+ @Test
+ public void testCurrentRowFollowing3() throws Exception {
+ runTest(CURRENT, 0, FOLLOWING, 3);
+ }
+
+ @Test
+ public void testCurrentRowFFollowingUnbounded() throws Exception {
+ runTest(CURRENT, 0, FOLLOWING, UNBOUNDED_AMOUNT);
+ }
+
+ @Test
+ public void testFollowing2Following4() throws Exception {
+ runTest(FOLLOWING, 2, FOLLOWING, 4);
+ }
+
+ @Test
+ public void testFollowing2FollowingUnbounded() throws Exception {
+ runTest(FOLLOWING, 2, FOLLOWING, UNBOUNDED_AMOUNT);
+ }
+
+ /**
+ * Executes test on a given window definition. Such a test will be executed
against the values set
+ * in ORDERS and CACHE_SIZES, validating ORDERS X CACHE_SIZES test cases.
Cache size of null will
+ * be used to setup baseline.
+ * @param startDirection
+ * @param startAmount
+ * @param endDirection
+ * @param endAmount
+ * @throws Exception
+ */
+ private void runTest(WindowingSpec.Direction startDirection, int startAmount,
+ WindowingSpec.Direction endDirection, int endAmount)
throws Exception {
+
+ BoundaryDef startBoundary = new BoundaryDef(startDirection, startAmount);
+ BoundaryDef endBoundary = new BoundaryDef(endDirection, endAmount);
+ AtomicInteger readCounter = new AtomicInteger(0);
+
+ int[] expectedBoundaryStarts = new int[TEST_PARTITION.size()];
+ int[] expectedBoundaryEnds = new int[TEST_PARTITION.size()];
+ int expectedReadCountWithoutCache = -1;
+
+ for (PTFInvocationSpec.Order order : ORDERS) {
+ for (Integer cacheSize : CACHE_SIZES) {
+ LOG.info(Thread.currentThread().getStackTrace()[2].getMethodName());
+ LOG.info("Cache: " + cacheSize + " order: " + order);
+ BoundaryCache cache = cacheSize == null ? null : new
BoundaryCache(cacheSize);
+ Pair<PTFPartition, ValueBoundaryScanner> mocks =
setupMocks(TEST_PARTITION,
+ ORDER_BY_COL, startBoundary, endBoundary, order, cache,
readCounter);
+ PTFPartition ptfPartition = mocks.getLeft();
+ ValueBoundaryScanner scanner = mocks.getRight();
+ for (int i = 0; i < TEST_PARTITION.size(); ++i) {
+ scanner.handleCache(i, ptfPartition);
+ int start = scanner.computeStart(i, ptfPartition);
+ int end = scanner.computeEnd(i, ptfPartition) - 1;
+ if (cache == null) {
+ //Cache-less version should be baseline
+ expectedBoundaryStarts[i] = start;
+ expectedBoundaryEnds[i] = end;
+ } else {
+ assertEquals(expectedBoundaryStarts[i], start);
+ assertEquals(expectedBoundaryEnds[i], end);
+ }
+ Integer col0 = ofNullable(TEST_PARTITION.get(i).get(0)).map(v ->
v.get()).orElse(null);
+ Integer col1 = ofNullable(TEST_PARTITION.get(i).get(1)).map(v ->
v.get()).orElse(null);
+ Integer col2 = ofNullable(TEST_PARTITION.get(i).get(2)).map(v ->
v.get()).orElse(null);
+ LOG.info(String.format("%d|\t%d\t%d\t%d\t|%d-%d", i, col0, col1,
col2, start, end));
+ }
+ if (cache == null) {
+ expectedReadCountWithoutCache = readCounter.get();
+ } else {
+ //Read count should be smaller with cache being used, but larger
than the minimum of
+ // reading every row once.
+ assertTrue(expectedReadCountWithoutCache >= readCounter.get());
+ if (startAmount != UNBOUNDED_AMOUNT || endAmount !=
UNBOUNDED_AMOUNT) {
+ assertTrue(TEST_PARTITION.size() <= readCounter.get());
+ }
+ }
+ readCounter.set(0);
+ }
+ }
+ }
+
+ /**
+ * Sets up mock and spy objects used for testing.
+ * @param partition The real partition containing row values.
+ * @param orderByCol Index of column in the row used for separating ranges.
+ * @param start Window definition.
+ * @param end Window definition.
+ * @param order Window definition.
+ * @param cache BoundaryCache instance, it may come in various sizes.
+ * @param readCounter counts how many times reading was invoked
+ * @return Mocked PTFPartition instance and ValueBoundaryScanner spy.
+ * @throws Exception
+ */
+ private static Pair<PTFPartition, ValueBoundaryScanner> setupMocks(
+ List<List<IntWritable>> partition, int orderByCol, BoundaryDef
start, BoundaryDef end,
+ PTFInvocationSpec.Order order, BoundaryCache cache,
+ AtomicInteger readCounter) throws Exception {
+ PTFPartition partitionMock = mock(PTFPartition.class);
+ doAnswer(invocationOnMock -> {
+ int idx = invocationOnMock.getArgumentAt(0, Integer.class);
+ return partition.get(idx);
+ }).when(partitionMock).getAt(any(Integer.class));
+ doAnswer(invocationOnMock -> {
+ return partition.size();
+ }).when(partitionMock).size();
+ when(partitionMock.getBoundaryCache()).thenReturn(cache);
+
+ OrderExpressionDef orderDef = mock(OrderExpressionDef.class);
+ when(orderDef.getOrder()).thenReturn(order);
+
+ ValueBoundaryScanner scan = new LongValueBoundaryScanner(start, end,
orderDef, order == ASC);
+ ValueBoundaryScanner scannerSpy = spy(scan);
+ doAnswer(invocationOnMock -> {
+ readCounter.incrementAndGet();
+ List<IntWritable> row = invocationOnMock.getArgumentAt(0, List.class);
+ return row.get(orderByCol);
+ }).when(scannerSpy).computeValue(any(Object.class));
+ doAnswer(invocationOnMock -> {
+ IntWritable v1 = invocationOnMock.getArgumentAt(0, IntWritable.class);
+ IntWritable v2 = invocationOnMock.getArgumentAt(1, IntWritable.class);
+ return (v1 != null && v2 != null) ? v1.get() == v2.get() : v1 == null &&
v2 == null;
+ }).when(scannerSpy).isEqual(any(Object.class), any(Object.class));
+ doAnswer(invocationOnMock -> {
+ IntWritable v1 = invocationOnMock.getArgumentAt(0, IntWritable.class);
+ IntWritable v2 = invocationOnMock.getArgumentAt(1, IntWritable.class);
+ Integer amt = invocationOnMock.getArgumentAt(2, Integer.class);
+ return (v1 != null && v2 != null) ? (v1.get() - v2.get()) > amt : v1 !=
null || v2 != null;
+ }).when(scannerSpy).isDistanceGreater(any(Object.class),
any(Object.class), any(Integer.class));
+
+ setOrderOnTestPartitions(order);
+ return new ImmutablePair<>(partitionMock, scannerSpy);
+
+ }
+
+ private static void addRow(List<List<IntWritable>> partition, Integer col0,
Integer col1,
+ Integer col2) {
+ partition.add(Lists.newArrayList(
+ col0 != null ? new IntWritable(col0) : null,
+ col1 != null ? new IntWritable(col1) : null,
+ col2 != null ? new IntWritable(col2) : null
+ ));
+ }
+
+ /**
+ * Reverses order on actual data if needed, based on order parameter.
+ * @param order
+ */
+ private static void setOrderOnTestPartitions(PTFInvocationSpec.Order order) {
+ LinkedList<List<IntWritable>> notNulls = TEST_PARTITION.stream().filter(
+ r -> r.get(ORDER_BY_COL) !=
null).collect(toCollection(LinkedList::new));
+ List<List<IntWritable>> nulls = TEST_PARTITION.stream().filter(
+ r -> r.get(ORDER_BY_COL) == null).collect(toList());
+
+ boolean isAscCurrently = notNulls.getFirst().get(ORDER_BY_COL).get() <
+ notNulls.getLast().get(ORDER_BY_COL).get();
+
+ if ((ASC.equals(order) && !isAscCurrently) || (DESC.equals(order) &&
isAscCurrently)) {
+ Collections.reverse(notNulls);
+ TEST_PARTITION.clear();
+ TEST_PARTITION.addAll(notNulls);
+ TEST_PARTITION.addAll(nulls);
+ }
+ }
+
+}