This is an automated email from the ASF dual-hosted git repository.
ulyssesyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 7e217cfde [GLUTEN-5625][VL] Support window range frame (#5626)
7e217cfde is described below
commit 7e217cfdeae96337bcf59a082e66d5297f534f41
Author: WangGuangxin <[email protected]>
AuthorDate: Wed Jun 12 10:09:23 2024 +0800
[GLUTEN-5625][VL] Support window range frame (#5626)
---
.../clickhouse/CHSparkPlanExecApi.scala | 22 +++---
.../gluten/backendsapi/velox/VeloxBackend.scala | 20 ++----
.../org/apache/gluten/execution/TestOperator.scala | 41 ++++++++++-
cpp/velox/substrait/SubstraitToVeloxPlan.cc | 30 ++++++--
cpp/velox/substrait/SubstraitToVeloxPlan.h | 6 ++
docs/developers/SubstraitModifications.md | 1 +
.../substrait/expression/ExpressionBuilder.java | 22 +++---
.../substrait/expression/WindowFunctionNode.java | 62 ++++++++++++++---
.../substrait/proto/substrait/algebra.proto | 26 ++++---
.../gluten/backendsapi/BackendSettingsApi.scala | 2 +
.../gluten/backendsapi/SparkPlanExecApi.scala | 35 ++++++----
.../gluten/execution/WindowExecTransformer.scala | 13 ----
.../columnar/rewrite/PullOutPreProject.scala | 15 +++-
.../apache/gluten/utils/PullOutProjectHelper.scala | 63 ++++++++++++++++-
.../expressions/PreComputeRangeFrameBound.scala | 80 ++++++++++++++++++++++
15 files changed, 348 insertions(+), 90 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index a8a05c40f..d7faa07a5 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -66,6 +66,7 @@ import org.apache.commons.lang3.ClassUtils
import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
+import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
class CHSparkPlanExecApi extends SparkPlanExecApi {
@@ -727,9 +728,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
new JArrayList[ExpressionNode](),
columnName,
ConverterUtils.getTypeNode(aggWindowFunc.dataType,
aggWindowFunc.nullable),
- WindowExecTransformer.getFrameBound(frame.upper),
- WindowExecTransformer.getFrameBound(frame.lower),
- frame.frameType.sql
+ frame.upper,
+ frame.lower,
+ frame.frameType.sql,
+ originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case aggExpression: AggregateExpression =>
@@ -753,9 +755,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(aggExpression.dataType,
aggExpression.nullable),
- WindowExecTransformer.getFrameBound(frame.upper),
- WindowExecTransformer.getFrameBound(frame.lower),
- frame.frameType.sql
+ frame.upper,
+ frame.lower,
+ frame.frameType.sql,
+ originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case wf @ (Lead(_, _, _, _) | Lag(_, _, _, _)) =>
@@ -802,9 +805,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
- WindowExecTransformer.getFrameBound(frame.upper),
- WindowExecTransformer.getFrameBound(frame.lower),
- frame.frameType.sql
+ frame.upper,
+ frame.lower,
+ frame.frameType.sql,
+ originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case _ =>
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index f06929fff..21e6246d1 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -296,15 +296,9 @@ object VeloxBackendSettings extends BackendSettingsApi {
case _ => throw new GlutenNotSupportException(s"$func is not
supported.")
}
- // Block the offloading by checking Velox's current limitations
- // when literal bound type is used for RangeFrame.
def checkLimitations(swf: SpecifiedWindowFrame, orderSpec:
Seq[SortOrder]): Unit = {
- def doCheck(bound: Expression, isUpperBound: Boolean): Unit = {
+ def doCheck(bound: Expression): Unit = {
bound match {
- case e if e.foldable =>
- throw new GlutenNotSupportException(
- "Window frame of type RANGE does" +
- " not support constant arguments in velox backend")
case _: SpecialFrameBoundary =>
case e if e.foldable =>
orderSpec.foreach(
@@ -325,17 +319,11 @@ object VeloxBackendSettings extends BackendSettingsApi {
"Only integral type & date type are" +
" supported for sort key when literal bound type
is used!")
})
- val rawValue = e.eval().toString.toLong
- if (isUpperBound && rawValue < 0) {
- throw new GlutenNotSupportException("Negative upper bound
is not supported!")
- } else if (!isUpperBound && rawValue > 0) {
- throw new GlutenNotSupportException("Positive lower bound
is not supported!")
- }
case _ =>
}
}
- doCheck(swf.upper, true)
- doCheck(swf.lower, false)
+ doCheck(swf.upper)
+ doCheck(swf.lower)
}
windowExpression.windowSpec.frameSpecification match {
@@ -495,4 +483,6 @@ object VeloxBackendSettings extends BackendSettingsApi {
override def supportColumnarArrowUdf(): Boolean = true
override def generateHdfsConfForLibhdfs(): Boolean = true
+
+ override def needPreComputeRangeFrameBoundary(): Boolean = true
}
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
index ae8d64a09..3cf485aac 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
@@ -212,17 +212,56 @@ class TestOperator extends
VeloxWholeStageTransformerSuite with AdaptiveSparkPla
Seq("sort", "streaming").foreach {
windowType =>
withSQLConf("spark.gluten.sql.columnar.backend.velox.window.type" ->
windowType) {
+ runQueryAndCompare(
+ "select max(l_partkey) over" +
+ " (partition by l_suppkey order by l_orderkey" +
+ " RANGE BETWEEN 1 PRECEDING AND CURRENT ROW), " +
+ "min(l_comment) over" +
+ " (partition by l_suppkey order by l_linenumber" +
+ " RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") {
+ checkSparkOperatorMatch[WindowExecTransformer]
+ }
+
runQueryAndCompare(
"select max(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey" +
" RANGE BETWEEN CURRENT ROW AND 2 FOLLOWING) from lineitem ") {
- checkSparkOperatorMatch[WindowExec]
+ checkSparkOperatorMatch[WindowExecTransformer]
}
runQueryAndCompare(
"select max(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey" +
" RANGE BETWEEN 6 PRECEDING AND CURRENT ROW) from lineitem ") {
+ checkSparkOperatorMatch[WindowExecTransformer]
+ }
+
+ runQueryAndCompare(
+ "select max(l_partkey) over" +
+ " (partition by l_suppkey order by l_orderkey" +
+ " RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from lineitem ") {
+ checkSparkOperatorMatch[WindowExecTransformer]
+ }
+
+ runQueryAndCompare(
+ "select max(l_partkey) over" +
+ " (partition by l_suppkey order by l_orderkey" +
+ " RANGE BETWEEN 6 PRECEDING AND 3 PRECEDING) from lineitem ") {
+ checkSparkOperatorMatch[WindowExecTransformer]
+ }
+
+ runQueryAndCompare(
+ "select max(l_partkey) over" +
+ " (partition by l_suppkey order by l_orderkey" +
+ " RANGE BETWEEN 3 FOLLOWING AND 6 FOLLOWING) from lineitem ") {
+ checkSparkOperatorMatch[WindowExecTransformer]
+ }
+
+ // DecimalType as order by column is not supported
+ runQueryAndCompare(
+ "select min(l_comment) over" +
+ " (partition by l_suppkey order by l_discount" +
+ " RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") {
checkSparkOperatorMatch[WindowExec]
}
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index b82eead2c..4e875d479 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -823,10 +823,11 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
nextPlanNodeId(), replicated, unnest, std::move(unnestNames),
ordinalityName, childNode);
}
-const core::WindowNode::Frame createWindowFrame(
+const core::WindowNode::Frame SubstraitToVeloxPlanConverter::createWindowFrame(
const ::substrait::Expression_WindowFunction_Bound& lower_bound,
const ::substrait::Expression_WindowFunction_Bound& upper_bound,
- const ::substrait::WindowType& type) {
+ const ::substrait::WindowType& type,
+ const RowTypePtr& inputType) {
core::WindowNode::Frame frame;
switch (type) {
case ::substrait::WindowType::ROWS:
@@ -839,9 +840,22 @@ const core::WindowNode::Frame createWindowFrame(
VELOX_FAIL("the window type only support ROWS and RANGE, and the input
type is ", std::to_string(type));
}
- auto boundTypeConversion = [](::substrait::Expression_WindowFunction_Bound
boundType)
+ auto specifiedBound =
+ [&](bool hasOffset, int64_t offset, const ::substrait::Expression&
columnRef) -> core::TypedExprPtr {
+ if (hasOffset) {
+ VELOX_CHECK(
+ frame.type != core::WindowNode::WindowType::kRange,
+ "for RANGE frame offset, we should pre-calculate the range frame
boundary and pass the column reference, but got a constant offset.")
+ return std::make_shared<core::ConstantTypedExpr>(BIGINT(),
variant(offset));
+ } else {
+ VELOX_CHECK(
+ frame.type != core::WindowNode::WindowType::kRows, "for ROW frame
offset, we should pass a constant offset.")
+ return exprConverter_->toVeloxExpr(columnRef, inputType);
+ }
+ };
+
+ auto boundTypeConversion = [&](::substrait::Expression_WindowFunction_Bound
boundType)
-> std::tuple<core::WindowNode::BoundType, core::TypedExprPtr> {
- // TODO: support non-literal expression.
if (boundType.has_current_row()) {
return std::make_tuple(core::WindowNode::BoundType::kCurrentRow,
nullptr);
} else if (boundType.has_unbounded_following()) {
@@ -849,13 +863,15 @@ const core::WindowNode::Frame createWindowFrame(
} else if (boundType.has_unbounded_preceding()) {
return std::make_tuple(core::WindowNode::BoundType::kUnboundedPreceding,
nullptr);
} else if (boundType.has_following()) {
+ auto following = boundType.following();
return std::make_tuple(
core::WindowNode::BoundType::kFollowing,
- std::make_shared<core::ConstantTypedExpr>(BIGINT(),
variant(boundType.following().offset())));
+ specifiedBound(following.has_offset(), following.offset(),
following.ref()));
} else if (boundType.has_preceding()) {
+ auto preceding = boundType.preceding();
return std::make_tuple(
core::WindowNode::BoundType::kPreceding,
- std::make_shared<core::ConstantTypedExpr>(BIGINT(),
variant(boundType.preceding().offset())));
+ specifiedBound(preceding.has_offset(), preceding.offset(),
preceding.ref()));
} else {
VELOX_FAIL("The BoundType is not supported.");
}
@@ -906,7 +922,7 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
windowColumnNames.push_back(windowFunction.column_name());
windowNodeFunctions.push_back(
- {std::move(windowCall), std::move(createWindowFrame(lowerBound,
upperBound, type)), ignoreNulls});
+ {std::move(windowCall), std::move(createWindowFrame(lowerBound,
upperBound, type, inputType)), ignoreNulls});
}
// Construct partitionKeys
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.h
b/cpp/velox/substrait/SubstraitToVeloxPlan.h
index 567ebb215..3a0e677af 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.h
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.h
@@ -555,6 +555,12 @@ class SubstraitToVeloxPlanConverter {
return toVeloxPlan(rel.input());
}
+ const core::WindowNode::Frame createWindowFrame(
+ const ::substrait::Expression_WindowFunction_Bound& lower_bound,
+ const ::substrait::Expression_WindowFunction_Bound& upper_bound,
+ const ::substrait::WindowType& type,
+ const RowTypePtr& inputType);
+
/// The unique identification for each PlanNode.
int planNodeId_ = 0;
diff --git a/docs/developers/SubstraitModifications.md
b/docs/developers/SubstraitModifications.md
index 38406425a..24a9c1a21 100644
--- a/docs/developers/SubstraitModifications.md
+++ b/docs/developers/SubstraitModifications.md
@@ -27,6 +27,7 @@ changed `Unbounded` in `WindowFunction` into
`Unbounded_Preceding` and `Unbounde
* Added `PartitionColumn` in
`LocalFiles`([#2405](https://github.com/apache/incubator-gluten/pull/2405)).
* Added `WriteRel`
([#3690](https://github.com/apache/incubator-gluten/pull/3690)).
* Added `TopNRel`
([#5409](https://github.com/apache/incubator-gluten/pull/5409)).
+* Added `ref` field in window bound `Preceding` and `Following`
([#5626](https://github.com/apache/incubator-gluten/pull/5626)).
## Modifications to type.proto
diff --git
a/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java
b/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java
index 5d106938c..e322e1528 100644
---
a/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java
+++
b/gluten-core/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java
@@ -21,6 +21,8 @@ import org.apache.gluten.expression.ConverterUtils;
import org.apache.gluten.substrait.type.*;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.Attribute;
+import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.*;
@@ -264,9 +266,10 @@ public class ExpressionBuilder {
List<ExpressionNode> expressionNodes,
String columnName,
TypeNode outputTypeNode,
- String upperBound,
- String lowerBound,
- String frameType) {
+ Expression upperBound,
+ Expression lowerBound,
+ String frameType,
+ List<Attribute> originalInputAttributes) {
return makeWindowFunction(
functionId,
expressionNodes,
@@ -275,7 +278,8 @@ public class ExpressionBuilder {
upperBound,
lowerBound,
frameType,
- false);
+ false,
+ originalInputAttributes);
}
public static WindowFunctionNode makeWindowFunction(
@@ -283,10 +287,11 @@ public class ExpressionBuilder {
List<ExpressionNode> expressionNodes,
String columnName,
TypeNode outputTypeNode,
- String upperBound,
- String lowerBound,
+ Expression upperBound,
+ Expression lowerBound,
String frameType,
- boolean ignoreNulls) {
+ boolean ignoreNulls,
+ List<Attribute> originalInputAttributes) {
return new WindowFunctionNode(
functionId,
expressionNodes,
@@ -295,6 +300,7 @@ public class ExpressionBuilder {
upperBound,
lowerBound,
frameType,
- ignoreNulls);
+ ignoreNulls,
+ originalInputAttributes);
}
}
diff --git
a/gluten-core/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java
b/gluten-core/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java
index 67d0d6e57..b9f1fbc12 100644
---
a/gluten-core/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java
+++
b/gluten-core/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java
@@ -16,17 +16,24 @@
*/
package org.apache.gluten.substrait.expression;
+import org.apache.gluten.exception.GlutenException;
+import org.apache.gluten.expression.ExpressionConverter;
import org.apache.gluten.substrait.type.TypeNode;
import io.substrait.proto.Expression;
import io.substrait.proto.FunctionArgument;
import io.substrait.proto.FunctionOption;
import io.substrait.proto.WindowType;
+import org.apache.spark.sql.catalyst.expressions.Attribute;
+import org.apache.spark.sql.catalyst.expressions.PreComputeRangeFrameBound;
import java.io.Serializable;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import scala.collection.JavaConverters;
+
public class WindowFunctionNode implements Serializable {
private final Integer functionId;
private final List<ExpressionNode> expressionNodes = new ArrayList<>();
@@ -34,23 +41,26 @@ public class WindowFunctionNode implements Serializable {
private final String columnName;
private final TypeNode outputTypeNode;
- private final String upperBound;
+ private final org.apache.spark.sql.catalyst.expressions.Expression
upperBound;
- private final String lowerBound;
+ private final org.apache.spark.sql.catalyst.expressions.Expression
lowerBound;
private final String frameType;
private final boolean ignoreNulls;
+ private final List<Attribute> originalInputAttributes;
+
WindowFunctionNode(
Integer functionId,
List<ExpressionNode> expressionNodes,
String columnName,
TypeNode outputTypeNode,
- String upperBound,
- String lowerBound,
+ org.apache.spark.sql.catalyst.expressions.Expression upperBound,
+ org.apache.spark.sql.catalyst.expressions.Expression lowerBound,
String frameType,
- boolean ignoreNulls) {
+ boolean ignoreNulls,
+ List<Attribute> originalInputAttributes) {
this.functionId = functionId;
this.expressionNodes.addAll(expressionNodes);
this.columnName = columnName;
@@ -59,11 +69,13 @@ public class WindowFunctionNode implements Serializable {
this.lowerBound = lowerBound;
this.frameType = frameType;
this.ignoreNulls = ignoreNulls;
+ this.originalInputAttributes = originalInputAttributes;
}
private Expression.WindowFunction.Bound.Builder setBound(
- Expression.WindowFunction.Bound.Builder builder, String boundType) {
- switch (boundType) {
+ Expression.WindowFunction.Bound.Builder builder,
+ org.apache.spark.sql.catalyst.expressions.Expression boundType) {
+ switch (boundType.sql()) {
case ("CURRENT ROW"):
Expression.WindowFunction.Bound.CurrentRow.Builder currentRowBuilder =
Expression.WindowFunction.Bound.CurrentRow.newBuilder();
@@ -80,8 +92,36 @@ public class WindowFunctionNode implements Serializable {
builder.setUnboundedFollowing(followingBuilder.build());
break;
default:
- try {
- Long offset = Long.valueOf(boundType);
+ if (boundType instanceof PreComputeRangeFrameBound) {
+ // Used only when backend is velox and frame type is RANGE.
+ if (!frameType.equals("RANGE")) {
+ throw new GlutenException(
+ "Only Range frame supports PreComputeRangeFrameBound, but got
" + frameType);
+ }
+ ExpressionNode refNode =
+ ExpressionConverter.replaceWithExpressionTransformer(
+ ((PreComputeRangeFrameBound)
boundType).child().toAttribute(),
+
JavaConverters.asScalaIteratorConverter(originalInputAttributes.iterator())
+ .asScala()
+ .toSeq())
+ .doTransform(new HashMap<String, Long>());
+ Long offset = Long.valueOf(boundType.eval(null).toString());
+ if (offset < 0) {
+ Expression.WindowFunction.Bound.Preceding.Builder
refPrecedingBuilder =
+ Expression.WindowFunction.Bound.Preceding.newBuilder();
+ refPrecedingBuilder.setRef(refNode.toProtobuf());
+ builder.setPreceding(refPrecedingBuilder.build());
+ } else {
+ Expression.WindowFunction.Bound.Following.Builder
refFollowingBuilder =
+ Expression.WindowFunction.Bound.Following.newBuilder();
+ refFollowingBuilder.setRef(refNode.toProtobuf());
+ builder.setFollowing(refFollowingBuilder.build());
+ }
+ } else if (boundType.foldable()) {
+ // Used when
+ // 1. Velox backend and frame type is ROW
+ // 2. Clickhouse backend
+ Long offset = Long.valueOf(boundType.eval(null).toString());
if (offset < 0) {
Expression.WindowFunction.Bound.Preceding.Builder
offsetPrecedingBuilder =
Expression.WindowFunction.Bound.Preceding.newBuilder();
@@ -93,9 +133,9 @@ public class WindowFunctionNode implements Serializable {
offsetFollowingBuilder.setOffset(offset);
builder.setFollowing(offsetFollowingBuilder.build());
}
- } catch (NumberFormatException e) {
+ } else {
throw new UnsupportedOperationException(
- "Unsupported Window Function Frame Type:" + boundType);
+ "Unsupported Window Function Frame Bound Type: " + boundType);
}
}
return builder;
diff --git
a/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto
b/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto
index 877493439..0e51baf5a 100644
--- a/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto
+++ b/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto
@@ -996,18 +996,28 @@ message Expression {
message Bound {
// Defines that the bound extends this far back from the current record.
message Preceding {
- // A strictly positive integer specifying the number of records that
- // the window extends back from the current record. Required. Use
- // CurrentRow for offset zero and Following for negative offsets.
- int64 offset = 1;
+ oneof kind {
+ // A strictly positive integer specifying the number of records that
+ // the window extends back from the current record. Use
+ // CurrentRow for offset zero and Following for negative offsets.
+ int64 offset = 1;
+
+ // the reference to pre-project range frame boundary.
+ Expression ref = 2;
+ }
}
// Defines that the bound extends this far ahead of the current record.
message Following {
- // A strictly positive integer specifying the number of records that
- // the window extends ahead of the current record. Required. Use
- // CurrentRow for offset zero and Preceding for negative offsets.
- int64 offset = 1;
+ oneof kind {
+ // A strictly positive integer specifying the number of records that
+ // the window extends ahead of the current record. Use
+ // CurrentRow for offset zero and Preceding for negative offsets.
+ int64 offset = 1;
+
+ // the reference to pre-project range frame boundary.
+ Expression ref = 2;
+ }
}
// Defines that the bound extends to or from the current record.
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
index d18273af2..b7a3bc1b6 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
@@ -152,4 +152,6 @@ trait BackendSettingsApi {
def supportColumnarArrowUdf(): Boolean = false
def generateHdfsConfForLibhdfs(): Boolean = false
+
+ def needPreComputeRangeFrameBoundary(): Boolean = false
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 8a1baae51..8bc8e136b 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -529,9 +529,10 @@ trait SparkPlanExecApi {
new JArrayList[ExpressionNode](),
columnName,
ConverterUtils.getTypeNode(aggWindowFunc.dataType,
aggWindowFunc.nullable),
- WindowExecTransformer.getFrameBound(frame.upper),
- WindowExecTransformer.getFrameBound(frame.lower),
- frame.frameType.sql
+ frame.upper,
+ frame.lower,
+ frame.frameType.sql,
+ originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case aggExpression: AggregateExpression =>
@@ -554,9 +555,10 @@ trait SparkPlanExecApi {
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(aggExpression.dataType,
aggExpression.nullable),
- WindowExecTransformer.getFrameBound(frame.upper),
- WindowExecTransformer.getFrameBound(frame.lower),
- frame.frameType.sql
+ frame.upper,
+ frame.lower,
+ frame.frameType.sql,
+ originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case wf @ (_: Lead | _: Lag) =>
@@ -590,10 +592,11 @@ trait SparkPlanExecApi {
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
- WindowExecTransformer.getFrameBound(frame.upper),
- WindowExecTransformer.getFrameBound(frame.lower),
+ frame.upper,
+ frame.lower,
frame.frameType.sql,
- offsetWf.ignoreNulls
+ offsetWf.ignoreNulls,
+ originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case wf @ NthValue(input, offset: Literal, ignoreNulls: Boolean) =>
@@ -609,10 +612,11 @@ trait SparkPlanExecApi {
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
- frame.upper.sql,
- frame.lower.sql,
+ frame.upper,
+ frame.lower,
frame.frameType.sql,
- ignoreNulls
+ ignoreNulls,
+ originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case wf @ NTile(buckets: Expression) =>
@@ -625,9 +629,10 @@ trait SparkPlanExecApi {
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
- frame.upper.sql,
- frame.lower.sql,
- frame.frameType.sql
+ frame.upper,
+ frame.lower,
+ frame.frameType.sql,
+ originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case _ =>
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala
index ef6a767b5..6832221a4 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala
@@ -197,16 +197,3 @@ case class WindowExecTransformer(
override protected def withNewChildInternal(newChild: SparkPlan):
WindowExecTransformer =
copy(child = newChild)
}
-
-object WindowExecTransformer {
-
- /** Gets lower/upper bound represented in string. */
- def getFrameBound(bound: Expression): String = {
- // The lower/upper can be either a foldable Expression or a
SpecialFrameBoundary.
- if (bound.foldable) {
- bound.eval().toString
- } else {
- bound.sql
- }
- }
-}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
index 50dc55423..73b8ab260 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
@@ -75,6 +75,17 @@ object PullOutPreProject extends RewriteSingleNode with
PullOutProjectHelper {
case _ => false
}
case _ => false
+ }.isDefined) ||
+ window.windowExpression.exists(_.find {
+ case we: WindowExpression =>
+ we.windowSpec.frameSpecification match {
+ case swf: SpecifiedWindowFrame
+ if needPreComputeRangeFrame(swf) &&
supportPreComputeRangeFrame(
+ we.windowSpec.orderSpec) =>
+ true
+ case _ => false
+ }
+ case _ => false
}.isDefined)
case plan if SparkShimLoader.getSparkShims.isWindowGroupLimitExec(plan)
=>
val window = SparkShimLoader.getSparkShims
@@ -174,7 +185,9 @@ object PullOutPreProject extends RewriteSingleNode with
PullOutProjectHelper {
// Handle windowExpressions.
val newWindowExpressions = window.windowExpression.toIndexedSeq.map {
- _.transform { case we: WindowExpression => rewriteWindowExpression(we,
expressionMap) }
+ _.transform {
+ case we: WindowExpression => rewriteWindowExpression(we,
newOrderSpec, expressionMap)
+ }
}
val newWindow = window.copy(
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala
b/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala
index 505f13f26..12055f9e9 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala
@@ -16,11 +16,13 @@
*/
package org.apache.gluten.utils
-import org.apache.gluten.exception.GlutenNotSupportException
+import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException}
import org.apache.spark.sql.catalyst.expressions._
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
AggregateFunction}
import org.apache.spark.sql.execution.aggregate._
+import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType,
ShortType}
import java.util.concurrent.atomic.AtomicInteger
@@ -143,8 +145,49 @@ trait PullOutProjectHelper {
ae.copy(aggregateFunction = newAggFunc, filter = newFilter)
}
+ private def needPreComputeRangeFrameBoundary(bound: Expression): Boolean = {
+ bound match {
+ case _: PreComputeRangeFrameBound => false
+ case _ if !bound.foldable => false
+ case _ => true
+ }
+ }
+
+ private def preComputeRangeFrameBoundary(
+ bound: Expression,
+ orderSpec: SortOrder,
+ expressionMap: mutable.HashMap[Expression, NamedExpression]): Expression
= {
+ bound match {
+ case _: PreComputeRangeFrameBound => bound
+ case _ if !bound.foldable => bound
+ case _ if bound.foldable =>
+ val a = expressionMap
+ .getOrElseUpdate(
+ bound.canonicalized,
+ Alias(Add(orderSpec.child, bound), generatePreAliasName)())
+ PreComputeRangeFrameBound(a.asInstanceOf[Alias], bound)
+ }
+ }
+
+ protected def needPreComputeRangeFrame(swf: SpecifiedWindowFrame): Boolean =
{
+ BackendsApiManager.getSettings.needPreComputeRangeFrameBoundary &&
+ swf.frameType == RangeFrame &&
+ (needPreComputeRangeFrameBoundary(swf.lower) ||
needPreComputeRangeFrameBoundary(swf.upper))
+ }
+
+ protected def supportPreComputeRangeFrame(sortOrders: Seq[SortOrder]):
Boolean = {
+ sortOrders.forall {
+ _.dataType match {
+ case ByteType | ShortType | IntegerType | LongType | DateType => true
+ // Only integral type & date type are supported for sort key with
Range Frame
+ case _ => false
+ }
+ }
+ }
+
protected def rewriteWindowExpression(
we: WindowExpression,
+ orderSpecs: Seq[SortOrder],
expressionMap: mutable.HashMap[Expression, NamedExpression]):
WindowExpression = {
val newWindowFunc = we.windowFunction match {
case windowFunc: WindowFunction =>
@@ -156,6 +199,22 @@ trait PullOutProjectHelper {
case ae: AggregateExpression => rewriteAggregateExpression(ae,
expressionMap)
case other => other
}
- we.copy(windowFunction = newWindowFunc)
+
+ val newWindowSpec = we.windowSpec.frameSpecification match {
+ case swf: SpecifiedWindowFrame if needPreComputeRangeFrame(swf) =>
+ // This is guaranteed by Spark, but we still check it here
+ if (orderSpecs.size != 1) {
+ throw new GlutenException(
+ s"A range window frame with value boundaries expects one and only
one " +
+ s"order by expression: ${orderSpecs.mkString(",")}")
+ }
+ val orderSpec = orderSpecs.head
+ val lowerFrameCol = preComputeRangeFrameBoundary(swf.lower, orderSpec,
expressionMap)
+ val upperFrameCol = preComputeRangeFrameBoundary(swf.upper, orderSpec,
expressionMap)
+ val newFrame = swf.copy(lower = lowerFrameCol, upper = upperFrameCol)
+ we.windowSpec.copy(frameSpecification = newFrame)
+ case _ => we.windowSpec
+ }
+ we.copy(windowFunction = newWindowFunc, windowSpec = newWindowSpec)
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/PreComputeRangeFrameBound.scala
b/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/PreComputeRangeFrameBound.scala
new file mode 100644
index 000000000..73c1cb3de
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/PreComputeRangeFrameBound.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode}
+import org.apache.spark.sql.types.{DataType, Metadata}
+
+/**
+ * Represents a pre-compute boundary for range frame when boundary is
non-SpecialFrameBoundary,
+ * since Velox doesn't support constant offset for range frame. It acts like
the original boundary
+ * which is foldable and generate the same result when eval is invoked so that
if the WindowExec
+ * fallback to Vanilla Spark it can still work correctly.
+ * @param child
+ * The alias to pre-compute projection column
+ * @param originalBound
+ * The original boundary which is a foldable expression
+ */
+case class PreComputeRangeFrameBound(child: Alias, originalBound: Expression)
+ extends UnaryExpression
+ with NamedExpression {
+
+ override def foldable: Boolean = true
+
+ override def eval(input: InternalRow): Any = originalBound.eval(input)
+
+ override def genCode(ctx: CodegenContext): ExprCode =
originalBound.genCode(ctx)
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode =
+ originalBound.genCode(ctx)
+
+ override def name: String = child.name
+
+ override def exprId: ExprId = child.exprId
+
+ override def qualifier: Seq[String] = child.qualifier
+
+ override def newInstance(): NamedExpression =
+ PreComputeRangeFrameBound(child.newInstance().asInstanceOf[Alias],
originalBound)
+
+ override lazy val resolved: Boolean = originalBound.resolved
+
+ override def dataType: DataType = child.dataType
+
+ override def nullable: Boolean = child.nullable
+
+ override def metadata: Metadata = child.metadata
+
+ override def toAttribute: Attribute = child.toAttribute
+
+ override def toString: String = child.toString
+
+ override def hashCode(): Int = child.hashCode()
+
+ override def equals(other: Any): Boolean = other match {
+ case a: PreComputeRangeFrameBound =>
+ child.equals(a.child)
+ case _ => false
+ }
+
+ override def sql: String = child.sql
+
+ override protected def withNewChildInternal(newChild: Expression):
PreComputeRangeFrameBound =
+ copy(child = newChild.asInstanceOf[Alias])
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]