This is an automated email from the ASF dual-hosted git repository.

philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 92ca76d22 [VL] Support lead window function (#4902)
92ca76d22 is described below

commit 92ca76d22c8e446f790701a2f252de60c6633273
Author: Xiduo You <[email protected]>
AuthorDate: Mon Mar 11 20:50:01 2024 +0800

    [VL] Support lead window function (#4902)
---
 .../backendsapi/velox/VeloxBackend.scala           |  4 +-
 .../io/glutenproject/execution/TestOperator.scala  |  6 +++
 cpp/velox/substrait/SubstraitToVeloxPlan.cc        | 27 ++++-------
 cpp/velox/substrait/SubstraitToVeloxPlan.h         |  3 --
 .../substrait/expression/ExpressionBuilder.java    | 29 +++++++++++-
 .../substrait/expression/WindowFunctionNode.java   | 12 ++++-
 .../backendsapi/SparkPlanExecApi.scala             | 50 ++++++++++++++------
 .../sql/GlutenDataFrameWindowFunctionsSuite.scala  |  7 +++
 .../utils/velox/VeloxTestSettings.scala            | 10 ++++
 .../sql/GlutenDataFrameWindowFunctionsSuite.scala  | 55 ++++++++++++++++++++++
 .../utils/velox/VeloxTestSettings.scala            | 10 ++++
 .../sql/GlutenDataFrameWindowFunctionsSuite.scala  | 55 ++++++++++++++++++++++
 12 files changed, 228 insertions(+), 40 deletions(-)

diff --git 
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala
 
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala
index c316d50f0..bab5e68ec 100644
--- 
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala
+++ 
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxBackend.scala
@@ -26,7 +26,7 @@ import 
io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
 import 
io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat.{DwrfReadFormat, 
OrcReadFormat, ParquetReadFormat}
 
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
-import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, 
Descending, Expression, Lag, Literal, NamedExpression, NthValue, NTile, 
PercentRank, Rand, RangeFrame, Rank, RowNumber, SortOrder, 
SpecialFrameBoundary, SpecifiedWindowFrame}
+import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, 
Descending, Expression, Lag, Lead, Literal, NamedExpression, NthValue, NTile, 
PercentRank, Rand, RangeFrame, Rank, RowNumber, SortOrder, 
SpecialFrameBoundary, SpecifiedWindowFrame}
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Count, Sum}
 import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.catalyst.util.CharVarcharUtils
@@ -317,7 +317,7 @@ object BackendSettings extends BackendSettingsApi {
           }
           windowExpression.windowFunction match {
             case _: RowNumber | _: AggregateExpression | _: Rank | _: CumeDist 
| _: DenseRank |
-                _: PercentRank | _: NthValue | _: NTile | _: Lag =>
+                _: PercentRank | _: NthValue | _: NTile | _: Lag | _: Lead =>
             case _ =>
               allSupported = false
           }
diff --git 
a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala 
b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
index f2b6fe8b3..961a0c3de 100644
--- 
a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
+++ 
b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
@@ -283,6 +283,12 @@ class TestOperator extends VeloxWholeStageTransformerSuite 
{
             assertWindowOffloaded
           }
 
+          runQueryAndCompare(
+            "select lead(l_orderkey) over" +
+              " (partition by l_suppkey order by l_orderkey) from lineitem ") {
+            assertWindowOffloaded
+          }
+
           // Test same partition/ordering keys.
           runQueryAndCompare(
             "select avg(l_partkey) over" +
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc 
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index 4f0bdfa25..100dcb9eb 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -826,7 +826,6 @@ core::PlanNodePtr 
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
 
   // Parse measures and get the window expressions.
   // Each measure represents one window expression.
-  bool ignoreNulls = false;
   std::vector<core::WindowNode::Function> windowNodeFunctions;
   std::vector<std::string> windowColumnNames;
 
@@ -837,23 +836,15 @@ core::PlanNodePtr 
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
     std::vector<core::TypedExprPtr> windowParams;
     auto& argumentList = windowFunction.arguments();
     windowParams.reserve(argumentList.size());
+    const auto& options = windowFunction.options();
     // For functions in kOffsetWindowFunctions (see Spark 
OffsetWindowFunctions),
-    // we expect the last arg is passed for setting ignoreNulls.
-    if (kOffsetWindowFunctions.find(funcName) != kOffsetWindowFunctions.end()) 
{
-      int i = 0;
-      for (; i < argumentList.size() - 1; i++) {
-        
windowParams.emplace_back(exprConverter_->toVeloxExpr(argumentList[i].value(), 
inputType));
-      }
-      auto constantTypedExpr = 
exprConverter_->toVeloxExpr(argumentList[i].value().literal());
-      auto variant = constantTypedExpr->value();
-      if (!variant.hasValue()) {
-        VELOX_FAIL("Value is expected in variant for setting ignoreNulls.");
-      }
-      ignoreNulls = variant.value<bool>();
-    } else {
-      for (const auto& arg : argumentList) {
-        windowParams.emplace_back(exprConverter_->toVeloxExpr(arg.value(), 
inputType));
-      }
+    // we expect the first option name is `ignoreNulls` if ignoreNulls is true.
+    bool ignoreNulls = false;
+    if (!options.empty() && options.at(0).name() == "ignoreNulls") {
+      ignoreNulls = true;
+    }
+    for (const auto& arg : argumentList) {
+      windowParams.emplace_back(exprConverter_->toVeloxExpr(arg.value(), 
inputType));
     }
     auto windowVeloxType = 
SubstraitParser::parseType(windowFunction.output_type());
     auto windowCall = std::make_shared<const 
core::CallTypedExpr>(windowVeloxType, std::move(windowParams), funcName);
@@ -864,7 +855,7 @@ core::PlanNodePtr 
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
     windowColumnNames.push_back(windowFunction.column_name());
 
     windowNodeFunctions.push_back(
-        {std::move(windowCall), createWindowFrame(lowerBound, upperBound, 
type), ignoreNulls});
+        {std::move(windowCall), std::move(createWindowFrame(lowerBound, 
upperBound, type)), ignoreNulls});
   }
 
   // Construct partitionKeys
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.h 
b/cpp/velox/substrait/SubstraitToVeloxPlan.h
index adc3b5ec0..21a318b91 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.h
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.h
@@ -26,9 +26,6 @@
 namespace gluten {
 class ResultIterator;
 
-// Holds names of Spark OffsetWindowFunctions.
-static const std::unordered_set<std::string> kOffsetWindowFunctions = 
{"nth_value", "lag"};
-
 struct SplitInfo {
   /// Whether the split comes from arrow array stream node.
   bool isStream = false;
diff --git 
a/gluten-core/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java
 
b/gluten-core/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java
index 81194fedc..6197ded9e 100644
--- 
a/gluten-core/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java
+++ 
b/gluten-core/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java
@@ -266,7 +266,34 @@ public class ExpressionBuilder {
       String upperBound,
       String lowerBound,
       String frameType) {
+    return makeWindowFunction(
+        functionId,
+        expressionNodes,
+        columnName,
+        outputTypeNode,
+        upperBound,
+        lowerBound,
+        frameType,
+        false);
+  }
+
+  public static WindowFunctionNode makeWindowFunction(
+      Integer functionId,
+      List<ExpressionNode> expressionNodes,
+      String columnName,
+      TypeNode outputTypeNode,
+      String upperBound,
+      String lowerBound,
+      String frameType,
+      boolean ignoreNulls) {
     return new WindowFunctionNode(
-        functionId, expressionNodes, columnName, outputTypeNode, upperBound, 
lowerBound, frameType);
+        functionId,
+        expressionNodes,
+        columnName,
+        outputTypeNode,
+        upperBound,
+        lowerBound,
+        frameType,
+        ignoreNulls);
   }
 }
diff --git 
a/gluten-core/src/main/java/io/glutenproject/substrait/expression/WindowFunctionNode.java
 
b/gluten-core/src/main/java/io/glutenproject/substrait/expression/WindowFunctionNode.java
index 3d4054196..8d69e31ad 100644
--- 
a/gluten-core/src/main/java/io/glutenproject/substrait/expression/WindowFunctionNode.java
+++ 
b/gluten-core/src/main/java/io/glutenproject/substrait/expression/WindowFunctionNode.java
@@ -20,6 +20,7 @@ import io.glutenproject.substrait.type.TypeNode;
 
 import io.substrait.proto.Expression;
 import io.substrait.proto.FunctionArgument;
+import io.substrait.proto.FunctionOption;
 import io.substrait.proto.WindowType;
 
 import java.io.Serializable;
@@ -39,6 +40,8 @@ public class WindowFunctionNode implements Serializable {
 
   private final String frameType;
 
+  private final boolean ignoreNulls;
+
   WindowFunctionNode(
       Integer functionId,
       List<ExpressionNode> expressionNodes,
@@ -46,7 +49,8 @@ public class WindowFunctionNode implements Serializable {
       TypeNode outputTypeNode,
       String upperBound,
       String lowerBound,
-      String frameType) {
+      String frameType,
+      boolean ignoreNulls) {
     this.functionId = functionId;
     this.expressionNodes.addAll(expressionNodes);
     this.columnName = columnName;
@@ -54,6 +58,7 @@ public class WindowFunctionNode implements Serializable {
     this.upperBound = upperBound;
     this.lowerBound = lowerBound;
     this.frameType = frameType;
+    this.ignoreNulls = ignoreNulls;
   }
 
   private Expression.WindowFunction.Bound.Builder setBound(
@@ -114,7 +119,10 @@ public class WindowFunctionNode implements Serializable {
   public Expression.WindowFunction toProtobuf() {
     Expression.WindowFunction.Builder windowBuilder = 
Expression.WindowFunction.newBuilder();
     windowBuilder.setFunctionReference(functionId);
-
+    if (ignoreNulls) {
+      FunctionOption option = 
FunctionOption.newBuilder().setName("ignoreNulls").build();
+      windowBuilder.addOptions(option);
+    }
     for (ExpressionNode expressionNode : expressionNodes) {
       FunctionArgument.Builder functionArgument = 
FunctionArgument.newBuilder();
       functionArgument.setValue(expressionNode.toProtobuf());
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
 
b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
index f34e784b3..4379745cc 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala
@@ -43,7 +43,7 @@ import 
org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.joins.BuildSideRelation
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.hive.HiveTableScanExecTransformer
-import org.apache.spark.sql.types.{BooleanType, LongType, NullType, StructType}
+import org.apache.spark.sql.types.{LongType, NullType, StructType}
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
 import java.lang.{Long => JLong}
@@ -508,7 +508,7 @@ trait SparkPlanExecApi {
               frame.frameType.sql
             )
             windowExpressionNodes.add(windowFunctionNode)
-          case wf @ (Lead(_, _, _, _) | Lag(_, _, _, _)) =>
+          case wf @ (_: Lead | _: Lag) =>
             val offsetWf = wf.asInstanceOf[FrameLessOffsetWindowFunction]
             val frame = offsetWf.frame.asInstanceOf[SpecifiedWindowFrame]
             val childrenNodeList = new JArrayList[ExpressionNode]()
@@ -519,11 +519,26 @@ trait SparkPlanExecApi {
                   attributeSeq = originalInputAttributes)
                 .doTransform(args))
             // Spark only accepts foldable offset. Converts it to LongType 
literal.
-            val offsetNode = ExpressionBuilder.makeLiteral(
-              // Velox always expects positive offset.
-              
Math.abs(offsetWf.offset.eval(EmptyRow).asInstanceOf[Int].toLong),
-              LongType,
-              false)
+            var offset = offsetWf.offset.eval(EmptyRow).asInstanceOf[Int]
+            if (wf.isInstanceOf[Lead]) {
+              if (offset < 0) {
+                // Velox always expects non-negative offset.
+                throw new UnsupportedOperationException(
+                  s"${wf.nodeName} does not support negative offset: $offset")
+              }
+            } else {
+              // For Lag
+              // Spark would use `-inputOffset` as offset, so here we forbid 
positive offset.
+              // Which means the inputOffset is negative.
+              if (offset > 0) {
+                // Velox always expects non-negative offset.
+                throw new UnsupportedOperationException(
+                  s"${wf.nodeName} does not support negative offset: $offset")
+              }
+              // Revert the Spark change and use the original input offset
+              offset = -offset
+            }
+            val offsetNode = ExpressionBuilder.makeLiteral(offset.toLong, 
LongType, false)
             childrenNodeList.add(offsetNode)
             // NullType means Null is the default value. Don't pass it to 
native.
             if (offsetWf.default.dataType != NullType) {
@@ -534,9 +549,16 @@ trait SparkPlanExecApi {
                     attributeSeq = originalInputAttributes)
                   .doTransform(args))
             }
-            // Always adds ignoreNulls at the end.
-            childrenNodeList.add(
-              ExpressionBuilder.makeLiteral(offsetWf.ignoreNulls, BooleanType, 
false))
+            val ignoreNulls = if (offset == 0) {
+              // This is a workaround for Velox backend, because velox has bug 
if the
+              // ignoreNulls is true and offset is 0.
+              // Logically, if offset is 0 the ignoreNulls is always 
meaningless, so
+              // this workaround is safe.
+              // TODO, remove this once Velox has fixed it
+              false
+            } else {
+              offsetWf.ignoreNulls
+            }
             val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
               WindowFunctionsBuilder.create(args, offsetWf).toInt,
               childrenNodeList,
@@ -544,7 +566,8 @@ trait SparkPlanExecApi {
               ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
               WindowExecTransformer.getFrameBound(frame.upper),
               WindowExecTransformer.getFrameBound(frame.lower),
-              frame.frameType.sql
+              frame.frameType.sql,
+              ignoreNulls
             )
             windowExpressionNodes.add(windowFunctionNode)
           case wf @ NthValue(input, offset: Literal, ignoreNulls: Boolean) =>
@@ -555,8 +578,6 @@ trait SparkPlanExecApi {
                 .replaceWithExpressionTransformer(input, attributeSeq = 
originalInputAttributes)
                 .doTransform(args))
             childrenNodeList.add(LiteralTransformer(offset).doTransform(args))
-            // Always adds ignoreNulls at the end.
-            childrenNodeList.add(ExpressionBuilder.makeLiteral(ignoreNulls, 
BooleanType, false))
             val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
               WindowFunctionsBuilder.create(args, wf).toInt,
               childrenNodeList,
@@ -564,7 +585,8 @@ trait SparkPlanExecApi {
               ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
               frame.upper.sql,
               frame.lower.sql,
-              frame.frameType.sql
+              frame.frameType.sql,
+              ignoreNulls
             )
             windowExpressionNodes.add(windowFunctionNode)
           case wf @ NTile(buckets: Expression) =>
diff --git 
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
 
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
index f5d021fa5..42757cec9 100644
--- 
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
+++ 
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.spark.sql
 
+import org.apache.spark.SparkConf
 import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -26,6 +27,12 @@ class GlutenDataFrameWindowFunctionsSuite
 
   import testImplicits._
 
+  override def sparkConf: SparkConf = {
+    super.sparkConf
+      // avoid single partition
+      .set("spark.sql.shuffle.partitions", "2")
+  }
+
   testGluten("covar_samp, var_samp (variance), stddev_samp (stddev) functions 
in specific window") {
     withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true") {
       val df = Seq(
diff --git 
a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
 
b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
index 0e7949356..97dd91ded 100644
--- 
a/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
+++ 
b/gluten-ut/spark33/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
@@ -983,6 +983,16 @@ class VeloxTestSettings extends BackendTestSettings {
     .excludeGlutenTest("describe")
   enableSuite[GlutenDataFrameTimeWindowingSuite]
   enableSuite[GlutenDataFrameTungstenSuite]
+  enableSuite[GlutenDataFrameWindowFunctionsSuite]
+    // does not support `spark.sql.legacy.statisticalAggregate=true` (null -> 
NAN)
+    .exclude("corr, covar_pop, stddev_pop functions in specific window")
+    .exclude("covar_samp, var_samp (variance), stddev_samp (stddev) functions 
in specific window")
+    // does not support spill
+    .exclude("Window spill with more than the inMemoryThreshold and 
spillThreshold")
+    .exclude("SPARK-21258: complex object in combination with spilling")
+    // rewrite `WindowExec -> WindowExecTransformer`
+    .exclude(
+      "SPARK-38237: require all cluster keys for child required distribution 
for window query")
   enableSuite[GlutenDataFrameWindowFramesSuite]
     // Local window fixes are not added.
     .exclude("range between should accept int/long values as boundary")
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
index f5d021fa5..0e32f46da 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
@@ -16,6 +16,13 @@
  */
 package org.apache.spark.sql
 
+import io.glutenproject.execution.WindowExecTransformer
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
Expression}
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.ColumnarShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS
 import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -26,6 +33,12 @@ class GlutenDataFrameWindowFunctionsSuite
 
   import testImplicits._
 
+  override def sparkConf: SparkConf = {
+    super.sparkConf
+      // avoid single partition
+      .set("spark.sql.shuffle.partitions", "2")
+  }
+
   testGluten("covar_samp, var_samp (variance), stddev_samp (stddev) functions 
in specific window") {
     withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true") {
       val df = Seq(
@@ -151,4 +164,46 @@ class GlutenDataFrameWindowFunctionsSuite
       )
     }
   }
+
+  testGluten(
+    "SPARK-38237: require all cluster keys for child required distribution for 
window query") {
+    def partitionExpressionsColumns(expressions: Seq[Expression]): Seq[String] 
= {
+      expressions.flatMap { case ref: AttributeReference => Some(ref.name) }
+    }
+
+    def isShuffleExecByRequirement(
+        plan: ColumnarShuffleExchangeExec,
+        desiredClusterColumns: Seq[String]): Boolean = plan match {
+      case ColumnarShuffleExchangeExec(op: HashPartitioning, _, 
ENSURE_REQUIREMENTS, _, _) =>
+        partitionExpressionsColumns(op.expressions) === desiredClusterColumns
+      case _ => false
+    }
+
+    val df = Seq(("a", 1, 1), ("a", 2, 2), ("b", 1, 3), ("b", 1, 
4)).toDF("key1", "key2", "value")
+    val windowSpec = Window.partitionBy("key1", "key2").orderBy("value")
+
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+      SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key -> "true") {
+
+      val windowed = df
+        // repartition by subset of window partitionBy keys which satisfies 
ClusteredDistribution
+        .repartition($"key1")
+        .select(lead($"key1", 1).over(windowSpec), lead($"value", 
1).over(windowSpec))
+
+      checkAnswer(windowed, Seq(Row("b", 4), Row(null, null), Row(null, null), 
Row(null, null)))
+
+      val shuffleByRequirement = windowed.queryExecution.executedPlan.exists {
+        case w: WindowExecTransformer =>
+          w.child.exists {
+            case s: ColumnarShuffleExchangeExec =>
+              isShuffleExecByRequirement(s, Seq("key1", "key2"))
+            case _ => false
+          }
+        case _ => false
+      }
+
+      assert(shuffleByRequirement, "Can't find desired shuffle node from the 
query plan")
+    }
+  }
 }
diff --git 
a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
 
b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
index 45de795de..1fc6a3099 100644
--- 
a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
+++ 
b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/velox/VeloxTestSettings.scala
@@ -989,6 +989,16 @@ class VeloxTestSettings extends BackendTestSettings {
     .exclude("SPARK-41048: Improve output partitioning and ordering with AQE 
cache")
   enableSuite[GlutenDataFrameTimeWindowingSuite]
   enableSuite[GlutenDataFrameTungstenSuite]
+  enableSuite[GlutenDataFrameWindowFunctionsSuite]
+    // does not support `spark.sql.legacy.statisticalAggregate=true` (null -> 
NAN)
+    .exclude("corr, covar_pop, stddev_pop functions in specific window")
+    .exclude("covar_samp, var_samp (variance), stddev_samp (stddev) functions 
in specific window")
+    // does not support spill
+    .exclude("Window spill with more than the inMemoryThreshold and 
spillThreshold")
+    .exclude("SPARK-21258: complex object in combination with spilling")
+    // rewrite `WindowExec -> WindowExecTransformer`
+    .exclude(
+      "SPARK-38237: require all cluster keys for child required distribution 
for window query")
   enableSuite[GlutenDataFrameWindowFramesSuite]
     // Local window fixes are not added.
     .exclude("range between should accept int/long values as boundary")
diff --git 
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
 
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
index f5d021fa5..0e32f46da 100644
--- 
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
+++ 
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameWindowFunctionsSuite.scala
@@ -16,6 +16,13 @@
  */
 package org.apache.spark.sql
 
+import io.glutenproject.execution.WindowExecTransformer
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
Expression}
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.ColumnarShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS
 import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -26,6 +33,12 @@ class GlutenDataFrameWindowFunctionsSuite
 
   import testImplicits._
 
+  override def sparkConf: SparkConf = {
+    super.sparkConf
+      // avoid single partition
+      .set("spark.sql.shuffle.partitions", "2")
+  }
+
   testGluten("covar_samp, var_samp (variance), stddev_samp (stddev) functions 
in specific window") {
     withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true") {
       val df = Seq(
@@ -151,4 +164,46 @@ class GlutenDataFrameWindowFunctionsSuite
       )
     }
   }
+
+  testGluten(
+    "SPARK-38237: require all cluster keys for child required distribution for 
window query") {
+    def partitionExpressionsColumns(expressions: Seq[Expression]): Seq[String] 
= {
+      expressions.flatMap { case ref: AttributeReference => Some(ref.name) }
+    }
+
+    def isShuffleExecByRequirement(
+        plan: ColumnarShuffleExchangeExec,
+        desiredClusterColumns: Seq[String]): Boolean = plan match {
+      case ColumnarShuffleExchangeExec(op: HashPartitioning, _, 
ENSURE_REQUIREMENTS, _, _) =>
+        partitionExpressionsColumns(op.expressions) === desiredClusterColumns
+      case _ => false
+    }
+
+    val df = Seq(("a", 1, 1), ("a", 2, 2), ("b", 1, 3), ("b", 1, 
4)).toDF("key1", "key2", "value")
+    val windowSpec = Window.partitionBy("key1", "key2").orderBy("value")
+
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+      SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key -> "true") {
+
+      val windowed = df
+        // repartition by subset of window partitionBy keys which satisfies 
ClusteredDistribution
+        .repartition($"key1")
+        .select(lead($"key1", 1).over(windowSpec), lead($"value", 
1).over(windowSpec))
+
+      checkAnswer(windowed, Seq(Row("b", 4), Row(null, null), Row(null, null), 
Row(null, null)))
+
+      val shuffleByRequirement = windowed.queryExecution.executedPlan.exists {
+        case w: WindowExecTransformer =>
+          w.child.exists {
+            case s: ColumnarShuffleExchangeExec =>
+              isShuffleExecByRequirement(s, Seq("key1", "key2"))
+            case _ => false
+          }
+        case _ => false
+      }
+
+      assert(shuffleByRequirement, "Can't find desired shuffle node from the 
query plan")
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to