This is an automated email from the ASF dual-hosted git repository.

szita pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hive.git


The following commit(s) were added to refs/heads/master by this push:
     new a8bd1eb  HIVE-21217: Optimize range calculation for PTF (Adam Szita, 
reviewed by Peter Vary)
a8bd1eb is described below

commit a8bd1eb09e3feb10d68075c6bc676ec333f498da
Author: Adam Szita <[email protected]>
AuthorDate: Wed Feb 20 10:26:57 2019 +0100

    HIVE-21217: Optimize range calculation for PTF (Adam Szita, reviewed by 
Peter Vary)
---
 .../java/org/apache/hadoop/hive/conf/HiveConf.java |   4 +
 .../apache/hadoop/hive/ql/exec/BoundaryCache.java  | 124 ++++++
 .../apache/hadoop/hive/ql/exec/PTFPartition.java   |   7 +
 .../hive/ql/udf/ptf/BasePartitionEvaluator.java    |   1 +
 .../hive/ql/udf/ptf/ValueBoundaryScanner.java      | 416 +++++++++++++++++----
 .../hadoop/hive/ql/udf/ptf/TestBoundaryCache.java  | 295 +++++++++++++++
 6 files changed, 765 insertions(+), 82 deletions(-)

diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java 
b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
index 4a86b0a..11f165a 100644
--- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
+++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
@@ -1634,6 +1634,10 @@ public class HiveConf extends Configuration {
         + "the evaluation of certain joins, since we will not be emitting rows 
which are thrown away by "
         + "a Filter operator straight away. However, currently vectorization 
does not support them, thus "
         + "enabling it is only recommended when vectorization is disabled."),
+    HIVE_PTF_RANGECACHE_SIZE("hive.ptf.rangecache.size", 10000,
+        "Size of the cache used on reducer side, that stores boundaries of 
ranges within a PTF " +
+        "partition. Used if a query specifies a RANGE type window including an 
orderby clause." +
+        "Set this to 0 to disable this cache."),
 
     // CBO related
     HIVE_CBO_ENABLED("hive.cbo.enable", true, "Flag to control enabling Cost 
Based Optimizations using Calcite framework."),
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/BoundaryCache.java 
b/ql/src/java/org/apache/hadoop/hive/ql/exec/BoundaryCache.java
new file mode 100644
index 0000000..7cf278c
--- /dev/null
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/BoundaryCache.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec;
+
+import java.util.LinkedList;
+import java.util.Map;
+import java.util.TreeMap;
+
+/**
+ * Cache for storing boundaries found within a partition - used for PTF 
functions.
+ * Stores key-value pairs where key is the row index in the partition from 
which a range begins,
+ * value is the corresponding row value (based on what the user specified in 
the orderby column).
+ */
+public class BoundaryCache extends TreeMap<Integer, Object> {
+
+  private boolean isComplete = false;
+  private final int maxSize;
+  private final LinkedList<Integer> queue = new LinkedList<>();
+
+  public BoundaryCache(int maxSize) {
+    if (maxSize <= 1) {
+      throw new IllegalArgumentException("Cache size of 1 and below it doesn't 
make sense.");
+    }
+    this.maxSize = maxSize;
+  }
+
+  /**
+   * True if the last range(s) of the partition are loaded into the cache.
+   * @return
+   */
+  public boolean isComplete() {
+    return isComplete;
+  }
+
+  public void setComplete(boolean complete) {
+    isComplete = complete;
+  }
+
+  @Override
+  public Object put(Integer key, Object value) {
+    Object result = super.put(key, value);
+    //Every new element is added to FIFO too.
+    if (result == null) {
+      queue.add(key);
+    }
+    //If FIFO size reaches maxSize we evict the eldest entry.
+    if (queue.size() > maxSize) {
+      evictOne();
+    }
+    return result;
+  }
+
+  /**
+   * Puts new key-value pair in cache.
+   * @param key
+   * @param value
+   * @return false if queue was full and put failed. True otherwise.
+   */
+  public Boolean putIfNotFull(Integer key, Object value) {
+    if (isFull()) {
+      return false;
+    } else {
+      put(key, value);
+      return true;
+    }
+  }
+
+  /**
+   * Checks if cache is full.
+   * @return true if full, false otherwise.
+   */
+  public Boolean isFull() {
+    return queue.size() >= maxSize;
+  }
+
+  @Override
+  public void clear() {
+    this.isComplete = false;
+    this.queue.clear();
+    super.clear();
+  }
+
+  /**
+   * Returns entry corresponding to highest row index.
+   * @return max entry.
+   */
+  public Map.Entry<Integer, Object> getMaxEntry() {
+    return floorEntry(Integer.MAX_VALUE);
+  }
+
+  /**
+   * Removes eldest entry from the boundary cache.
+   */
+  public void evictOne() {
+    if (queue.isEmpty()) {
+      return;
+    }
+    Integer elementToDelete = queue.poll();
+    this.remove(elementToDelete);
+  }
+
+  public void evictThisAndAllBefore(int rowIdx) {
+    while (queue.peek() <= rowIdx) {
+      evictOne();
+    }
+  }
+
+}
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java 
b/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java
index f125f9b..e17068e 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java
@@ -46,6 +46,7 @@ public class PTFPartition {
   StructObjectInspector inputOI;
   StructObjectInspector outputOI;
   private final PTFRowContainer<List<Object>> elems;
+  private final BoundaryCache boundaryCache;
 
   protected PTFPartition(Configuration cfg,
       AbstractSerDe serDe, StructObjectInspector inputOI,
@@ -70,6 +71,8 @@ public class PTFPartition {
     } else {
       elems = null;
     }
+    int boundaryCacheSize = HiveConf.getIntVar(cfg, 
ConfVars.HIVE_PTF_RANGECACHE_SIZE);
+    boundaryCache = boundaryCacheSize > 1 ? new 
BoundaryCache(boundaryCacheSize) : null;
   }
 
   public void reset() throws HiveException {
@@ -262,4 +265,8 @@ public class PTFPartition {
         ObjectInspectorCopyOption.WRITABLE);
   }
 
+  public BoundaryCache getBoundaryCache() {
+    return boundaryCache;
+  }
+
 }
diff --git 
a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java 
b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java
index d44604d..20dc862 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java
@@ -256,6 +256,7 @@ public class BasePartitionEvaluator {
       end = getRowBoundaryEnd(endB, currRow, p);
     } else {
       ValueBoundaryScanner vbs = ValueBoundaryScanner.getScanner(winFrame, 
nullsLast);
+      vbs.handleCache(currRow, p);
       start = vbs.computeStart(currRow, p);
       end = vbs.computeEnd(currRow, p);
     }
diff --git 
a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java 
b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java
index e633edb..524812f 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java
@@ -18,10 +18,15 @@
 
 package org.apache.hadoop.hive.ql.udf.ptf;
 
+import java.util.Map;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.hadoop.hive.common.type.Date;
 import org.apache.hadoop.hive.common.type.HiveDecimal;
 import org.apache.hadoop.hive.common.type.Timestamp;
 import org.apache.hadoop.hive.common.type.TimestampTZ;
+import org.apache.hadoop.hive.ql.exec.BoundaryCache;
 import org.apache.hadoop.hive.ql.exec.PTFPartition;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order;
@@ -44,10 +49,258 @@ public abstract class ValueBoundaryScanner {
     this.nullsLast = nullsLast;
   }
 
+  public abstract Object computeValue(Object row) throws HiveException;
+
+  /**
+   * Checks if the distance of v2 to v1 is greater than the given amt.
+   * @return True if the value of v1 - v2 is greater than amt or either value 
is null.
+   */
+  public abstract boolean isDistanceGreater(Object v1, Object v2, int amt);
+
+  /**
+   * Checks if the values of v1 or v2 are the same.
+   * @return True if both values are the same or both are nulls.
+   */
+  public abstract boolean isEqual(Object v1, Object v2);
+
   public abstract int computeStart(int rowIdx, PTFPartition p) throws 
HiveException;
 
   public abstract int computeEnd(int rowIdx, PTFPartition p) throws 
HiveException;
 
+  /**
+   * Checks and maintains cache content - optimizes cache window to always be 
around current row
+   * thereby makes it follow the current progress.
+   * @param rowIdx current row
+   * @param p current partition for the PTF operator
+   * @throws HiveException
+   */
+  public void handleCache(int rowIdx, PTFPartition p) throws HiveException {
+    BoundaryCache cache = p.getBoundaryCache();
+    if (cache == null) {
+      return;
+    }
+
+    //No need to setup/fill cache.
+    if (start.isUnbounded() && end.isUnbounded()) {
+      return;
+    }
+
+    //Start of partition.
+    if (rowIdx == 0) {
+      cache.clear();
+    }
+    if (cache.isComplete()) {
+      return;
+    }
+    if (cache.isEmpty()) {
+      fillCacheUntilEndOrFull(rowIdx, p);
+      return;
+    }
+
+    if (start.isPreceding()) {
+      if (start.isUnbounded()) {
+        if (end.isPreceding()) {
+          //We can wait with cache eviction until we're at the end of 
currently known ranges.
+          Map.Entry<Integer, Object> maxEntry = cache.getMaxEntry();
+          if (maxEntry != null && maxEntry.getKey() <= rowIdx) {
+            cache.evictOne();
+          }
+        } else {
+          //Starting from current row, all previous ranges can be evicted.
+          checkIfCacheCanEvict(rowIdx, p, true);
+        }
+      } else {
+        //We either evict when we're at the end of currently known ranges, or 
if not there yet and
+        // END is of FOLLOWING type: we should remove ranges preceding the 
current range beginning.
+        Map.Entry<Integer, Object> maxEntry = cache.getMaxEntry();
+        if (maxEntry != null && maxEntry.getKey() <= rowIdx) {
+          cache.evictOne();
+        } else if (end.isFollowing()) {
+          int startIdx = computeStart(rowIdx, p);
+          checkIfCacheCanEvict(startIdx - 1, p, true);
+        }
+      }
+    }
+
+    if (start.isCurrentRow()) {
+      //Starting from current row, all previous ranges before the previous 
range can be evicted.
+      checkIfCacheCanEvict(rowIdx, p, false);
+    }
+    if (start.isFollowing()) {
+      //Starting from current row, all previous ranges can be evicted.
+      checkIfCacheCanEvict(rowIdx, p, true);
+    }
+
+    fillCacheUntilEndOrFull(rowIdx, p);
+  }
+
+  /**
+   * Retrieves the range for rowIdx, then removes all previous range entries 
before it.
+   * @param rowIdx row index.
+   * @param p partition.
+   * @param willScanFwd false: removal is started only from the previous 
previous range.
+   */
+  private void checkIfCacheCanEvict(int rowIdx, PTFPartition p, boolean 
willScanFwd) {
+    BoundaryCache cache = p.getBoundaryCache();
+    if (cache == null) {
+      return;
+    }
+    Map.Entry<Integer, Object> floorEntry = cache.floorEntry(rowIdx);
+    if (floorEntry != null) {
+      floorEntry = cache.floorEntry(floorEntry.getKey() - 1);
+      if (floorEntry != null) {
+        if (willScanFwd) {
+          cache.evictThisAndAllBefore(floorEntry.getKey());
+        } else {
+          floorEntry = cache.floorEntry(floorEntry.getKey() - 1);
+          if (floorEntry != null) {
+            cache.evictThisAndAllBefore(floorEntry.getKey());
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * Inserts values into cache starting from rowIdx in the current partition 
p. Stops if cache
+   * reaches its maximum size or we get out of rows in p.
+   * @param rowIdx
+   * @param p
+   * @throws HiveException
+   */
+  private void fillCacheUntilEndOrFull(int rowIdx, PTFPartition p) throws 
HiveException {
+    BoundaryCache cache = p.getBoundaryCache();
+    if (cache == null || p.size() <= 0) {
+      return;
+    }
+
+    Object rowVal = null;
+
+    //If we continue building cache
+    Map.Entry<Integer, Object> ceilingEntry = cache.getMaxEntry();
+    if (ceilingEntry != null) {
+      rowIdx = ceilingEntry.getKey();
+      rowVal = ceilingEntry.getValue();
+      ++rowIdx;
+    }
+
+    Object lastRowVal = rowVal;
+
+    while (rowIdx < p.size() && !cache.isFull()) {
+      rowVal = computeValue(p.getAt(rowIdx));
+      if (!isEqual(rowVal, lastRowVal)){
+        cache.put(rowIdx, rowVal);
+      }
+      lastRowVal = rowVal;
+      ++rowIdx;
+
+    }
+    //Signaling end of all rows in a partition
+    if (cache.putIfNotFull(rowIdx, null)) {
+      cache.setComplete(true);
+    }
+  }
+
+  /**
+   * Uses cache content to jump backwards if possible. If not, it steps one 
back.
+   * @param r
+   * @param p
+   * @return pair of (row we stepped/jumped onto ; row value at this position)
+   * @throws HiveException
+   */
+  protected Pair<Integer, Object> skipOrStepBack(int r, PTFPartition p)
+          throws HiveException {
+    Object rowVal = null;
+    BoundaryCache cache = p.getBoundaryCache();
+
+    Map.Entry<Integer, Object> floorEntry = null;
+    Map.Entry<Integer, Object> ceilingEntry = null;
+
+    if (cache != null) {
+      floorEntry = cache.floorEntry(r);
+      ceilingEntry = cache.ceilingEntry(r);
+    }
+
+    if (floorEntry != null && ceilingEntry != null) {
+      r = floorEntry.getKey() - 1;
+      floorEntry = cache.floorEntry(r);
+      if (floorEntry != null) {
+        rowVal = floorEntry.getValue();
+      } else if (r >= 0){
+        rowVal = computeValue(p.getAt(r));
+      }
+    } else {
+      r--;
+      if (r >= 0) {
+        rowVal = computeValue(p.getAt(r));
+      }
+    }
+    return new ImmutablePair<>(r, rowVal);
+  }
+
+  /**
+   * Uses cache content to jump forward if possible. If not, it steps one 
forward.
+   * @param r
+   * @param p
+   * @return pair of (row we stepped/jumped onto ; row value at this position)
+   * @throws HiveException
+   */
+  protected Pair<Integer, Object> skipOrStepForward(int r, PTFPartition p)
+          throws HiveException {
+    Object rowVal = null;
+    BoundaryCache cache = p.getBoundaryCache();
+
+    Map.Entry<Integer, Object> floorEntry = null;
+    Map.Entry<Integer, Object> ceilingEntry = null;
+
+    if (cache != null) {
+      floorEntry = cache.floorEntry(r);
+      ceilingEntry = cache.ceilingEntry(r);
+    }
+
+    if (ceilingEntry != null && ceilingEntry.getKey().equals(r)){
+      ceilingEntry = cache.ceilingEntry(r + 1);
+    }
+    if (floorEntry != null && ceilingEntry != null) {
+      r = ceilingEntry.getKey();
+      rowVal = ceilingEntry.getValue();
+    } else {
+      r++;
+      if (r < p.size()) {
+        rowVal = computeValue(p.getAt(r));
+      }
+    }
+    return new ImmutablePair<>(r, rowVal);
+  }
+
+  /**
+   * Uses cache to lookup row value. Computes it on the fly on cache miss.
+   * @param r
+   * @param p
+   * @return row value.
+   * @throws HiveException
+   */
+  protected Object computeValueUseCache(int r, PTFPartition p) throws 
HiveException {
+    BoundaryCache cache = p.getBoundaryCache();
+
+    Map.Entry<Integer, Object> floorEntry = null;
+    Map.Entry<Integer, Object> ceilingEntry = null;
+
+    if (cache != null) {
+      floorEntry = cache.floorEntry(r);
+      ceilingEntry = cache.ceilingEntry(r);
+    }
+
+    if (ceilingEntry != null && ceilingEntry.getKey().equals(r)){
+      return ceilingEntry.getValue();
+    }
+    if (floorEntry != null && ceilingEntry != null) {
+      return floorEntry.getValue();
+    } else {
+      return computeValue(p.getAt(r));
+    }
+  }
+
   public static ValueBoundaryScanner getScanner(WindowFrameDef winFrameDef, 
boolean nullsLast)
       throws HiveException {
     OrderDef orderDef = winFrameDef.getOrderDef();
@@ -108,6 +361,7 @@ abstract class SingleValueBoundaryScanner extends 
ValueBoundaryScanner {
 |      |                |                |          |       | such that R2.sk 
- R.sk > amt      |
 
|------+----------------+----------------+----------+-------+-----------------------------------|
    */
+
   @Override
   public int computeStart(int rowIdx, PTFPartition p) throws HiveException {
     switch(start.getDirection()) {
@@ -127,18 +381,17 @@ abstract class SingleValueBoundaryScanner extends 
ValueBoundaryScanner {
     if ( amt == BoundarySpec.UNBOUNDED_AMOUNT ) {
       return 0;
     }
-    Object sortKey = computeValue(p.getAt(rowIdx));
+    Object sortKey = computeValueUseCache(rowIdx, p);
 
     if ( sortKey == null ) {
       // Use Case 3.
       if (nullsLast || expressionDef.getOrder() == Order.DESC) {
         while ( sortKey == null && rowIdx >= 0 ) {
-          --rowIdx;
-          if ( rowIdx >= 0 ) {
-            sortKey = computeValue(p.getAt(rowIdx));
-          }
+          Pair<Integer, Object> stepResult = skipOrStepBack(rowIdx, p);
+          rowIdx = stepResult.getLeft();
+          sortKey = stepResult.getRight();
         }
-        return rowIdx+1;
+        return rowIdx + 1;
       }
       else { // Use Case 2.
         if ( expressionDef.getOrder() == Order.ASC ) {
@@ -153,36 +406,34 @@ abstract class SingleValueBoundaryScanner extends 
ValueBoundaryScanner {
     // Use Case 4.
     if ( expressionDef.getOrder() == Order.DESC ) {
       while (r >= 0 && !isDistanceGreater(rowVal, sortKey, amt) ) {
-        r--;
-        if ( r >= 0 ) {
-          rowVal = computeValue(p.getAt(r));
-        }
+        Pair<Integer, Object> stepResult = skipOrStepBack(r, p);
+        r = stepResult.getLeft();
+        rowVal = stepResult.getRight();
       }
       return r + 1;
     }
     else { // Use Case 5.
       while (r >= 0 && !isDistanceGreater(sortKey, rowVal, amt) ) {
-        r--;
-        if ( r >= 0 ) {
-          rowVal = computeValue(p.getAt(r));
-        }
+        Pair<Integer, Object> stepResult = skipOrStepBack(r, p);
+        r = stepResult.getLeft();
+        rowVal = stepResult.getRight();
       }
+
       return r + 1;
     }
   }
 
   protected int computeStartCurrentRow(int rowIdx, PTFPartition p) throws 
HiveException {
-    Object sortKey = computeValue(p.getAt(rowIdx));
+    Object sortKey = computeValueUseCache(rowIdx, p);
 
     // Use Case 6.
     if ( sortKey == null ) {
       while ( sortKey == null && rowIdx >= 0 ) {
-        --rowIdx;
-        if ( rowIdx >= 0 ) {
-          sortKey = computeValue(p.getAt(rowIdx));
-        }
+        Pair<Integer, Object> stepResult = skipOrStepBack(rowIdx, p);
+        rowIdx = stepResult.getLeft();
+        sortKey = stepResult.getRight();
       }
-      return rowIdx+1;
+      return rowIdx + 1;
     }
 
     Object rowVal = sortKey;
@@ -190,17 +441,16 @@ abstract class SingleValueBoundaryScanner extends 
ValueBoundaryScanner {
 
     // Use Case 7.
     while (r >= 0 && isEqual(rowVal, sortKey) ) {
-      r--;
-      if ( r >= 0 ) {
-        rowVal = computeValue(p.getAt(r));
-      }
+      Pair<Integer, Object> stepResult = skipOrStepBack(r, p);
+      r = stepResult.getLeft();
+      rowVal = stepResult.getRight();
     }
     return r + 1;
   }
 
   protected int computeStartFollowing(int rowIdx, PTFPartition p) throws 
HiveException {
     int amt = start.getAmt();
-    Object sortKey = computeValue(p.getAt(rowIdx));
+    Object sortKey = computeValueUseCache(rowIdx, p);
 
     Object rowVal = sortKey;
     int r = rowIdx;
@@ -212,10 +462,9 @@ abstract class SingleValueBoundaryScanner extends 
ValueBoundaryScanner {
       }
       else { // Use Case 10.
         while (r < p.size() && rowVal == null ) {
-          r++;
-          if ( r < p.size() ) {
-            rowVal = computeValue(p.getAt(r));
-          }
+          Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+          r = stepResult.getLeft();
+          rowVal = stepResult.getRight();
         }
         return r;
       }
@@ -224,19 +473,17 @@ abstract class SingleValueBoundaryScanner extends 
ValueBoundaryScanner {
     // Use Case 11.
     if ( expressionDef.getOrder() == Order.DESC) {
       while (r < p.size() && !isDistanceGreater(sortKey, rowVal, amt) ) {
-        r++;
-        if ( r < p.size() ) {
-          rowVal = computeValue(p.getAt(r));
-        }
+        Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+        r = stepResult.getLeft();
+        rowVal = stepResult.getRight();
       }
       return r;
     }
     else { // Use Case 12.
       while (r < p.size() && !isDistanceGreater(rowVal, sortKey, amt) ) {
-        r++;
-        if ( r < p.size() ) {
-          rowVal = computeValue(p.getAt(r));
-        }
+        Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+        r = stepResult.getLeft();
+        rowVal = stepResult.getRight();
       }
       return r;
     }
@@ -292,7 +539,7 @@ abstract class SingleValueBoundaryScanner extends 
ValueBoundaryScanner {
     // Use Case 1.
     // amt == UNBOUNDED, is caught during translation
 
-    Object sortKey = computeValue(p.getAt(rowIdx));
+    Object sortKey = computeValueUseCache(rowIdx, p);
 
     if ( sortKey == null ) {
       // Use Case 2.
@@ -310,34 +557,31 @@ abstract class SingleValueBoundaryScanner extends 
ValueBoundaryScanner {
     // Use Case 4.
     if ( expressionDef.getOrder() == Order.DESC ) {
       while (r >= 0 && !isDistanceGreater(rowVal, sortKey, amt) ) {
-        r--;
-        if ( r >= 0 ) {
-          rowVal = computeValue(p.getAt(r));
-        }
+        Pair<Integer, Object> stepResult = skipOrStepBack(r, p);
+        r = stepResult.getLeft();
+        rowVal = stepResult.getRight();
       }
       return r + 1;
     }
     else { // Use Case 5.
       while (r >= 0 && !isDistanceGreater(sortKey, rowVal, amt) ) {
-        r--;
-        if ( r >= 0 ) {
-          rowVal = computeValue(p.getAt(r));
-        }
+        Pair<Integer, Object> stepResult = skipOrStepBack(r, p);
+        r = stepResult.getLeft();
+        rowVal = stepResult.getRight();
       }
       return r + 1;
     }
   }
 
   protected int computeEndCurrentRow(int rowIdx, PTFPartition p) throws 
HiveException {
-    Object sortKey = computeValue(p.getAt(rowIdx));
+    Object sortKey = computeValueUseCache(rowIdx, p);
 
     // Use Case 6.
     if ( sortKey == null ) {
       while ( sortKey == null && rowIdx < p.size() ) {
-        ++rowIdx;
-        if ( rowIdx < p.size() ) {
-          sortKey = computeValue(p.getAt(rowIdx));
-        }
+        Pair<Integer, Object> stepResult = skipOrStepForward(rowIdx, p);
+        rowIdx = stepResult.getLeft();
+        sortKey = stepResult.getRight();
       }
       return rowIdx;
     }
@@ -347,10 +591,9 @@ abstract class SingleValueBoundaryScanner extends 
ValueBoundaryScanner {
 
     // Use Case 7.
     while (r < p.size() && isEqual(sortKey, rowVal) ) {
-      r++;
-      if ( r < p.size() ) {
-        rowVal = computeValue(p.getAt(r));
-      }
+      Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+      r = stepResult.getLeft();
+      rowVal = stepResult.getRight();
     }
     return r;
   }
@@ -362,7 +605,7 @@ abstract class SingleValueBoundaryScanner extends 
ValueBoundaryScanner {
     if ( amt == BoundarySpec.UNBOUNDED_AMOUNT ) {
       return p.size();
     }
-    Object sortKey = computeValue(p.getAt(rowIdx));
+    Object sortKey = computeValueUseCache(rowIdx, p);
 
     Object rowVal = sortKey;
     int r = rowIdx;
@@ -374,10 +617,9 @@ abstract class SingleValueBoundaryScanner extends 
ValueBoundaryScanner {
       }
       else { // Use Case 10.
         while (r < p.size() && rowVal == null ) {
-          r++;
-          if ( r < p.size() ) {
-            rowVal = computeValue(p.getAt(r));
-          }
+          Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+          r = stepResult.getLeft();
+          rowVal = stepResult.getRight();
         }
         return r;
       }
@@ -386,19 +628,17 @@ abstract class SingleValueBoundaryScanner extends 
ValueBoundaryScanner {
     // Use Case 11.
     if ( expressionDef.getOrder() == Order.DESC) {
       while (r < p.size() && !isDistanceGreater(sortKey, rowVal, amt) ) {
-        r++;
-        if ( r < p.size() ) {
-          rowVal = computeValue(p.getAt(r));
-        }
+        Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+        r = stepResult.getLeft();
+        rowVal = stepResult.getRight();
       }
       return r;
     }
     else { // Use Case 12.
       while (r < p.size() && !isDistanceGreater(rowVal, sortKey, amt) ) {
-        r++;
-        if ( r < p.size() ) {
-          rowVal = computeValue(p.getAt(r));
-        }
+        Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+        r = stepResult.getLeft();
+        rowVal = stepResult.getRight();
       }
       return r;
     }
@@ -717,15 +957,14 @@ class StringValueBoundaryScanner extends 
SingleValueBoundaryScanner {
   }
 
   protected int computeStartCurrentRow(int rowIdx, PTFPartition p) throws 
HiveException {
-    Object[] sortKey = computeValues(p.getAt(rowIdx));
-    Object[] rowVal = sortKey;
+    Object sortKey = computeValueUseCache(rowIdx, p);
+    Object rowVal = sortKey;
     int r = rowIdx;
 
     while (r >= 0 && isEqual(rowVal, sortKey) ) {
-      r--;
-      if ( r >= 0 ) {
-        rowVal = computeValues(p.getAt(r));
-      }
+      Pair<Integer, Object> stepResult = skipOrStepBack(r, p);
+      r = stepResult.getLeft();
+      rowVal = stepResult.getRight();
     }
     return r + 1;
   }
@@ -741,6 +980,7 @@ class StringValueBoundaryScanner extends 
SingleValueBoundaryScanner {
 |   2. | FOLLOWING      | UNB           | ANY      | ANY   | end = 
partition.size()            |
 
|------+----------------+---------------+----------+-------+-----------------------------------|
    */
+
   @Override
   public int computeEnd(int rowIdx, PTFPartition p) throws HiveException {
     switch(end.getDirection()) {
@@ -756,15 +996,14 @@ class StringValueBoundaryScanner extends 
SingleValueBoundaryScanner {
   }
 
   protected int computeEndCurrentRow(int rowIdx, PTFPartition p) throws 
HiveException {
-    Object[] sortKey = computeValues(p.getAt(rowIdx));
-    Object[] rowVal = sortKey;
+    Object sortKey = computeValueUseCache(rowIdx, p);
+    Object rowVal = sortKey;
     int r = rowIdx;
 
     while (r < p.size() && isEqual(sortKey, rowVal) ) {
-      r++;
-      if ( r < p.size() ) {
-        rowVal = computeValues(p.getAt(r));
-      }
+      Pair<Integer, Object> stepResult = skipOrStepForward(r, p);
+      r = stepResult.getLeft();
+      rowVal = stepResult.getRight();
     }
     return r;
   }
@@ -778,7 +1017,8 @@ class StringValueBoundaryScanner extends 
SingleValueBoundaryScanner {
             "FOLLOWING needs UNBOUNDED for RANGE with multiple expressions in 
ORDER BY");
   }
 
-  public Object[] computeValues(Object row) throws HiveException {
+  @Override
+  public Object computeValue(Object row) throws HiveException {
     Object[] objs = new Object[orderDef.getExpressions().size()];
     for (int i = 0; i < objs.length; i++) {
       Object o = 
orderDef.getExpressions().get(i).getExprEvaluator().evaluate(row);
@@ -787,7 +1027,14 @@ class StringValueBoundaryScanner extends 
SingleValueBoundaryScanner {
     return objs;
   }
 
-  public boolean isEqual(Object[] v1, Object[] v2) {
+  @Override
+  public boolean isEqual(Object val1, Object val2) {
+    if (val1 == null || val2 == null) {
+      return (val1 == null && val2 == null);
+    }
+    Object[] v1 = (Object[]) val1;
+    Object[] v2 = (Object[]) val2;
+
     assert v1.length == v2.length;
     for (int i = 0; i < v1.length; i++) {
       if (v1[i] == null && v2[i] == null) {
@@ -804,5 +1051,10 @@ class StringValueBoundaryScanner extends 
SingleValueBoundaryScanner {
     }
     return true;
   }
+
+  @Override
+  public boolean isDistanceGreater(Object v1, Object v2, int amt) {
+    throw new UnsupportedOperationException("Only unbounded ranges supported");
+  }
 }
 
diff --git 
a/ql/src/test/org/apache/hadoop/hive/ql/udf/ptf/TestBoundaryCache.java 
b/ql/src/test/org/apache/hadoop/hive/ql/udf/ptf/TestBoundaryCache.java
new file mode 100644
index 0000000..714c51b
--- /dev/null
+++ b/ql/src/test/org/apache/hadoop/hive/ql/udf/ptf/TestBoundaryCache.java
@@ -0,0 +1,295 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.udf.ptf;
+
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.hadoop.hive.ql.exec.BoundaryCache;
+import org.apache.hadoop.hive.ql.exec.PTFPartition;
+import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec;
+import org.apache.hadoop.hive.ql.parse.WindowingSpec;
+import org.apache.hadoop.hive.ql.plan.ptf.BoundaryDef;
+import org.apache.hadoop.hive.ql.plan.ptf.OrderExpressionDef;
+import org.apache.hadoop.io.IntWritable;
+
+import com.google.common.collect.Lists;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static java.util.Optional.ofNullable;
+import static java.util.stream.Collectors.toCollection;
+import static java.util.stream.Collectors.toList;
+import static org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order.ASC;
+import static org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order.DESC;
+import static 
org.apache.hadoop.hive.ql.parse.WindowingSpec.BoundarySpec.UNBOUNDED_AMOUNT;
+import static org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction.CURRENT;
+import static 
org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction.FOLLOWING;
+import static 
org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction.PRECEDING;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests BoundaryCache used for RANGE windows in PTF functions.
+ */
+public class TestBoundaryCache {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(TestBoundaryCache.class);
+  private static final LinkedList<List<IntWritable>> TEST_PARTITION = new 
LinkedList<>();
+  //Null for using no cache at all, 2 is minimum cache length, 5-9-15 for 
checking with smaller,
+  // exactly equal and larger cache than needed.
+  private static final List<Integer> CACHE_SIZES = Lists.newArrayList(null, 2, 
5, 9, 15);
+  private static final List<PTFInvocationSpec.Order> ORDERS = 
Lists.newArrayList(ASC, DESC);
+  private static final int ORDER_BY_COL = 2;
+
+  @BeforeClass
+  public static void setupTests() throws Exception {
+    //8 ranges, max cache content is 8+1=9 entries
+    addRow(TEST_PARTITION, 1, 1, -7);
+    addRow(TEST_PARTITION, 2, 1, -1);
+    addRow(TEST_PARTITION, 3, 1, -1);
+    addRow(TEST_PARTITION, 4, 1, 1);
+    addRow(TEST_PARTITION, 5, 1, 1);
+    addRow(TEST_PARTITION, 6, 1, 1);
+    addRow(TEST_PARTITION, 7, 1, 1);
+    addRow(TEST_PARTITION, 8, 1, 2);
+    addRow(TEST_PARTITION, 9, 1, 2);
+    addRow(TEST_PARTITION, 10, 1, 2);
+    addRow(TEST_PARTITION, 11, 1, 2);
+    addRow(TEST_PARTITION, 12, 1, 3);
+    addRow(TEST_PARTITION, 13, 1, 5);
+    addRow(TEST_PARTITION, 14, 1, 5);
+    addRow(TEST_PARTITION, 15, 1, 5);
+    addRow(TEST_PARTITION, 16, 1, 5);
+    addRow(TEST_PARTITION, 17, 1, 6);
+    addRow(TEST_PARTITION, 18, 1, 6);
+    addRow(TEST_PARTITION, 19, 1, 9);
+    addRow(TEST_PARTITION, 20, 1, null);
+    addRow(TEST_PARTITION, 21, 1, null);
+
+  }
+
+  @Test
+  public void testPrecedingUnboundedFollowingUnbounded() throws Exception {
+    runTest(PRECEDING, UNBOUNDED_AMOUNT, FOLLOWING, UNBOUNDED_AMOUNT);
+  }
+
+  @Test
+  public void testPrecedingUnboundedCurrentRow() throws Exception {
+    runTest(PRECEDING, UNBOUNDED_AMOUNT, CURRENT, 0);
+  }
+
+  @Test
+  public void testPrecedingUnboundedPreceding2() throws Exception {
+    runTest(PRECEDING, UNBOUNDED_AMOUNT, PRECEDING, 2);
+  }
+
+  @Test
+  public void testPreceding4Preceding1() throws Exception {
+    runTest(PRECEDING, 4, PRECEDING, 1);
+  }
+
+  @Test
+  public void testPreceding2CurrentRow() throws Exception {
+    runTest(PRECEDING, 2, CURRENT, 0);
+  }
+
+  @Test
+  public void testPreceding2Following100() throws Exception {
+    runTest(PRECEDING, 1, FOLLOWING, 100);
+  }
+
+  @Test
+  public void testCurrentRowFollowing3() throws Exception {
+    runTest(CURRENT, 0, FOLLOWING, 3);
+  }
+
+  @Test
+  public void testCurrentRowFFollowingUnbounded() throws Exception {
+    runTest(CURRENT, 0, FOLLOWING, UNBOUNDED_AMOUNT);
+  }
+
+  @Test
+  public void testFollowing2Following4() throws Exception {
+    runTest(FOLLOWING, 2, FOLLOWING, 4);
+  }
+
+  @Test
+  public void testFollowing2FollowingUnbounded() throws Exception {
+    runTest(FOLLOWING, 2, FOLLOWING, UNBOUNDED_AMOUNT);
+  }
+
+  /**
+   * Executes test on a given window definition. Such a test will be executed 
against the values set
+   * in ORDERS and CACHE_SIZES, validating ORDERS X CACHE_SIZES test cases. 
Cache size of null will
+   * be used to setup baseline.
+   * @param startDirection
+   * @param startAmount
+   * @param endDirection
+   * @param endAmount
+   * @throws Exception
+   */
+  private void runTest(WindowingSpec.Direction startDirection, int startAmount,
+                       WindowingSpec.Direction endDirection, int endAmount) 
throws Exception {
+
+    BoundaryDef startBoundary = new BoundaryDef(startDirection, startAmount);
+    BoundaryDef endBoundary = new BoundaryDef(endDirection, endAmount);
+    AtomicInteger readCounter = new AtomicInteger(0);
+
+    int[] expectedBoundaryStarts = new int[TEST_PARTITION.size()];
+    int[] expectedBoundaryEnds = new int[TEST_PARTITION.size()];
+    int expectedReadCountWithoutCache = -1;
+
+    for (PTFInvocationSpec.Order order : ORDERS) {
+      for (Integer cacheSize : CACHE_SIZES) {
+        LOG.info(Thread.currentThread().getStackTrace()[2].getMethodName());
+        LOG.info("Cache: " + cacheSize + " order: " + order);
+        BoundaryCache cache = cacheSize == null ? null : new 
BoundaryCache(cacheSize);
+        Pair<PTFPartition, ValueBoundaryScanner> mocks = 
setupMocks(TEST_PARTITION,
+                ORDER_BY_COL, startBoundary, endBoundary, order, cache, 
readCounter);
+        PTFPartition ptfPartition = mocks.getLeft();
+        ValueBoundaryScanner scanner = mocks.getRight();
+        for (int i = 0; i < TEST_PARTITION.size(); ++i) {
+          scanner.handleCache(i, ptfPartition);
+          int start = scanner.computeStart(i, ptfPartition);
+          int end = scanner.computeEnd(i, ptfPartition) - 1;
+          if (cache == null) {
+            //Cache-less version should be baseline
+            expectedBoundaryStarts[i] = start;
+            expectedBoundaryEnds[i] = end;
+          } else {
+            assertEquals(expectedBoundaryStarts[i], start);
+            assertEquals(expectedBoundaryEnds[i], end);
+          }
+          Integer col0 = ofNullable(TEST_PARTITION.get(i).get(0)).map(v -> 
v.get()).orElse(null);
+          Integer col1 = ofNullable(TEST_PARTITION.get(i).get(1)).map(v -> 
v.get()).orElse(null);
+          Integer col2 = ofNullable(TEST_PARTITION.get(i).get(2)).map(v -> 
v.get()).orElse(null);
+          LOG.info(String.format("%d|\t%d\t%d\t%d\t|%d-%d", i, col0, col1, 
col2, start, end));
+        }
+        if (cache == null) {
+          expectedReadCountWithoutCache = readCounter.get();
+        } else {
+          //Read count should be smaller with cache being used, but larger 
than the minimum of
+          // reading every row once.
+          assertTrue(expectedReadCountWithoutCache >= readCounter.get());
+          if (startAmount != UNBOUNDED_AMOUNT || endAmount != 
UNBOUNDED_AMOUNT) {
+            assertTrue(TEST_PARTITION.size() <= readCounter.get());
+          }
+        }
+        readCounter.set(0);
+      }
+    }
+  }
+
+  /**
+   * Sets up mock and spy objects used for testing.
+   * @param partition The real partition containing row values.
+   * @param orderByCol Index of column in the row used for separating ranges.
+   * @param start Window definition.
+   * @param end Window definition.
+   * @param order Window definition.
+   * @param cache BoundaryCache instance, it may come in various sizes.
+   * @param readCounter counts how many times reading was invoked
+   * @return Mocked PTFPartition instance and ValueBoundaryScanner spy.
+   * @throws Exception
+   */
+  private static Pair<PTFPartition, ValueBoundaryScanner> setupMocks(
+          List<List<IntWritable>> partition, int orderByCol, BoundaryDef 
start, BoundaryDef end,
+          PTFInvocationSpec.Order order, BoundaryCache cache,
+          AtomicInteger readCounter) throws Exception {
+    PTFPartition partitionMock = mock(PTFPartition.class);
+    doAnswer(invocationOnMock -> {
+      int idx = invocationOnMock.getArgumentAt(0, Integer.class);
+      return partition.get(idx);
+    }).when(partitionMock).getAt(any(Integer.class));
+    doAnswer(invocationOnMock -> {
+      return partition.size();
+    }).when(partitionMock).size();
+    when(partitionMock.getBoundaryCache()).thenReturn(cache);
+
+    OrderExpressionDef orderDef = mock(OrderExpressionDef.class);
+    when(orderDef.getOrder()).thenReturn(order);
+
+    ValueBoundaryScanner scan = new LongValueBoundaryScanner(start, end, 
orderDef, order == ASC);
+    ValueBoundaryScanner scannerSpy = spy(scan);
+    doAnswer(invocationOnMock -> {
+      readCounter.incrementAndGet();
+      List<IntWritable> row = invocationOnMock.getArgumentAt(0, List.class);
+      return row.get(orderByCol);
+    }).when(scannerSpy).computeValue(any(Object.class));
+    doAnswer(invocationOnMock -> {
+      IntWritable v1 = invocationOnMock.getArgumentAt(0, IntWritable.class);
+      IntWritable v2 = invocationOnMock.getArgumentAt(1, IntWritable.class);
+      return (v1 != null && v2 != null) ? v1.get() == v2.get() : v1 == null && 
v2 == null;
+    }).when(scannerSpy).isEqual(any(Object.class), any(Object.class));
+    doAnswer(invocationOnMock -> {
+      IntWritable v1 = invocationOnMock.getArgumentAt(0, IntWritable.class);
+      IntWritable v2 = invocationOnMock.getArgumentAt(1, IntWritable.class);
+      Integer amt = invocationOnMock.getArgumentAt(2, Integer.class);
+      return (v1 != null && v2 != null) ? (v1.get() - v2.get()) > amt :  v1 != 
null || v2 != null;
+    }).when(scannerSpy).isDistanceGreater(any(Object.class), 
any(Object.class), any(Integer.class));
+
+    setOrderOnTestPartitions(order);
+    return new ImmutablePair<>(partitionMock, scannerSpy);
+
+  }
+
+  private static void addRow(List<List<IntWritable>> partition, Integer col0, 
Integer col1,
+                             Integer col2) {
+    partition.add(Lists.newArrayList(
+            col0 != null ? new IntWritable(col0) : null,
+            col1 != null ? new IntWritable(col1) : null,
+            col2 != null ? new IntWritable(col2) : null
+    ));
+  }
+
+  /**
+   * Reverses order on actual data if needed, based on order parameter.
+   * @param order
+   */
+  private static void setOrderOnTestPartitions(PTFInvocationSpec.Order order) {
+    LinkedList<List<IntWritable>> notNulls = TEST_PARTITION.stream().filter(
+        r -> r.get(ORDER_BY_COL) != 
null).collect(toCollection(LinkedList::new));
+    List<List<IntWritable>> nulls = TEST_PARTITION.stream().filter(
+        r -> r.get(ORDER_BY_COL) == null).collect(toList());
+
+    boolean isAscCurrently = notNulls.getFirst().get(ORDER_BY_COL).get() <
+            notNulls.getLast().get(ORDER_BY_COL).get();
+
+    if ((ASC.equals(order) && !isAscCurrently) || (DESC.equals(order) && 
isAscCurrently)) {
+      Collections.reverse(notNulls);
+      TEST_PARTITION.clear();
+      TEST_PARTITION.addAll(notNulls);
+      TEST_PARTITION.addAll(nulls);
+    }
+  }
+
+}

Reply via email to