This is an automated email from the ASF dual-hosted git repository.

marin-ma pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new a2cdb98ea8 [VL] Support native scala udaf in window (#12117)
a2cdb98ea8 is described below

commit a2cdb98ea85f05563df478cdf12fd60fd046c438
Author: Rong Ma <[email protected]>
AuthorDate: Thu Jun 4 11:09:50 2026 +0100

    [VL] Support native scala udaf in window (#12117)
---
 .../backendsapi/velox/VeloxSparkPlanExecApi.scala  | 145 +++++++++++++++++++-
 .../apache/spark/sql/expression/UDFResolver.scala  |  45 ++++---
 .../apache/gluten/expression/VeloxUdfSuite.scala   | 146 +++++++++++++++------
 cpp/velox/substrait/SubstraitToVeloxPlan.cc        |   3 +-
 cpp/velox/udf/examples/UdfCommon.h                 |   4 +-
 docs/developers/VeloxUDF.md                        |   1 +
 .../gluten/backendsapi/SparkPlanExecApi.scala      | 137 +------------------
 7 files changed, 282 insertions(+), 199 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index ef25171e6d..be21337d99 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -26,6 +26,8 @@ import org.apache.gluten.extension.JoinKeysTag
 import org.apache.gluten.extension.columnar.FallbackTags
 import org.apache.gluten.shuffle.NeedCustomColumnarBatchSerializer
 import org.apache.gluten.sql.shims.SparkShimLoader
+import org.apache.gluten.substrait.SubstraitContext
+import org.apache.gluten.substrait.expression.{ExpressionBuilder, 
ExpressionNode, WindowFunctionNode}
 import org.apache.gluten.vectorized.{ColumnarBatchSerializer, 
ColumnarBatchSerializeResult}
 
 import org.apache.spark.{ShuffleDependency, SparkEnv, SparkException}
@@ -53,7 +55,8 @@ import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
 import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
 import org.apache.spark.sql.execution.utils.ExecUtil
-import org.apache.spark.sql.expression.{UDFExpression, 
UserDefinedAggregateFunction}
+import org.apache.spark.sql.expression.{UDFExpression, UDFResolver, 
UserDefinedAggregateFunction}
+import org.apache.spark.sql.hive.HiveUDAFInspector
 import org.apache.spark.sql.hive.VeloxHiveUDFTransformer
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -64,6 +67,7 @@ import org.apache.commons.lang3.ClassUtils
 
 import javax.ws.rs.core.UriBuilder
 
+import java.util.{ArrayList => JArrayList, List => JList}
 import java.util.Locale
 
 import scala.collection.JavaConverters._
@@ -1271,4 +1275,143 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi 
with Logging {
           RaiseErrorRestrictions.ONLY_SUPPORT_ERROR_MESSAGE)
     }
   }
+
+  override def genWindowFunctionsNode(
+      windowExpression: Seq[NamedExpression],
+      windowExpressionNodes: JList[WindowFunctionNode],
+      originalInputAttributes: Seq[Attribute],
+      context: SubstraitContext): Unit = {
+    windowExpression.foreach {
+      windowExpr =>
+        val aliasExpr = windowExpr.asInstanceOf[Alias]
+        val columnName = s"${aliasExpr.name}_${aliasExpr.exprId.id}"
+        val wExpression = aliasExpr.child.asInstanceOf[WindowExpression]
+        wExpression.windowFunction match {
+          case wf @ (RowNumber() | Rank(_) | DenseRank(_) | CumeDist() | 
PercentRank(_)) =>
+            val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction]
+            val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
+            val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
+              WindowFunctionsBuilder.create(context, aggWindowFunc).toInt,
+              new JArrayList[ExpressionNode](),
+              columnName,
+              ConverterUtils.getTypeNode(aggWindowFunc.dataType, 
aggWindowFunc.nullable),
+              frame.upper,
+              frame.lower,
+              frame.frameType.sql,
+              originalInputAttributes.asJava
+            )
+            windowExpressionNodes.add(windowFunctionNode)
+          case aggExpression: AggregateExpression =>
+            val frame = 
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
+            val originalAggFunc = aggExpression.aggregateFunction
+            val aggregateFunc =
+              try {
+                
AggregateFunctionsBuilder.getSubstraitFunctionName(originalAggFunc)
+                originalAggFunc
+              } catch {
+                case e: GlutenNotSupportException =>
+                  HiveUDAFInspector.getUDAFClassName(originalAggFunc) match {
+                    case Some(udafClass) if 
UDFResolver.UDAFNames.contains(udafClass) =>
+                      
UDFResolver.getUdafExpression(udafClass)(originalAggFunc.children)
+                    case _ => throw e
+                  }
+              }
+
+            val childrenNodeList = aggregateFunc.children
+              .map(
+                ExpressionConverter
+                  .replaceWithExpressionTransformer(_, originalInputAttributes)
+                  .doTransform(context))
+              .asJava
+
+            val functionId = VeloxAggregateFunctionsBuilder
+              .create(context, aggregateFunc, aggExpression.mode)
+              .toInt
+            val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
+              functionId,
+              childrenNodeList,
+              columnName,
+              ConverterUtils.getTypeNode(aggExpression.dataType, 
aggExpression.nullable),
+              frame.upper,
+              frame.lower,
+              frame.frameType.sql,
+              originalInputAttributes.asJava
+            )
+            windowExpressionNodes.add(windowFunctionNode)
+          case wf @ (_: Lead | _: Lag) =>
+            val offsetWf = wf.asInstanceOf[FrameLessOffsetWindowFunction]
+            val frame = offsetWf.frame.asInstanceOf[SpecifiedWindowFrame]
+            val childrenNodeList = new JArrayList[ExpressionNode]()
+            childrenNodeList.add(
+              ExpressionConverter
+                .replaceWithExpressionTransformer(
+                  offsetWf.input,
+                  attributeSeq = originalInputAttributes)
+                .doTransform(context))
+            val offset = offsetWf.offset.eval(EmptyRow).asInstanceOf[Int]
+            val offsetNode = 
ExpressionBuilder.makeLiteral(Math.abs(offset.toLong), LongType, false)
+            childrenNodeList.add(offsetNode)
+            if (offsetWf.default.dataType != NullType) {
+              childrenNodeList.add(
+                ExpressionConverter
+                  .replaceWithExpressionTransformer(
+                    offsetWf.default,
+                    attributeSeq = originalInputAttributes)
+                  .doTransform(context))
+            }
+            val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
+              WindowFunctionsBuilder.create(context, offsetWf).toInt,
+              childrenNodeList,
+              columnName,
+              ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
+              frame.upper,
+              frame.lower,
+              frame.frameType.sql,
+              offsetWf.ignoreNulls,
+              originalInputAttributes.asJava
+            )
+            windowExpressionNodes.add(windowFunctionNode)
+          case wf @ NthValue(input, offset: Literal, ignoreNulls: Boolean) =>
+            val frame = 
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
+            val childrenNodeList = new JArrayList[ExpressionNode]()
+            childrenNodeList.add(
+              ExpressionConverter
+                .replaceWithExpressionTransformer(input, attributeSeq = 
originalInputAttributes)
+                .doTransform(context))
+            
childrenNodeList.add(LiteralTransformer(offset).doTransform(context))
+            val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
+              WindowFunctionsBuilder.create(context, wf).toInt,
+              childrenNodeList,
+              columnName,
+              ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
+              frame.upper,
+              frame.lower,
+              frame.frameType.sql,
+              ignoreNulls,
+              originalInputAttributes.asJava
+            )
+            windowExpressionNodes.add(windowFunctionNode)
+          case wf @ NTile(buckets: Expression) =>
+            val frame = 
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
+            val childrenNodeList = new JArrayList[ExpressionNode]()
+            val literal = buckets.asInstanceOf[Literal]
+            
childrenNodeList.add(LiteralTransformer(literal).doTransform(context))
+            val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
+              WindowFunctionsBuilder.create(context, wf).toInt,
+              childrenNodeList,
+              columnName,
+              ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
+              frame.upper,
+              frame.lower,
+              frame.frameType.sql,
+              originalInputAttributes.asJava
+            )
+            windowExpressionNodes.add(windowFunctionNode)
+          case _ =>
+            throw new GlutenNotSupportException(
+              "unsupported window function type: " +
+                wExpression.windowFunction)
+        }
+    }
+  }
 }
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
index d2405f9e93..43ee7b4f7d 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
@@ -344,15 +344,15 @@ object UDFResolver extends Logging {
 
     val allowTypeConversion = checkAllowTypeConversion
     val signatures =
-      UDFMap.getOrElse(name, throw new GlutenNotSupportException(errorMessage))
-    signatures.find(sig => tryBind(sig, children.map(_.dataType), 
allowTypeConversion)) match {
-      case Some(sig) =>
+      UDFMap.getOrElse(name, throw new 
GlutenNotSupportException(errorMessage)).toSeq
+    tryBind(signatures, children.map(_.dataType), allowTypeConversion) match {
+      case Some((sig, withTypeConversion)) =>
         UDFExpression(
           name,
           alias,
           sig.expressionType.dataType,
           sig.expressionType.nullable,
-          if (!allowTypeConversion && !sig.allowTypeConversion) children
+          if (!withTypeConversion) children
           else applyCast(children, sig)
         )
       case None =>
@@ -366,17 +366,15 @@ object UDFResolver extends Logging {
 
     val allowTypeConversion = checkAllowTypeConversion
     val signatures =
-      UDAFMap.getOrElse(
-        name,
-        throw new GlutenNotSupportException(errorMessage)
-      )
-    signatures.find(sig => tryBind(sig, children.map(_.dataType), 
allowTypeConversion)) match {
-      case Some(sig) =>
+      UDAFMap.getOrElse(name, throw new 
GlutenNotSupportException(errorMessage)).toSeq
+
+    tryBind(signatures, children.map(_.dataType), allowTypeConversion) match {
+      case Some((sig, withTypeConversion)) =>
         UserDefinedAggregateFunction(
           name,
           sig.expressionType.dataType,
           sig.expressionType.nullable,
-          if (!allowTypeConversion && !sig.allowTypeConversion) children
+          if (!withTypeConversion) children
           else applyCast(children, sig),
           sig.intermediateAttrs
         )
@@ -385,16 +383,23 @@ object UDFResolver extends Logging {
     }
   }
 
-  private def tryBind(
-      sig: UDFSignatureBase,
+  private def tryBind[U <: UDFSignatureBase](
+      signatures: Seq[U],
       requiredDataTypes: Seq[DataType],
-      allowTypeConversion: Boolean): Boolean = {
-    if (
-      !tryBindStrict(sig, requiredDataTypes) && (allowTypeConversion || 
sig.allowTypeConversion)
-    ) {
-      tryBindWithTypeConversion(sig, requiredDataTypes)
-    } else {
-      true
+      allowTypeConversion: Boolean): Option[(U, Boolean)] = {
+    signatures.find(sig => tryBindStrict(sig, requiredDataTypes)) match {
+      case Some(sig) => Some((sig, false))
+      case None =>
+        val allowTypeConversionSignatures = if (allowTypeConversion) {
+          signatures
+        } else {
+          signatures.filter(_.allowTypeConversion)
+        }
+        allowTypeConversionSignatures.find(
+          sig => tryBindWithTypeConversion(sig, requiredDataTypes)) match {
+          case Some(sig) => Some((sig, true))
+          case None => None
+        }
     }
   }
 
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
index 7eb61144a9..a5128f62d9 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
@@ -18,12 +18,14 @@ package org.apache.gluten.expression
 
 import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
 import org.apache.gluten.execution.ProjectExecTransformer
+import org.apache.gluten.execution.WindowExecTransformer
 import org.apache.gluten.tags.{SkipTest, UDFTest}
 
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.{GlutenQueryTest, Row, SparkSession}
 import org.apache.spark.sql.catalyst.plans.SQLHelper
 import org.apache.spark.sql.execution.ProjectExec
+import org.apache.spark.sql.execution.window.WindowExec
 import org.apache.spark.sql.expression.UDFResolver
 
 import java.nio.file.Paths
@@ -92,16 +94,72 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with 
SQLHelper {
       .set("spark.memory.offHeap.enabled", "true")
       .set("spark.memory.offHeap.size", "1024MB")
       .set("spark.ui.enabled", "false")
+      .set("spark.sql.adaptive.enabled", "false")
   }
 
-  // Aggregate result can be flaky.
-  ignore("test native hive udaf") {
+  test("test native hive udf") {
+    val tbl = "test_hive_udf_replacement"
+    withTempPath {
+      dir =>
+        try {
+          spark.sql(s"""
+                       |CREATE EXTERNAL TABLE $tbl
+                       |LOCATION 'file://$dir'
+                       |AS select * from values (1, '1'), (2, '2'), (3, '3')
+                       |""".stripMargin)
+
+          // Check native hive udf has been registered.
+          assert(
+            
UDFResolver.UDFNames.contains("org.apache.spark.sql.hive.execution.UDFStringString"))
+
+          spark.sql("""
+                      |CREATE TEMPORARY FUNCTION hive_string_string
+                      |AS 'org.apache.spark.sql.hive.execution.UDFStringString'
+                      |""".stripMargin)
+
+          val offloadWithImplicitConversionDF =
+            spark.sql(s"""SELECT hive_string_string(col1, 'a') FROM $tbl""")
+          
checkGlutenPlan[ProjectExecTransformer](offloadWithImplicitConversionDF)
+          val offloadWithImplicitConversionResult = 
offloadWithImplicitConversionDF.collect()
+
+          val offloadDF =
+            spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""")
+          checkGlutenPlan[ProjectExecTransformer](offloadDF)
+          val offloadResult = offloadWithImplicitConversionDF.collect()
+
+          // Unregister native hive udf to fallback.
+          
UDFResolver.UDFNames.remove("org.apache.spark.sql.hive.execution.UDFStringString")
+          val fallbackDF =
+            spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""")
+          checkSparkPlan[ProjectExec](fallbackDF)
+          val fallbackResult = fallbackDF.collect()
+          
assert(offloadWithImplicitConversionResult.sameElements(fallbackResult))
+          assert(offloadResult.sameElements(fallbackResult))
+
+          // Add an unimplemented udf to the map to test fallback of 
registered native hive udf.
+          
UDFResolver.UDFNames.add("org.apache.spark.sql.hive.execution.UDFIntegerToString")
+          spark.sql("""
+                      |CREATE TEMPORARY FUNCTION hive_int_to_string
+                      |AS 
'org.apache.spark.sql.hive.execution.UDFIntegerToString'
+                      |""".stripMargin)
+          val df = spark.sql(s"""select hive_int_to_string(col1) from $tbl""")
+          checkSparkPlan[ProjectExec](df)
+          checkAnswer(df, Seq(Row("1"), Row("2"), Row("3")))
+        } finally {
+          spark.sql(s"DROP TABLE IF EXISTS $tbl")
+          spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_string_string")
+          spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_int_to_string")
+        }
+    }
+  }
+
+  test("test native hive udaf") {
     val tbl = "test_hive_udaf_replacement"
+    val udafClass = "test.org.apache.spark.sql.MyDoubleAvg"
     withTempPath {
       dir =>
         try {
           // Check native hive udaf has been registered.
-          val udafClass = "test.org.apache.spark.sql.MyDoubleAvg"
           assert(UDFResolver.UDAFNames.contains(udafClass))
 
           spark.sql(s"""
@@ -136,64 +194,68 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with 
SQLHelper {
           assert(nativeResult.sameElements(fallbackResult))
           assert(nativeImplicitConversionResult.sameElements(fallbackResult))
         } finally {
+          UDFResolver.UDAFNames.add(udafClass)
           spark.sql(s"DROP TABLE IF EXISTS $tbl")
           spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS my_double_avg")
         }
     }
   }
 
-  test("test native hive udf") {
-    val tbl = "test_hive_udf_replacement"
+  test("test native hive udaf in window") {
+    val tbl = "test_hive_udaf_window"
+    val udafClass = "test.org.apache.spark.sql.MyDoubleAvg"
+    val query =
+      s"""SELECT
+         |  col1,
+         |  my_double_avg(col1) OVER (
+         |    PARTITION BY col1 % 2
+         |    ORDER BY col1
+         |    ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS 
my_avg_window
+         |FROM $tbl
+         |ORDER BY col1
+         |""".stripMargin
+
     withTempPath {
       dir =>
         try {
+          assert(UDFResolver.UDAFNames.contains(udafClass))
+
+          spark.sql(s"""
+                       |CREATE TEMPORARY FUNCTION my_double_avg
+                       |AS '$udafClass'
+                       |""".stripMargin)
+          spark.sql(s"""
+                       |DROP TABLE IF EXISTS $tbl;
+                       |""".stripMargin)
           spark.sql(s"""
                        |CREATE EXTERNAL TABLE $tbl
                        |LOCATION 'file://$dir'
-                       |AS select * from values (1, '1'), (2, '2'), (3, '3')
+                       |AS SELECT CAST(v AS FLOAT) AS col1
+                       |FROM VALUES (1.0), (2.0), (3.0), (4.0) AS t(v)
                        |""".stripMargin)
 
-          // Check native hive udf has been registered.
-          assert(
-            
UDFResolver.UDFNames.contains("org.apache.spark.sql.hive.execution.UDFStringString"))
-
-          spark.sql("""
-                      |CREATE TEMPORARY FUNCTION hive_string_string
-                      |AS 'org.apache.spark.sql.hive.execution.UDFStringString'
-                      |""".stripMargin)
-
-          val offloadWithImplicitConversionDF =
-            spark.sql(s"""SELECT hive_string_string(col1, 'a') FROM $tbl""")
-          
checkGlutenPlan[ProjectExecTransformer](offloadWithImplicitConversionDF)
-          val offloadWithImplicitConversionResult = 
offloadWithImplicitConversionDF.collect()
+          val offloadDF = spark.sql(query)
+          checkGlutenPlan[WindowExecTransformer](offloadDF)
+          checkAnswer(
+            offloadDF,
+            Seq(
+              Row(1.0f, 101.0),
+              Row(2.0f, 102.0),
+              Row(3.0f, 102.0),
+              Row(4.0f, 103.0)
+            ))
+          val offloadResult = offloadDF.collect()
 
-          val offloadDF =
-            spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""")
-          checkGlutenPlan[ProjectExecTransformer](offloadDF)
-          val offloadResult = offloadWithImplicitConversionDF.collect()
-
-          // Unregister native hive udf to fallback.
-          
UDFResolver.UDFNames.remove("org.apache.spark.sql.hive.execution.UDFStringString")
-          val fallbackDF =
-            spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""")
-          checkSparkPlan[ProjectExec](fallbackDF)
+          UDFResolver.UDAFNames.remove(udafClass)
+          val fallbackDF = spark.sql(query)
+          checkSparkPlan[WindowExec](fallbackDF)
           val fallbackResult = fallbackDF.collect()
-          
assert(offloadWithImplicitConversionResult.sameElements(fallbackResult))
-          assert(offloadResult.sameElements(fallbackResult))
 
-          // Add an unimplemented udf to the map to test fallback of 
registered native hive udf.
-          
UDFResolver.UDFNames.add("org.apache.spark.sql.hive.execution.UDFIntegerToString")
-          spark.sql("""
-                      |CREATE TEMPORARY FUNCTION hive_int_to_string
-                      |AS 
'org.apache.spark.sql.hive.execution.UDFIntegerToString'
-                      |""".stripMargin)
-          val df = spark.sql(s"""select hive_int_to_string(col1) from $tbl""")
-          checkSparkPlan[ProjectExec](df)
-          checkAnswer(df, Seq(Row("1"), Row("2"), Row("3")))
+          assert(offloadResult.sameElements(fallbackResult))
         } finally {
+          UDFResolver.UDAFNames.add(udafClass)
           spark.sql(s"DROP TABLE IF EXISTS $tbl")
-          spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_string_string")
-          spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_int_to_string")
+          spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS my_double_avg")
         }
     }
   }
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc 
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index 5477176ce8..b0fc0fc4a3 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -1165,7 +1165,8 @@ core::PlanNodePtr 
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
       windowParams.emplace_back(exprConverter_->toVeloxExpr(arg.value(), 
inputType));
     }
     auto windowVeloxType = 
SubstraitParser::parseType(windowFunction.output_type());
-    auto windowCall = std::make_shared<const 
core::CallTypedExpr>(windowVeloxType, std::move(windowParams), funcName);
+    auto windowCall = std::make_shared<const core::CallTypedExpr>(
+        windowVeloxType, std::move(windowParams), 
exec::sanitizeName(funcName));
     auto upperBound = windowFunction.upper_bound();
     auto lowerBound = windowFunction.lower_bound();
     auto type = windowFunction.window_type();
diff --git a/cpp/velox/udf/examples/UdfCommon.h 
b/cpp/velox/udf/examples/UdfCommon.h
index a68c474607..7e73b916a2 100644
--- a/cpp/velox/udf/examples/UdfCommon.h
+++ b/cpp/velox/udf/examples/UdfCommon.h
@@ -24,7 +24,7 @@ namespace gluten {
 
 class UdfRegisterer {
  public:
-  ~UdfRegisterer() = default;
+  virtual ~UdfRegisterer() = default;
 
   // Returns the number of UDFs in populateUdfEntries.
   virtual int getNumUdf() = 0;
@@ -38,7 +38,7 @@ class UdfRegisterer {
 
 class UdafRegisterer {
  public:
-  ~UdafRegisterer() = default;
+  virtual ~UdafRegisterer() = default;
 
   // Returns the number of UDFs in populateUdafEntries.
   virtual int getNumUdaf() = 0;
diff --git a/docs/developers/VeloxUDF.md b/docs/developers/VeloxUDF.md
index b3154c41a7..a38f1a48db 100644
--- a/docs/developers/VeloxUDF.md
+++ b/docs/developers/VeloxUDF.md
@@ -14,6 +14,7 @@ Users can implement custom functions using the UDF interface 
provided by Velox a
 At runtime, these UDFs are registered alongside their Java implementations via 
`CREATE TEMPORARY FUNCTION`.
 Once registered, Gluten can parse and offload these UDFs to Velox during 
execution, 
 meanwhile ensuring proper fallback to Java UDFs when necessary.
+Registered UDAFs can be used both as regular aggregate functions and as 
aggregate window functions.
 
 ## Create and Build UDF/UDAF library
 
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 84e2d86554..79f3d67c0e 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -22,7 +22,7 @@ import org.apache.gluten.execution._
 import org.apache.gluten.expression._
 import org.apache.gluten.sql.shims.SparkShimLoader
 import org.apache.gluten.substrait.SubstraitContext
-import org.apache.gluten.substrait.expression.{ExpressionBuilder, 
ExpressionNode, WindowFunctionNode}
+import org.apache.gluten.substrait.expression.WindowFunctionNode
 
 import org.apache.spark.ShuffleDependency
 import org.apache.spark.rdd.RDD
@@ -44,13 +44,11 @@ import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
 import org.apache.spark.sql.execution.window._
 import org.apache.spark.sql.hive.HiveUDFTransformer
-import org.apache.spark.sql.types.{DecimalType, LongType, NullType, StructType}
+import org.apache.spark.sql.types.{DecimalType, StructType}
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
 import java.io.{ObjectInputStream, ObjectOutputStream}
-import java.util.{ArrayList => JArrayList, List => JList}
-
-import scala.collection.JavaConverters._
+import java.util.{List => JList}
 
 trait SparkPlanExecApi {
 
@@ -555,134 +553,7 @@ trait SparkPlanExecApi {
       windowExpression: Seq[NamedExpression],
       windowExpressionNodes: JList[WindowFunctionNode],
       originalInputAttributes: Seq[Attribute],
-      context: SubstraitContext): Unit = {
-    windowExpression.map {
-      windowExpr =>
-        val aliasExpr = windowExpr.asInstanceOf[Alias]
-        val columnName = s"${aliasExpr.name}_${aliasExpr.exprId.id}"
-        val wExpression = aliasExpr.child.asInstanceOf[WindowExpression]
-        wExpression.windowFunction match {
-          case wf @ (RowNumber() | Rank(_) | DenseRank(_) | CumeDist() | 
PercentRank(_)) =>
-            val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction]
-            val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
-            val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
-              WindowFunctionsBuilder.create(context, aggWindowFunc).toInt,
-              new JArrayList[ExpressionNode](),
-              columnName,
-              ConverterUtils.getTypeNode(aggWindowFunc.dataType, 
aggWindowFunc.nullable),
-              frame.upper,
-              frame.lower,
-              frame.frameType.sql,
-              originalInputAttributes.asJava
-            )
-            windowExpressionNodes.add(windowFunctionNode)
-          case aggExpression: AggregateExpression =>
-            val frame = 
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
-            val aggregateFunc = aggExpression.aggregateFunction
-            val substraitAggFuncName = 
ExpressionMappings.expressionsMap.get(aggregateFunc.getClass)
-            if (substraitAggFuncName.isEmpty) {
-              throw new GlutenNotSupportException(s"Not currently supported: 
$aggregateFunc.")
-            }
-
-            val childrenNodeList = aggregateFunc.children
-              .map(
-                ExpressionConverter
-                  .replaceWithExpressionTransformer(_, originalInputAttributes)
-                  .doTransform(context))
-              .asJava
-
-            val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
-              AggregateFunctionsBuilder.create(context, 
aggExpression.aggregateFunction).toInt,
-              childrenNodeList,
-              columnName,
-              ConverterUtils.getTypeNode(aggExpression.dataType, 
aggExpression.nullable),
-              frame.upper,
-              frame.lower,
-              frame.frameType.sql,
-              originalInputAttributes.asJava
-            )
-            windowExpressionNodes.add(windowFunctionNode)
-          case wf @ (_: Lead | _: Lag) =>
-            val offsetWf = wf.asInstanceOf[FrameLessOffsetWindowFunction]
-            val frame = offsetWf.frame.asInstanceOf[SpecifiedWindowFrame]
-            val childrenNodeList = new JArrayList[ExpressionNode]()
-            childrenNodeList.add(
-              ExpressionConverter
-                .replaceWithExpressionTransformer(
-                  offsetWf.input,
-                  attributeSeq = originalInputAttributes)
-                .doTransform(context))
-            // Spark only accepts foldable offset. Converts it to LongType 
literal.
-            val offset = offsetWf.offset.eval(EmptyRow).asInstanceOf[Int]
-            // Velox only allows negative offset. 
WindowFunctionsBuilder#create converts
-            // lag/lead with negative offset to the function with positive 
offset. So just
-            // makes offsetNode store positive value.
-            val offsetNode = 
ExpressionBuilder.makeLiteral(Math.abs(offset.toLong), LongType, false)
-            childrenNodeList.add(offsetNode)
-            // NullType means Null is the default value. Don't pass it to 
native.
-            if (offsetWf.default.dataType != NullType) {
-              childrenNodeList.add(
-                ExpressionConverter
-                  .replaceWithExpressionTransformer(
-                    offsetWf.default,
-                    attributeSeq = originalInputAttributes)
-                  .doTransform(context))
-            }
-            val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
-              WindowFunctionsBuilder.create(context, offsetWf).toInt,
-              childrenNodeList,
-              columnName,
-              ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
-              frame.upper,
-              frame.lower,
-              frame.frameType.sql,
-              offsetWf.ignoreNulls,
-              originalInputAttributes.asJava
-            )
-            windowExpressionNodes.add(windowFunctionNode)
-          case wf @ NthValue(input, offset: Literal, ignoreNulls: Boolean) =>
-            val frame = 
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
-            val childrenNodeList = new JArrayList[ExpressionNode]()
-            childrenNodeList.add(
-              ExpressionConverter
-                .replaceWithExpressionTransformer(input, attributeSeq = 
originalInputAttributes)
-                .doTransform(context))
-            
childrenNodeList.add(LiteralTransformer(offset).doTransform(context))
-            val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
-              WindowFunctionsBuilder.create(context, wf).toInt,
-              childrenNodeList,
-              columnName,
-              ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
-              frame.upper,
-              frame.lower,
-              frame.frameType.sql,
-              ignoreNulls,
-              originalInputAttributes.asJava
-            )
-            windowExpressionNodes.add(windowFunctionNode)
-          case wf @ NTile(buckets: Expression) =>
-            val frame = 
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
-            val childrenNodeList = new JArrayList[ExpressionNode]()
-            val literal = buckets.asInstanceOf[Literal]
-            
childrenNodeList.add(LiteralTransformer(literal).doTransform(context))
-            val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
-              WindowFunctionsBuilder.create(context, wf).toInt,
-              childrenNodeList,
-              columnName,
-              ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
-              frame.upper,
-              frame.lower,
-              frame.frameType.sql,
-              originalInputAttributes.asJava
-            )
-            windowExpressionNodes.add(windowFunctionNode)
-          case _ =>
-            throw new GlutenNotSupportException(
-              "unsupported window function type: " +
-                wExpression.windowFunction)
-        }
-    }
-  }
+      context: SubstraitContext): Unit
 
   def rewriteSpillPath(path: String): String = path
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to