This is an automated email from the ASF dual-hosted git repository.
philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 92ca76d22 [VL] Support lead window function (#4902)
92ca76d22 is described below
commit 92ca76d22c8e446f790701a2f252de60c6633273
Author: Xiduo You <[email protected]>
AuthorDate: Mon Mar 11 20:50:01 2024 +0800
[VL] Support lead window function (#4902)
---
.../backendsapi/velox/VeloxBackend.scala | 4 +-
.../io/glutenproject/execution/TestOperator.scala | 6 +++
cpp/velox/substrait/SubstraitToVeloxPlan.cc | 27 ++++-------
cpp/velox/substrait/SubstraitToVeloxPlan.h | 3 --
.../substrait/expression/ExpressionBuilder.java | 29 +++++++++++-
.../substrait/expression/WindowFunctionNode.java | 12 ++++-
.../backendsapi/SparkPlanExecApi.scala | 50 ++++++++++++++------
.../sql/GlutenDataFrameWindowFunctionsSuite.scala | 7 +++
.../utils/velox/VeloxTestSettings.scala | 10 ++++
.../sql/GlutenDataFrameWindowFunctionsSuite.scala | 55 ++++++++++++++++++++++
.../utils/velox/VeloxTestSettings.scala | 10 ++++
.../sql/GlutenDataFrameWindowFunctionsSuite.scala | 55 ++++++++++++++++++++++
12 files changed, 228 insertions(+), 40 deletions(-)
diff --git
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala
index c316d50f0..bab5e68ec 100644
---
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala
+++
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala
@@ -26,7 +26,7 @@ import
io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import
io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat.{DwrfReadFormat,
OrcReadFormat, ParquetReadFormat}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
-import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank,
Descending, Expression, Lag, Literal, NamedExpression, NthValue, NTile,
PercentRank, Rand, RangeFrame, Rank, RowNumber, SortOrder,
SpecialFrameBoundary, SpecifiedWindowFrame}
+import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank,
Descending, Expression, Lag, Lead, Literal, NamedExpression, NthValue, NTile,
PercentRank, Rand, RangeFrame, Rank, RowNumber, SortOrder,
SpecialFrameBoundary, SpecifiedWindowFrame}
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Count, Sum}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
@@ -317,7 +317,7 @@ object BackendSettings extends BackendSettingsApi {
}
windowExpression.windowFunction match {
case _: RowNumber | _: AggregateExpression | _: Rank | _: CumeDist
| _: DenseRank |
- _: PercentRank | _: NthValue | _: NTile | _: Lag =>
+ _: PercentRank | _: NthValue | _: NTile | _: Lag | _: Lead =>
case _ =>
allSupported = false
}
diff --git
a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
index f2b6fe8b3..961a0c3de 100644
---
a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
+++
b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
@@ -283,6 +283,12 @@ class TestOperator extends VeloxWholeStageTransformerSuite
{
assertWindowOffloaded
}
+ runQueryAndCompare(
+ "select lead(l_orderkey) over" +
+ " (partition by l_suppkey order by l_orderkey) from lineitem ") {
+ assertWindowOffloaded
+ }
+
// Test same partition/ordering keys.
runQueryAndCompare(
"select avg(l_partkey) over" +
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index 4f0bdfa25..100dcb9eb 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -826,7 +826,6 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
// Parse measures and get the window expressions.
// Each measure represents one window expression.
- bool ignoreNulls = false;
std::vector<core::WindowNode::Function> windowNodeFunctions;
std::vector<std::string> windowColumnNames;
@@ -837,23 +836,15 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
std::vector<core::TypedExprPtr> windowParams;
auto& argumentList = windowFunction.arguments();
windowParams.reserve(argumentList.size());
+ const auto& options = windowFunction.options();
// For functions in kOffsetWindowFunctions (see Spark
OffsetWindowFunctions),
- // we expect the last arg is passed for setting ignoreNulls.
- if (kOffsetWindowFunctions.find(funcName) != kOffsetWindowFunctions.end())
{
- int i = 0;
- for (; i < argumentList.size() - 1; i++) {
-
windowParams.emplace_back(exprConverter_->toVeloxExpr(argumentList[i].value(),
inputType));
- }
- auto constantTypedExpr =
exprConverter_->toVeloxExpr(argumentList[i].value().literal());
- auto variant = constantTypedExpr->value();
- if (!variant.hasValue()) {
- VELOX_FAIL("Value is expected in variant for setting ignoreNulls.");
- }
- ignoreNulls = variant.value<bool>();
- } else {
- for (const auto& arg : argumentList) {
- windowParams.emplace_back(exprConverter_->toVeloxExpr(arg.value(),
inputType));
- }
+ // we expect the first option name is `ignoreNulls` if ignoreNulls is true.
+ bool ignoreNulls = false;
+ if (!options.empty() && options.at(0).name() == "ignoreNulls") {
+ ignoreNulls = true;
+ }
+ for (const auto& arg : argumentList) {
+ windowParams.emplace_back(exprConverter_->toVeloxExpr(arg.value(),
inputType));
}
auto windowVeloxType =
SubstraitParser::parseType(windowFunction.output_type());
auto windowCall = std::make_shared<const
core::CallTypedExpr>(windowVeloxType, std::move(windowParams), funcName);
@@ -864,7 +855,7 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
windowColumnNames.push_back(windowFunction.column_name());
windowNodeFunctions.push_back(
- {std::move(windowCall), createWindowFrame(lowerBound, upperBound,
type), ignoreNulls});
+ {std::move(windowCall), std::move(createWindowFrame(lowerBound,
upperBound, type)), ignoreNulls});
}
// Construct partitionKeys
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.h
b/cpp/velox/substrait/SubstraitToVeloxPlan.h
index adc3b5ec0..21a318b91 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.h
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.h
@@ -26,9 +26,6 @@
namespace gluten {
class ResultIterator;
-// Holds names of Spark OffsetWindowFunctions.
-static const std::unordered_set<std::string> kOffsetWindowFunctions =
{"nth_value", "lag"};
-
struct SplitInfo {
/// Whether the split comes from arrow array stream node.
bool isStream = false;
diff --git
a/gluten-core/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java
b/gluten-core/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java
index 81194fedc..6197ded9e 100644
---
a/gluten-core/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java
+++
b/gluten-core/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java
@@ -266,7 +266,34 @@ public class ExpressionBuilder {
String upperBound,
String lowerBound,
String frameType) {
+ return makeWindowFunction(
+ functionId,
+ expressionNodes,
+ columnName,
+ outputTypeNode,
+ upperBound,
+ lowerBound,
+ frameType,
+ false);
+ }
+
+ public static WindowFunctionNode makeWindowFunction(
+ Integer functionId,
+ List<ExpressionNode> expressionNodes,
+ String columnName,
+ TypeNode outputTypeNode,
+ String upperBound,
+ String lowerBound,
+ String frameType,
+ boolean ignoreNulls) {
return new WindowFunctionNode(
- functionId, expressionNodes, columnName, outputTypeNode, upperBound,
lowerBound, frameType);
+ functionId,
+ expressionNodes,
+ columnName,
+ outputTypeNode,
+ upperBound,
+ lowerBound,
+ frameType,
+ ignoreNulls);
}
}
diff --git
a/gluten-core/src/main/java/io/glutenproject/substrait/expression/WindowFunctionNode.java
b/gluten-core/src/main/java/io/glutenproject/substrait/expression/WindowFunctionNode.java
index 3d4054196..8d69e31ad 100644
---
a/gluten-core/src/main/java/io/glutenproject/substrait/expression/WindowFunctionNode.java
+++
b/gluten-core/src/main/java/io/glutenproject/substrait/expression/WindowFunctionNode.java
@@ -20,6 +20,7 @@ import io.glutenproject.substrait.type.TypeNode;
import io.substrait.proto.Expression;
import io.substrait.proto.FunctionArgument;
+import io.substrait.proto.FunctionOption;
import io.substrait.proto.WindowType;
import java.io.Serializable;
@@ -39,6 +40,8 @@ public class WindowFunctionNode implements Serializable {
private final String frameType;
+ private final boolean ignoreNulls;
+
WindowFunctionNode(
Integer functionId,
List<ExpressionNode> expressionNodes,
@@ -46,7 +49,8 @@ public class WindowFunctionNode implements Serializable {
TypeNode outputTypeNode,
String upperBound,
String lowerBound,
- String frameType) {
+ String frameType,
+ boolean ignoreNulls) {
this.functionId = functionId;
this.expressionNodes.addAll(expressionNodes);
this.columnName = columnName;
@@ -54,6 +58,7 @@ public class WindowFunctionNode implements Serializable {
this.upperBound = upperBound;
this.lowerBound = lowerBound;
this.frameType = frameType;
+ this.ignoreNulls = ignoreNulls;
}
private Expression.WindowFunction.Bound.Builder setBound(
@@ -114,7 +119,10 @@ public class WindowFunctionNode implements Serializable {
public Expression.WindowFunction toProtobuf() {
Expression.WindowFunction.Builder windowBuilder =
Expression.WindowFunction.newBuilder();
windowBuilder.setFunctionReference(functionId);
-
+ if (ignoreNulls) {
+ FunctionOption option =
FunctionOption.newBuilder().setName("ignoreNulls").build();
+ windowBuilder.addOptions(option);
+ }
for (ExpressionNode expressionNode : expressionNodes) {
FunctionArgument.Builder functionArgument =
FunctionArgument.newBuilder();
functionArgument.setValue(expressionNode.toProtobuf());
diff --git
a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
index f34e784b3..4379745cc 100644
---
a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
@@ -43,7 +43,7 @@ import
org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
-import org.apache.spark.sql.types.{BooleanType, LongType, NullType, StructType}
+import org.apache.spark.sql.types.{LongType, NullType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import java.lang.{Long => JLong}
@@ -508,7 +508,7 @@ trait SparkPlanExecApi {
frame.frameType.sql
)
windowExpressionNodes.add(windowFunctionNode)
- case wf @ (Lead(_, _, _, _) | Lag(_, _, _, _)) =>
+ case wf @ (_: Lead | _: Lag) =>
val offsetWf = wf.asInstanceOf[FrameLessOffsetWindowFunction]
val frame = offsetWf.frame.asInstanceOf[SpecifiedWindowFrame]
val childrenNodeList = new JArrayList[ExpressionNode]()
@@ -519,11 +519,26 @@ trait SparkPlanExecApi {
attributeSeq = originalInputAttributes)
.doTransform(args))
// Spark only accepts foldable offset. Converts it to LongType
literal.
- val offsetNode = ExpressionBuilder.makeLiteral(
- // Velox always expects positive offset.
-
Math.abs(offsetWf.offset.eval(EmptyRow).asInstanceOf[Int].toLong),
- LongType,
- false)
+ var offset = offsetWf.offset.eval(EmptyRow).asInstanceOf[Int]
+ if (wf.isInstanceOf[Lead]) {
+ if (offset < 0) {
+ // Velox always expects non-negative offset.
+ throw new UnsupportedOperationException(
+ s"${wf.nodeName} does not support negative offset: $offset")
+ }
+ } else {
+ // For Lag
+ // Spark would use `-inputOffset` as offset, so here we forbid
positive offset.
+ // Which means the inputOffset is negative.
+ if (offset > 0) {
+ // Velox always expects non-negative offset.
+ throw new UnsupportedOperationException(
+ s"${wf.nodeName} does not support negative offset: $offset")
+ }
+ // Revert the Spark change and use the original input offset
+ offset = -offset
+ }
+ val offsetNode = ExpressionBuilder.makeLiteral(offset.toLong,
LongType, false)
childrenNodeList.add(offsetNode)
// NullType means Null is the default value. Don't pass it to
native.
if (offsetWf.default.dataType != NullType) {
@@ -534,9 +549,16 @@ trait SparkPlanExecApi {
attributeSeq = originalInputAttributes)
.doTransform(args))
}
- // Always adds ignoreNulls at the end.
- childrenNodeList.add(
- ExpressionBuilder.makeLiteral(offsetWf.ignoreNulls, BooleanType,
false))
+ val ignoreNulls = if (offset == 0) {
+ // This is a workaround for Velox backend, because velox has bug
if the
+ // ignoreNulls is true and offset is 0.
+ // Logically, if offset is 0 the ignoreNulls is always
meaningless, so
+ // this workaround is safe.
+ // TODO, remove this once Velox has fixed it
+ false
+ } else {
+ offsetWf.ignoreNulls
+ }
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(args, offsetWf).toInt,
childrenNodeList,
@@ -544,7 +566,8 @@ trait SparkPlanExecApi {
ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
WindowExecTransformer.getFrameBound(frame.upper),
WindowExecTransformer.getFrameBound(frame.lower),
- frame.frameType.sql
+ frame.frameType.sql,
+ ignoreNulls
)
windowExpressionNodes.add(windowFunctionNode)
case wf @ NthValue(input, offset: Literal, ignoreNulls: Boolean) =>
@@ -555,8 +578,6 @@ trait SparkPlanExecApi {
.replaceWithExpressionTransformer(input, attributeSeq =
originalInputAttributes)
.doTransform(args))
childrenNodeList.add(LiteralTransformer(offset).doTransform(args))
- // Always adds ignoreNulls at the end.
- childrenNodeList.add(ExpressionBuilder.makeLiteral(ignoreNulls,
BooleanType, false))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(args, wf).toInt,
childrenNodeList,
@@ -564,7 +585,8 @@ trait SparkPlanExecApi {
ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
frame.upper.sql,
frame.lower.sql,
- frame.frameType.sql
+ frame.frameType.sql,
+ ignoreNulls
)
windowExpressionNodes.add(windowFunctionNode)
case wf @ NTile(buckets: Expression) =>
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
index f5d021fa5..42757cec9 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.sql
+import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -26,6 +27,12 @@ class GlutenDataFrameWindowFunctionsSuite
import testImplicits._
+ override def sparkConf: SparkConf = {
+ super.sparkConf
+ // avoid single partition
+ .set("spark.sql.shuffle.partitions", "2")
+ }
+
testGluten("covar_samp, var_samp (variance), stddev_samp (stddev) functions
in specific window") {
withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true") {
val df = Seq(
diff --git
a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
index 0e7949356..97dd91ded 100644
---
a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
@@ -983,6 +983,16 @@ class VeloxTestSettings extends BackendTestSettings {
.excludeGlutenTest("describe")
enableSuite[GlutenDataFrameTimeWindowingSuite]
enableSuite[GlutenDataFrameTungstenSuite]
+ enableSuite[GlutenDataFrameWindowFunctionsSuite]
+ // does not support `spark.sql.legacy.statisticalAggregate=true` (null ->
NAN)
+ .exclude("corr, covar_pop, stddev_pop functions in specific window")
+ .exclude("covar_samp, var_samp (variance), stddev_samp (stddev) functions
in specific window")
+ // does not support spill
+ .exclude("Window spill with more than the inMemoryThreshold and
spillThreshold")
+ .exclude("SPARK-21258: complex object in combination with spilling")
+ // rewrite `WindowExec -> WindowExecTransformer`
+ .exclude(
+ "SPARK-38237: require all cluster keys for child required distribution
for window query")
enableSuite[GlutenDataFrameWindowFramesSuite]
// Local window fixes are not added.
.exclude("range between should accept int/long values as boundary")
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
index f5d021fa5..0e32f46da 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
@@ -16,6 +16,13 @@
*/
package org.apache.spark.sql
+import io.glutenproject.execution.WindowExecTransformer
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression}
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.ColumnarShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -26,6 +33,12 @@ class GlutenDataFrameWindowFunctionsSuite
import testImplicits._
+ override def sparkConf: SparkConf = {
+ super.sparkConf
+ // avoid single partition
+ .set("spark.sql.shuffle.partitions", "2")
+ }
+
testGluten("covar_samp, var_samp (variance), stddev_samp (stddev) functions
in specific window") {
withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true") {
val df = Seq(
@@ -151,4 +164,46 @@ class GlutenDataFrameWindowFunctionsSuite
)
}
}
+
+ testGluten(
+ "SPARK-38237: require all cluster keys for child required distribution for
window query") {
+ def partitionExpressionsColumns(expressions: Seq[Expression]): Seq[String]
= {
+ expressions.flatMap { case ref: AttributeReference => Some(ref.name) }
+ }
+
+ def isShuffleExecByRequirement(
+ plan: ColumnarShuffleExchangeExec,
+ desiredClusterColumns: Seq[String]): Boolean = plan match {
+ case ColumnarShuffleExchangeExec(op: HashPartitioning, _,
ENSURE_REQUIREMENTS, _, _) =>
+ partitionExpressionsColumns(op.expressions) === desiredClusterColumns
+ case _ => false
+ }
+
+ val df = Seq(("a", 1, 1), ("a", 2, 2), ("b", 1, 3), ("b", 1,
4)).toDF("key1", "key2", "value")
+ val windowSpec = Window.partitionBy("key1", "key2").orderBy("value")
+
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key -> "true") {
+
+ val windowed = df
+ // repartition by subset of window partitionBy keys which satisfies
ClusteredDistribution
+ .repartition($"key1")
+ .select(lead($"key1", 1).over(windowSpec), lead($"value",
1).over(windowSpec))
+
+ checkAnswer(windowed, Seq(Row("b", 4), Row(null, null), Row(null, null),
Row(null, null)))
+
+ val shuffleByRequirement = windowed.queryExecution.executedPlan.exists {
+ case w: WindowExecTransformer =>
+ w.child.exists {
+ case s: ColumnarShuffleExchangeExec =>
+ isShuffleExecByRequirement(s, Seq("key1", "key2"))
+ case _ => false
+ }
+ case _ => false
+ }
+
+ assert(shuffleByRequirement, "Can't find desired shuffle node from the
query plan")
+ }
+ }
}
diff --git
a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
index 45de795de..1fc6a3099 100644
---
a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
@@ -989,6 +989,16 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("SPARK-41048: Improve output partitioning and ordering with AQE
cache")
enableSuite[GlutenDataFrameTimeWindowingSuite]
enableSuite[GlutenDataFrameTungstenSuite]
+ enableSuite[GlutenDataFrameWindowFunctionsSuite]
+ // does not support `spark.sql.legacy.statisticalAggregate=true` (null ->
NAN)
+ .exclude("corr, covar_pop, stddev_pop functions in specific window")
+ .exclude("covar_samp, var_samp (variance), stddev_samp (stddev) functions
in specific window")
+ // does not support spill
+ .exclude("Window spill with more than the inMemoryThreshold and
spillThreshold")
+ .exclude("SPARK-21258: complex object in combination with spilling")
+ // rewrite `WindowExec -> WindowExecTransformer`
+ .exclude(
+ "SPARK-38237: require all cluster keys for child required distribution
for window query")
enableSuite[GlutenDataFrameWindowFramesSuite]
// Local window fixes are not added.
.exclude("range between should accept int/long values as boundary")
diff --git
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
index f5d021fa5..0e32f46da 100644
---
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
+++
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
@@ -16,6 +16,13 @@
*/
package org.apache.spark.sql
+import io.glutenproject.execution.WindowExecTransformer
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression}
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.ColumnarShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -26,6 +33,12 @@ class GlutenDataFrameWindowFunctionsSuite
import testImplicits._
+ override def sparkConf: SparkConf = {
+ super.sparkConf
+ // avoid single partition
+ .set("spark.sql.shuffle.partitions", "2")
+ }
+
testGluten("covar_samp, var_samp (variance), stddev_samp (stddev) functions
in specific window") {
withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true") {
val df = Seq(
@@ -151,4 +164,46 @@ class GlutenDataFrameWindowFunctionsSuite
)
}
}
+
+ testGluten(
+ "SPARK-38237: require all cluster keys for child required distribution for
window query") {
+ def partitionExpressionsColumns(expressions: Seq[Expression]): Seq[String]
= {
+ expressions.flatMap { case ref: AttributeReference => Some(ref.name) }
+ }
+
+ def isShuffleExecByRequirement(
+ plan: ColumnarShuffleExchangeExec,
+ desiredClusterColumns: Seq[String]): Boolean = plan match {
+ case ColumnarShuffleExchangeExec(op: HashPartitioning, _,
ENSURE_REQUIREMENTS, _, _) =>
+ partitionExpressionsColumns(op.expressions) === desiredClusterColumns
+ case _ => false
+ }
+
+ val df = Seq(("a", 1, 1), ("a", 2, 2), ("b", 1, 3), ("b", 1,
4)).toDF("key1", "key2", "value")
+ val windowSpec = Window.partitionBy("key1", "key2").orderBy("value")
+
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key -> "true") {
+
+ val windowed = df
+ // repartition by subset of window partitionBy keys which satisfies
ClusteredDistribution
+ .repartition($"key1")
+ .select(lead($"key1", 1).over(windowSpec), lead($"value",
1).over(windowSpec))
+
+ checkAnswer(windowed, Seq(Row("b", 4), Row(null, null), Row(null, null),
Row(null, null)))
+
+ val shuffleByRequirement = windowed.queryExecution.executedPlan.exists {
+ case w: WindowExecTransformer =>
+ w.child.exists {
+ case s: ColumnarShuffleExchangeExec =>
+ isShuffleExecByRequirement(s, Seq("key1", "key2"))
+ case _ => false
+ }
+ case _ => false
+ }
+
+ assert(shuffleByRequirement, "Can't find desired shuffle node from the
query plan")
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]