This is an automated email from the ASF dual-hosted git repository.

ulyssesyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 7e217cfde [GLUTEN-5625][VL] Support window range frame (#5626)
7e217cfde is described below

commit 7e217cfdeae96337bcf59a082e66d5297f534f41
Author: WangGuangxin <[email protected]>
AuthorDate: Wed Jun 12 10:09:23 2024 +0800

    [GLUTEN-5625][VL] Support window range frame (#5626)
---
 .../clickhouse/CHSparkPlanExecApi.scala            | 22 +++---
 .../gluten/backendsapi/velox/VeloxBackend.scala    | 20 ++----
 .../org/apache/gluten/execution/TestOperator.scala | 41 ++++++++++-
 cpp/velox/substrait/SubstraitToVeloxPlan.cc        | 30 ++++++--
 cpp/velox/substrait/SubstraitToVeloxPlan.h         |  6 ++
 docs/developers/SubstraitModifications.md          |  1 +
 .../substrait/expression/ExpressionBuilder.java    | 22 +++---
 .../substrait/expression/WindowFunctionNode.java   | 62 ++++++++++++++---
 .../substrait/proto/substrait/algebra.proto        | 26 ++++---
 .../gluten/backendsapi/BackendSettingsApi.scala    |  2 +
 .../gluten/backendsapi/SparkPlanExecApi.scala      | 35 ++++++----
 .../gluten/execution/WindowExecTransformer.scala   | 13 ----
 .../columnar/rewrite/PullOutPreProject.scala       | 15 +++-
 .../apache/gluten/utils/PullOutProjectHelper.scala | 63 ++++++++++++++++-
 .../expressions/PreComputeRangeFrameBound.scala    | 80 ++++++++++++++++++++++
 15 files changed, 348 insertions(+), 90 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index a8a05c40f..d7faa07a5 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -66,6 +66,7 @@ import org.apache.commons.lang3.ClassUtils
 import java.lang.{Long => JLong}
 import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
 
+import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 
 class CHSparkPlanExecApi extends SparkPlanExecApi {
@@ -727,9 +728,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
               new JArrayList[ExpressionNode](),
               columnName,
               ConverterUtils.getTypeNode(aggWindowFunc.dataType, 
aggWindowFunc.nullable),
-              WindowExecTransformer.getFrameBound(frame.upper),
-              WindowExecTransformer.getFrameBound(frame.lower),
-              frame.frameType.sql
+              frame.upper,
+              frame.lower,
+              frame.frameType.sql,
+              originalInputAttributes.asJava
             )
             windowExpressionNodes.add(windowFunctionNode)
           case aggExpression: AggregateExpression =>
@@ -753,9 +755,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
               childrenNodeList,
               columnName,
               ConverterUtils.getTypeNode(aggExpression.dataType, 
aggExpression.nullable),
-              WindowExecTransformer.getFrameBound(frame.upper),
-              WindowExecTransformer.getFrameBound(frame.lower),
-              frame.frameType.sql
+              frame.upper,
+              frame.lower,
+              frame.frameType.sql,
+              originalInputAttributes.asJava
             )
             windowExpressionNodes.add(windowFunctionNode)
           case wf @ (Lead(_, _, _, _) | Lag(_, _, _, _)) =>
@@ -802,9 +805,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
               childrenNodeList,
               columnName,
               ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
-              WindowExecTransformer.getFrameBound(frame.upper),
-              WindowExecTransformer.getFrameBound(frame.lower),
-              frame.frameType.sql
+              frame.upper,
+              frame.lower,
+              frame.frameType.sql,
+              originalInputAttributes.asJava
             )
             windowExpressionNodes.add(windowFunctionNode)
           case _ =>
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index f06929fff..21e6246d1 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -296,15 +296,9 @@ object VeloxBackendSettings extends BackendSettingsApi {
             case _ => throw new GlutenNotSupportException(s"$func is not 
supported.")
           }
 
-          // Block the offloading by checking Velox's current limitations
-          // when literal bound type is used for RangeFrame.
           def checkLimitations(swf: SpecifiedWindowFrame, orderSpec: 
Seq[SortOrder]): Unit = {
-            def doCheck(bound: Expression, isUpperBound: Boolean): Unit = {
+            def doCheck(bound: Expression): Unit = {
               bound match {
-                case e if e.foldable =>
-                  throw new GlutenNotSupportException(
-                    "Window frame of type RANGE does" +
-                      " not support constant arguments in velox backend")
                 case _: SpecialFrameBoundary =>
                 case e if e.foldable =>
                   orderSpec.foreach(
@@ -325,17 +319,11 @@ object VeloxBackendSettings extends BackendSettingsApi {
                             "Only integral type & date type are" +
                               " supported for sort key when literal bound type 
is used!")
                       })
-                  val rawValue = e.eval().toString.toLong
-                  if (isUpperBound && rawValue < 0) {
-                    throw new GlutenNotSupportException("Negative upper bound 
is not supported!")
-                  } else if (!isUpperBound && rawValue > 0) {
-                    throw new GlutenNotSupportException("Positive lower bound 
is not supported!")
-                  }
                 case _ =>
               }
             }
-            doCheck(swf.upper, true)
-            doCheck(swf.lower, false)
+            doCheck(swf.upper)
+            doCheck(swf.lower)
           }
 
           windowExpression.windowSpec.frameSpecification match {
@@ -495,4 +483,6 @@ object VeloxBackendSettings extends BackendSettingsApi {
   override def supportColumnarArrowUdf(): Boolean = true
 
   override def generateHdfsConfForLibhdfs(): Boolean = true
+
+  override def needPreComputeRangeFrameBoundary(): Boolean = true
 }
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala 
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
index ae8d64a09..3cf485aac 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
@@ -212,17 +212,56 @@ class TestOperator extends 
VeloxWholeStageTransformerSuite with AdaptiveSparkPla
     Seq("sort", "streaming").foreach {
       windowType =>
         withSQLConf("spark.gluten.sql.columnar.backend.velox.window.type" -> 
windowType) {
+          runQueryAndCompare(
+            "select max(l_partkey) over" +
+              " (partition by l_suppkey order by l_orderkey" +
+              " RANGE BETWEEN 1 PRECEDING AND CURRENT ROW), " +
+              "min(l_comment) over" +
+              " (partition by l_suppkey order by l_linenumber" +
+              " RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") {
+            checkSparkOperatorMatch[WindowExecTransformer]
+          }
+
           runQueryAndCompare(
             "select max(l_partkey) over" +
               " (partition by l_suppkey order by l_orderkey" +
               " RANGE BETWEEN CURRENT ROW AND 2 FOLLOWING) from lineitem ") {
-            checkSparkOperatorMatch[WindowExec]
+            checkSparkOperatorMatch[WindowExecTransformer]
           }
 
           runQueryAndCompare(
             "select max(l_partkey) over" +
               " (partition by l_suppkey order by l_orderkey" +
               " RANGE BETWEEN 6 PRECEDING AND CURRENT ROW) from lineitem ") {
+            checkSparkOperatorMatch[WindowExecTransformer]
+          }
+
+          runQueryAndCompare(
+            "select max(l_partkey) over" +
+              " (partition by l_suppkey order by l_orderkey" +
+              " RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from lineitem ") {
+            checkSparkOperatorMatch[WindowExecTransformer]
+          }
+
+          runQueryAndCompare(
+            "select max(l_partkey) over" +
+              " (partition by l_suppkey order by l_orderkey" +
+              " RANGE BETWEEN 6 PRECEDING AND 3 PRECEDING) from lineitem ") {
+            checkSparkOperatorMatch[WindowExecTransformer]
+          }
+
+          runQueryAndCompare(
+            "select max(l_partkey) over" +
+              " (partition by l_suppkey order by l_orderkey" +
+              " RANGE BETWEEN 3 FOLLOWING AND 6 FOLLOWING) from lineitem ") {
+            checkSparkOperatorMatch[WindowExecTransformer]
+          }
+
+          // DecimalType as order by column is not supported
+          runQueryAndCompare(
+            "select min(l_comment) over" +
+              " (partition by l_suppkey order by l_discount" +
+              " RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") {
             checkSparkOperatorMatch[WindowExec]
           }
 
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc 
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index b82eead2c..4e875d479 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -823,10 +823,11 @@ core::PlanNodePtr 
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
       nextPlanNodeId(), replicated, unnest, std::move(unnestNames), 
ordinalityName, childNode);
 }
 
-const core::WindowNode::Frame createWindowFrame(
+const core::WindowNode::Frame SubstraitToVeloxPlanConverter::createWindowFrame(
     const ::substrait::Expression_WindowFunction_Bound& lower_bound,
     const ::substrait::Expression_WindowFunction_Bound& upper_bound,
-    const ::substrait::WindowType& type) {
+    const ::substrait::WindowType& type,
+    const RowTypePtr& inputType) {
   core::WindowNode::Frame frame;
   switch (type) {
     case ::substrait::WindowType::ROWS:
@@ -839,9 +840,22 @@ const core::WindowNode::Frame createWindowFrame(
       VELOX_FAIL("the window type only support ROWS and RANGE, and the input 
type is ", std::to_string(type));
   }
 
-  auto boundTypeConversion = [](::substrait::Expression_WindowFunction_Bound 
boundType)
+  auto specifiedBound =
+      [&](bool hasOffset, int64_t offset, const ::substrait::Expression& 
columnRef) -> core::TypedExprPtr {
+    if (hasOffset) {
+      VELOX_CHECK(
+          frame.type != core::WindowNode::WindowType::kRange,
+          "for RANGE frame offset, we should pre-calculate the range frame 
boundary and pass the column reference, but got a constant offset.")
+      return std::make_shared<core::ConstantTypedExpr>(BIGINT(), 
variant(offset));
+    } else {
+      VELOX_CHECK(
+          frame.type != core::WindowNode::WindowType::kRows, "for ROW frame 
offset, we should pass a constant offset.")
+      return exprConverter_->toVeloxExpr(columnRef, inputType);
+    }
+  };
+
+  auto boundTypeConversion = [&](::substrait::Expression_WindowFunction_Bound 
boundType)
       -> std::tuple<core::WindowNode::BoundType, core::TypedExprPtr> {
-    // TODO: support non-literal expression.
     if (boundType.has_current_row()) {
       return std::make_tuple(core::WindowNode::BoundType::kCurrentRow, 
nullptr);
     } else if (boundType.has_unbounded_following()) {
@@ -849,13 +863,15 @@ const core::WindowNode::Frame createWindowFrame(
     } else if (boundType.has_unbounded_preceding()) {
       return std::make_tuple(core::WindowNode::BoundType::kUnboundedPreceding, 
nullptr);
     } else if (boundType.has_following()) {
+      auto following = boundType.following();
       return std::make_tuple(
           core::WindowNode::BoundType::kFollowing,
-          std::make_shared<core::ConstantTypedExpr>(BIGINT(), 
variant(boundType.following().offset())));
+          specifiedBound(following.has_offset(), following.offset(), 
following.ref()));
     } else if (boundType.has_preceding()) {
+      auto preceding = boundType.preceding();
       return std::make_tuple(
           core::WindowNode::BoundType::kPreceding,
-          std::make_shared<core::ConstantTypedExpr>(BIGINT(), 
variant(boundType.preceding().offset())));
+          specifiedBound(preceding.has_offset(), preceding.offset(), 
preceding.ref()));
     } else {
       VELOX_FAIL("The BoundType is not supported.");
     }
@@ -906,7 +922,7 @@ core::PlanNodePtr 
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
     windowColumnNames.push_back(windowFunction.column_name());
 
     windowNodeFunctions.push_back(
-        {std::move(windowCall), std::move(createWindowFrame(lowerBound, 
upperBound, type)), ignoreNulls});
+        {std::move(windowCall), std::move(createWindowFrame(lowerBound, 
upperBound, type, inputType)), ignoreNulls});
   }
 
   // Construct partitionKeys
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.h 
b/cpp/velox/substrait/SubstraitToVeloxPlan.h
index 567ebb215..3a0e677af 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.h
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.h
@@ -555,6 +555,12 @@ class SubstraitToVeloxPlanConverter {
     return toVeloxPlan(rel.input());
   }
 
+  const core::WindowNode::Frame createWindowFrame(
+      const ::substrait::Expression_WindowFunction_Bound& lower_bound,
+      const ::substrait::Expression_WindowFunction_Bound& upper_bound,
+      const ::substrait::WindowType& type,
+      const RowTypePtr& inputType);
+
   /// The unique identification for each PlanNode.
   int planNodeId_ = 0;
 
diff --git a/docs/developers/SubstraitModifications.md 
b/docs/developers/SubstraitModifications.md
index 38406425a..24a9c1a21 100644
--- a/docs/developers/SubstraitModifications.md
+++ b/docs/developers/SubstraitModifications.md
@@ -27,6 +27,7 @@ changed `Unbounded` in `WindowFunction` into 
`Unbounded_Preceding` and `Unbounde
 * Added `PartitionColumn` in 
`LocalFiles`([#2405](https://github.com/apache/incubator-gluten/pull/2405)).
 * Added `WriteRel` 
([#3690](https://github.com/apache/incubator-gluten/pull/3690)).
 * Added `TopNRel` 
([#5409](https://github.com/apache/incubator-gluten/pull/5409)).
+* Added `ref` field in window bound `Preceding` and `Following` 
([#5626](https://github.com/apache/incubator-gluten/pull/5626)).
 
 ## Modifications to type.proto
 
diff --git 
a/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java
 
b/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java
index 5d106938c..e322e1528 100644
--- 
a/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java
+++ 
b/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java
@@ -21,6 +21,8 @@ import org.apache.gluten.expression.ConverterUtils;
 import org.apache.gluten.substrait.type.*;
 
 import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.Attribute;
+import org.apache.spark.sql.catalyst.expressions.Expression;
 import org.apache.spark.sql.catalyst.util.ArrayData;
 import org.apache.spark.sql.catalyst.util.MapData;
 import org.apache.spark.sql.types.*;
@@ -264,9 +266,10 @@ public class ExpressionBuilder {
       List<ExpressionNode> expressionNodes,
       String columnName,
       TypeNode outputTypeNode,
-      String upperBound,
-      String lowerBound,
-      String frameType) {
+      Expression upperBound,
+      Expression lowerBound,
+      String frameType,
+      List<Attribute> originalInputAttributes) {
     return makeWindowFunction(
         functionId,
         expressionNodes,
@@ -275,7 +278,8 @@ public class ExpressionBuilder {
         upperBound,
         lowerBound,
         frameType,
-        false);
+        false,
+        originalInputAttributes);
   }
 
   public static WindowFunctionNode makeWindowFunction(
@@ -283,10 +287,11 @@ public class ExpressionBuilder {
       List<ExpressionNode> expressionNodes,
       String columnName,
       TypeNode outputTypeNode,
-      String upperBound,
-      String lowerBound,
+      Expression upperBound,
+      Expression lowerBound,
       String frameType,
-      boolean ignoreNulls) {
+      boolean ignoreNulls,
+      List<Attribute> originalInputAttributes) {
     return new WindowFunctionNode(
         functionId,
         expressionNodes,
@@ -295,6 +300,7 @@ public class ExpressionBuilder {
         upperBound,
         lowerBound,
         frameType,
-        ignoreNulls);
+        ignoreNulls,
+        originalInputAttributes);
   }
 }
diff --git 
a/gluten-core/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java
 
b/gluten-core/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java
index 67d0d6e57..b9f1fbc12 100644
--- 
a/gluten-core/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java
+++ 
b/gluten-core/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java
@@ -16,17 +16,24 @@
  */
 package org.apache.gluten.substrait.expression;
 
+import org.apache.gluten.exception.GlutenException;
+import org.apache.gluten.expression.ExpressionConverter;
 import org.apache.gluten.substrait.type.TypeNode;
 
 import io.substrait.proto.Expression;
 import io.substrait.proto.FunctionArgument;
 import io.substrait.proto.FunctionOption;
 import io.substrait.proto.WindowType;
+import org.apache.spark.sql.catalyst.expressions.Attribute;
+import org.apache.spark.sql.catalyst.expressions.PreComputeRangeFrameBound;
 
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
 
+import scala.collection.JavaConverters;
+
 public class WindowFunctionNode implements Serializable {
   private final Integer functionId;
   private final List<ExpressionNode> expressionNodes = new ArrayList<>();
@@ -34,23 +41,26 @@ public class WindowFunctionNode implements Serializable {
   private final String columnName;
   private final TypeNode outputTypeNode;
 
-  private final String upperBound;
+  private final org.apache.spark.sql.catalyst.expressions.Expression 
upperBound;
 
-  private final String lowerBound;
+  private final org.apache.spark.sql.catalyst.expressions.Expression 
lowerBound;
 
   private final String frameType;
 
   private final boolean ignoreNulls;
 
+  private final List<Attribute> originalInputAttributes;
+
   WindowFunctionNode(
       Integer functionId,
       List<ExpressionNode> expressionNodes,
       String columnName,
       TypeNode outputTypeNode,
-      String upperBound,
-      String lowerBound,
+      org.apache.spark.sql.catalyst.expressions.Expression upperBound,
+      org.apache.spark.sql.catalyst.expressions.Expression lowerBound,
       String frameType,
-      boolean ignoreNulls) {
+      boolean ignoreNulls,
+      List<Attribute> originalInputAttributes) {
     this.functionId = functionId;
     this.expressionNodes.addAll(expressionNodes);
     this.columnName = columnName;
@@ -59,11 +69,13 @@ public class WindowFunctionNode implements Serializable {
     this.lowerBound = lowerBound;
     this.frameType = frameType;
     this.ignoreNulls = ignoreNulls;
+    this.originalInputAttributes = originalInputAttributes;
   }
 
   private Expression.WindowFunction.Bound.Builder setBound(
-      Expression.WindowFunction.Bound.Builder builder, String boundType) {
-    switch (boundType) {
+      Expression.WindowFunction.Bound.Builder builder,
+      org.apache.spark.sql.catalyst.expressions.Expression boundType) {
+    switch (boundType.sql()) {
       case ("CURRENT ROW"):
         Expression.WindowFunction.Bound.CurrentRow.Builder currentRowBuilder =
             Expression.WindowFunction.Bound.CurrentRow.newBuilder();
@@ -80,8 +92,36 @@ public class WindowFunctionNode implements Serializable {
         builder.setUnboundedFollowing(followingBuilder.build());
         break;
       default:
-        try {
-          Long offset = Long.valueOf(boundType);
+        if (boundType instanceof PreComputeRangeFrameBound) {
+          // Used only when backend is velox and frame type is RANGE.
+          if (!frameType.equals("RANGE")) {
+            throw new GlutenException(
+                "Only Range frame supports PreComputeRangeFrameBound, but got 
" + frameType);
+          }
+          ExpressionNode refNode =
+              ExpressionConverter.replaceWithExpressionTransformer(
+                      ((PreComputeRangeFrameBound) 
boundType).child().toAttribute(),
+                      
JavaConverters.asScalaIteratorConverter(originalInputAttributes.iterator())
+                          .asScala()
+                          .toSeq())
+                  .doTransform(new HashMap<String, Long>());
+          Long offset = Long.valueOf(boundType.eval(null).toString());
+          if (offset < 0) {
+            Expression.WindowFunction.Bound.Preceding.Builder 
refPrecedingBuilder =
+                Expression.WindowFunction.Bound.Preceding.newBuilder();
+            refPrecedingBuilder.setRef(refNode.toProtobuf());
+            builder.setPreceding(refPrecedingBuilder.build());
+          } else {
+            Expression.WindowFunction.Bound.Following.Builder 
refFollowingBuilder =
+                Expression.WindowFunction.Bound.Following.newBuilder();
+            refFollowingBuilder.setRef(refNode.toProtobuf());
+            builder.setFollowing(refFollowingBuilder.build());
+          }
+        } else if (boundType.foldable()) {
+          // Used when
+          // 1. Velox backend and frame type is ROW
+          // 2. Clickhouse backend
+          Long offset = Long.valueOf(boundType.eval(null).toString());
           if (offset < 0) {
             Expression.WindowFunction.Bound.Preceding.Builder 
offsetPrecedingBuilder =
                 Expression.WindowFunction.Bound.Preceding.newBuilder();
@@ -93,9 +133,9 @@ public class WindowFunctionNode implements Serializable {
             offsetFollowingBuilder.setOffset(offset);
             builder.setFollowing(offsetFollowingBuilder.build());
           }
-        } catch (NumberFormatException e) {
+        } else {
           throw new UnsupportedOperationException(
-              "Unsupported Window Function Frame Type:" + boundType);
+              "Unsupported Window Function Frame Bound Type: " + boundType);
         }
     }
     return builder;
diff --git 
a/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto 
b/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto
index 877493439..0e51baf5a 100644
--- a/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto
+++ b/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto
@@ -996,18 +996,28 @@ message Expression {
     message Bound {
       // Defines that the bound extends this far back from the current record.
       message Preceding {
-        // A strictly positive integer specifying the number of records that
-        // the window extends back from the current record. Required. Use
-        // CurrentRow for offset zero and Following for negative offsets.
-        int64 offset = 1;
+        oneof kind {
+          // A strictly positive integer specifying the number of records that
+          // the window extends back from the current record. Use
+          // CurrentRow for offset zero and Following for negative offsets.
+          int64 offset = 1;
+
+          // the reference to pre-project range frame boundary.
+          Expression ref = 2;
+        }
       }
 
       // Defines that the bound extends this far ahead of the current record.
       message Following {
-        // A strictly positive integer specifying the number of records that
-        // the window extends ahead of the current record. Required. Use
-        // CurrentRow for offset zero and Preceding for negative offsets.
-        int64 offset = 1;
+        oneof kind {
+          // A strictly positive integer specifying the number of records that
+          // the window extends ahead of the current record. Use
+          // CurrentRow for offset zero and Preceding for negative offsets.
+          int64 offset = 1;
+
+          // the reference to pre-project range frame boundary.
+          Expression ref = 2;
+        }
       }
 
       // Defines that the bound extends to or from the current record.
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
index d18273af2..b7a3bc1b6 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
@@ -152,4 +152,6 @@ trait BackendSettingsApi {
   def supportColumnarArrowUdf(): Boolean = false
 
   def generateHdfsConfForLibhdfs(): Boolean = false
+
+  def needPreComputeRangeFrameBoundary(): Boolean = false
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 8a1baae51..8bc8e136b 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -529,9 +529,10 @@ trait SparkPlanExecApi {
               new JArrayList[ExpressionNode](),
               columnName,
               ConverterUtils.getTypeNode(aggWindowFunc.dataType, 
aggWindowFunc.nullable),
-              WindowExecTransformer.getFrameBound(frame.upper),
-              WindowExecTransformer.getFrameBound(frame.lower),
-              frame.frameType.sql
+              frame.upper,
+              frame.lower,
+              frame.frameType.sql,
+              originalInputAttributes.asJava
             )
             windowExpressionNodes.add(windowFunctionNode)
           case aggExpression: AggregateExpression =>
@@ -554,9 +555,10 @@ trait SparkPlanExecApi {
               childrenNodeList,
               columnName,
               ConverterUtils.getTypeNode(aggExpression.dataType, 
aggExpression.nullable),
-              WindowExecTransformer.getFrameBound(frame.upper),
-              WindowExecTransformer.getFrameBound(frame.lower),
-              frame.frameType.sql
+              frame.upper,
+              frame.lower,
+              frame.frameType.sql,
+              originalInputAttributes.asJava
             )
             windowExpressionNodes.add(windowFunctionNode)
           case wf @ (_: Lead | _: Lag) =>
@@ -590,10 +592,11 @@ trait SparkPlanExecApi {
               childrenNodeList,
               columnName,
               ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
-              WindowExecTransformer.getFrameBound(frame.upper),
-              WindowExecTransformer.getFrameBound(frame.lower),
+              frame.upper,
+              frame.lower,
               frame.frameType.sql,
-              offsetWf.ignoreNulls
+              offsetWf.ignoreNulls,
+              originalInputAttributes.asJava
             )
             windowExpressionNodes.add(windowFunctionNode)
           case wf @ NthValue(input, offset: Literal, ignoreNulls: Boolean) =>
@@ -609,10 +612,11 @@ trait SparkPlanExecApi {
               childrenNodeList,
               columnName,
               ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
-              frame.upper.sql,
-              frame.lower.sql,
+              frame.upper,
+              frame.lower,
               frame.frameType.sql,
-              ignoreNulls
+              ignoreNulls,
+              originalInputAttributes.asJava
             )
             windowExpressionNodes.add(windowFunctionNode)
           case wf @ NTile(buckets: Expression) =>
@@ -625,9 +629,10 @@ trait SparkPlanExecApi {
               childrenNodeList,
               columnName,
               ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
-              frame.upper.sql,
-              frame.lower.sql,
-              frame.frameType.sql
+              frame.upper,
+              frame.lower,
+              frame.frameType.sql,
+              originalInputAttributes.asJava
             )
             windowExpressionNodes.add(windowFunctionNode)
           case _ =>
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala
index ef6a767b5..6832221a4 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala
@@ -197,16 +197,3 @@ case class WindowExecTransformer(
   override protected def withNewChildInternal(newChild: SparkPlan): 
WindowExecTransformer =
     copy(child = newChild)
 }
-
-object WindowExecTransformer {
-
-  /** Gets lower/upper bound represented in string. */
-  def getFrameBound(bound: Expression): String = {
-    // The lower/upper can be either a foldable Expression or a 
SpecialFrameBoundary.
-    if (bound.foldable) {
-      bound.eval().toString
-    } else {
-      bound.sql
-    }
-  }
-}
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
index 50dc55423..73b8ab260 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
@@ -75,6 +75,17 @@ object PullOutPreProject extends RewriteSingleNode with 
PullOutProjectHelper {
               case _ => false
             }
           case _ => false
+        }.isDefined) ||
+        window.windowExpression.exists(_.find {
+          case we: WindowExpression =>
+            we.windowSpec.frameSpecification match {
+              case swf: SpecifiedWindowFrame
+                  if needPreComputeRangeFrame(swf) && 
supportPreComputeRangeFrame(
+                    we.windowSpec.orderSpec) =>
+                true
+              case _ => false
+            }
+          case _ => false
         }.isDefined)
       case plan if SparkShimLoader.getSparkShims.isWindowGroupLimitExec(plan) 
=>
         val window = SparkShimLoader.getSparkShims
@@ -174,7 +185,9 @@ object PullOutPreProject extends RewriteSingleNode with 
PullOutProjectHelper {
 
       // Handle windowExpressions.
       val newWindowExpressions = window.windowExpression.toIndexedSeq.map {
-        _.transform { case we: WindowExpression => rewriteWindowExpression(we, 
expressionMap) }
+        _.transform {
+          case we: WindowExpression => rewriteWindowExpression(we, 
newOrderSpec, expressionMap)
+        }
       }
 
       val newWindow = window.copy(
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala 
b/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala
index 505f13f26..12055f9e9 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala
@@ -16,11 +16,13 @@
  */
 package org.apache.gluten.utils
 
-import org.apache.gluten.exception.GlutenNotSupportException
+import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException}
 
 import org.apache.spark.sql.catalyst.expressions._
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
AggregateFunction}
 import org.apache.spark.sql.execution.aggregate._
+import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, 
ShortType}
 
 import java.util.concurrent.atomic.AtomicInteger
 
@@ -143,8 +145,49 @@ trait PullOutProjectHelper {
     ae.copy(aggregateFunction = newAggFunc, filter = newFilter)
   }
 
+  private def needPreComputeRangeFrameBoundary(bound: Expression): Boolean = {
+    bound match {
+      case _: PreComputeRangeFrameBound => false
+      case _ if !bound.foldable => false
+      case _ => true
+    }
+  }
+
+  private def preComputeRangeFrameBoundary(
+      bound: Expression,
+      orderSpec: SortOrder,
+      expressionMap: mutable.HashMap[Expression, NamedExpression]): Expression 
= {
+    bound match {
+      case _: PreComputeRangeFrameBound => bound
+      case _ if !bound.foldable => bound
+      case _ if bound.foldable =>
+        val a = expressionMap
+          .getOrElseUpdate(
+            bound.canonicalized,
+            Alias(Add(orderSpec.child, bound), generatePreAliasName)())
+        PreComputeRangeFrameBound(a.asInstanceOf[Alias], bound)
+    }
+  }
+
+  protected def needPreComputeRangeFrame(swf: SpecifiedWindowFrame): Boolean = 
{
+    BackendsApiManager.getSettings.needPreComputeRangeFrameBoundary &&
+    swf.frameType == RangeFrame &&
+    (needPreComputeRangeFrameBoundary(swf.lower) || 
needPreComputeRangeFrameBoundary(swf.upper))
+  }
+
+  protected def supportPreComputeRangeFrame(sortOrders: Seq[SortOrder]): 
Boolean = {
+    sortOrders.forall {
+      _.dataType match {
+        case ByteType | ShortType | IntegerType | LongType | DateType => true
+        // Only integral type & date type are supported for sort key with 
Range Frame
+        case _ => false
+      }
+    }
+  }
+
   protected def rewriteWindowExpression(
       we: WindowExpression,
+      orderSpecs: Seq[SortOrder],
       expressionMap: mutable.HashMap[Expression, NamedExpression]): 
WindowExpression = {
     val newWindowFunc = we.windowFunction match {
       case windowFunc: WindowFunction =>
@@ -156,6 +199,22 @@ trait PullOutProjectHelper {
       case ae: AggregateExpression => rewriteAggregateExpression(ae, 
expressionMap)
       case other => other
     }
-    we.copy(windowFunction = newWindowFunc)
+
+    val newWindowSpec = we.windowSpec.frameSpecification match {
+      case swf: SpecifiedWindowFrame if needPreComputeRangeFrame(swf) =>
+        // This is guaranteed by Spark, but we still check it here
+        if (orderSpecs.size != 1) {
+          throw new GlutenException(
+            s"A range window frame with value boundaries expects one and only 
one " +
+              s"order by expression: ${orderSpecs.mkString(",")}")
+        }
+        val orderSpec = orderSpecs.head
+        val lowerFrameCol = preComputeRangeFrameBoundary(swf.lower, orderSpec, 
expressionMap)
+        val upperFrameCol = preComputeRangeFrameBoundary(swf.upper, orderSpec, 
expressionMap)
+        val newFrame = swf.copy(lower = lowerFrameCol, upper = upperFrameCol)
+        we.windowSpec.copy(frameSpecification = newFrame)
+      case _ => we.windowSpec
+    }
+    we.copy(windowFunction = newWindowFunc, windowSpec = newWindowSpec)
   }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/PreComputeRangeFrameBound.scala
 
b/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/PreComputeRangeFrameBound.scala
new file mode 100644
index 000000000..73c1cb3de
--- /dev/null
+++ 
b/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/PreComputeRangeFrameBound.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
+import org.apache.spark.sql.types.{DataType, Metadata}
+
+/**
+ * Represents a pre-compute boundary for range frame when boundary is 
non-SpecialFrameBoundary,
+ * since Velox doesn't support constant offset for range frame. It acts like 
the original boundary
+ * which is foldable and generate the same result when eval is invoked so that 
if the WindowExec
+ * fallback to Vanilla Spark it can still work correctly.
+ * @param child
+ *   The alias to pre-compute projection column
+ * @param originalBound
+ *   The original boundary which is a foldable expression
+ */
+case class PreComputeRangeFrameBound(child: Alias, originalBound: Expression)
+  extends UnaryExpression
+  with NamedExpression {
+
+  override def foldable: Boolean = true
+
+  override def eval(input: InternalRow): Any = originalBound.eval(input)
+
+  override def genCode(ctx: CodegenContext): ExprCode = 
originalBound.genCode(ctx)
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode =
+    originalBound.genCode(ctx)
+
+  override def name: String = child.name
+
+  override def exprId: ExprId = child.exprId
+
+  override def qualifier: Seq[String] = child.qualifier
+
+  override def newInstance(): NamedExpression =
+    PreComputeRangeFrameBound(child.newInstance().asInstanceOf[Alias], 
originalBound)
+
+  override lazy val resolved: Boolean = originalBound.resolved
+
+  override def dataType: DataType = child.dataType
+
+  override def nullable: Boolean = child.nullable
+
+  override def metadata: Metadata = child.metadata
+
+  override def toAttribute: Attribute = child.toAttribute
+
+  override def toString: String = child.toString
+
+  override def hashCode(): Int = child.hashCode()
+
+  override def equals(other: Any): Boolean = other match {
+    case a: PreComputeRangeFrameBound =>
+      child.equals(a.child)
+    case _ => false
+  }
+
+  override def sql: String = child.sql
+
+  override protected def withNewChildInternal(newChild: Expression): 
PreComputeRangeFrameBound =
+    copy(child = newChild.asInstanceOf[Alias])
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to