This is an automated email from the ASF dual-hosted git repository.
marin-ma pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gluten.git
The following commit(s) were added to refs/heads/main by this push:
new a2cdb98ea8 [VL] Support native scala udaf in window (#12117)
a2cdb98ea8 is described below
commit a2cdb98ea85f05563df478cdf12fd60fd046c438
Author: Rong Ma <[email protected]>
AuthorDate: Thu Jun 4 11:09:50 2026 +0100
[VL] Support native scala udaf in window (#12117)
---
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 145 +++++++++++++++++++-
.../apache/spark/sql/expression/UDFResolver.scala | 45 ++++---
.../apache/gluten/expression/VeloxUdfSuite.scala | 146 +++++++++++++++------
cpp/velox/substrait/SubstraitToVeloxPlan.cc | 3 +-
cpp/velox/udf/examples/UdfCommon.h | 4 +-
docs/developers/VeloxUDF.md | 1 +
.../gluten/backendsapi/SparkPlanExecApi.scala | 137 +------------------
7 files changed, 282 insertions(+), 199 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index ef25171e6d..be21337d99 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -26,6 +26,8 @@ import org.apache.gluten.extension.JoinKeysTag
import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.shuffle.NeedCustomColumnarBatchSerializer
import org.apache.gluten.sql.shims.SparkShimLoader
+import org.apache.gluten.substrait.SubstraitContext
+import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode, WindowFunctionNode}
import org.apache.gluten.vectorized.{ColumnarBatchSerializer,
ColumnarBatchSerializeResult}
import org.apache.spark.{ShuffleDependency, SparkEnv, SparkException}
@@ -53,7 +55,8 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
import org.apache.spark.sql.execution.utils.ExecUtil
-import org.apache.spark.sql.expression.{UDFExpression,
UserDefinedAggregateFunction}
+import org.apache.spark.sql.expression.{UDFExpression, UDFResolver,
UserDefinedAggregateFunction}
+import org.apache.spark.sql.hive.HiveUDAFInspector
import org.apache.spark.sql.hive.VeloxHiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -64,6 +67,7 @@ import org.apache.commons.lang3.ClassUtils
import javax.ws.rs.core.UriBuilder
+import java.util.{ArrayList => JArrayList, List => JList}
import java.util.Locale
import scala.collection.JavaConverters._
@@ -1271,4 +1275,143 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi
with Logging {
RaiseErrorRestrictions.ONLY_SUPPORT_ERROR_MESSAGE)
}
}
+
+ override def genWindowFunctionsNode(
+ windowExpression: Seq[NamedExpression],
+ windowExpressionNodes: JList[WindowFunctionNode],
+ originalInputAttributes: Seq[Attribute],
+ context: SubstraitContext): Unit = {
+ windowExpression.foreach {
+ windowExpr =>
+ val aliasExpr = windowExpr.asInstanceOf[Alias]
+ val columnName = s"${aliasExpr.name}_${aliasExpr.exprId.id}"
+ val wExpression = aliasExpr.child.asInstanceOf[WindowExpression]
+ wExpression.windowFunction match {
+ case wf @ (RowNumber() | Rank(_) | DenseRank(_) | CumeDist() |
PercentRank(_)) =>
+ val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction]
+ val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
+ val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
+ WindowFunctionsBuilder.create(context, aggWindowFunc).toInt,
+ new JArrayList[ExpressionNode](),
+ columnName,
+ ConverterUtils.getTypeNode(aggWindowFunc.dataType,
aggWindowFunc.nullable),
+ frame.upper,
+ frame.lower,
+ frame.frameType.sql,
+ originalInputAttributes.asJava
+ )
+ windowExpressionNodes.add(windowFunctionNode)
+ case aggExpression: AggregateExpression =>
+ val frame =
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
+ val originalAggFunc = aggExpression.aggregateFunction
+ val aggregateFunc =
+ try {
+
AggregateFunctionsBuilder.getSubstraitFunctionName(originalAggFunc)
+ originalAggFunc
+ } catch {
+ case e: GlutenNotSupportException =>
+ HiveUDAFInspector.getUDAFClassName(originalAggFunc) match {
+ case Some(udafClass) if
UDFResolver.UDAFNames.contains(udafClass) =>
+
UDFResolver.getUdafExpression(udafClass)(originalAggFunc.children)
+ case _ => throw e
+ }
+ }
+
+ val childrenNodeList = aggregateFunc.children
+ .map(
+ ExpressionConverter
+ .replaceWithExpressionTransformer(_, originalInputAttributes)
+ .doTransform(context))
+ .asJava
+
+ val functionId = VeloxAggregateFunctionsBuilder
+ .create(context, aggregateFunc, aggExpression.mode)
+ .toInt
+ val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
+ functionId,
+ childrenNodeList,
+ columnName,
+ ConverterUtils.getTypeNode(aggExpression.dataType,
aggExpression.nullable),
+ frame.upper,
+ frame.lower,
+ frame.frameType.sql,
+ originalInputAttributes.asJava
+ )
+ windowExpressionNodes.add(windowFunctionNode)
+ case wf @ (_: Lead | _: Lag) =>
+ val offsetWf = wf.asInstanceOf[FrameLessOffsetWindowFunction]
+ val frame = offsetWf.frame.asInstanceOf[SpecifiedWindowFrame]
+ val childrenNodeList = new JArrayList[ExpressionNode]()
+ childrenNodeList.add(
+ ExpressionConverter
+ .replaceWithExpressionTransformer(
+ offsetWf.input,
+ attributeSeq = originalInputAttributes)
+ .doTransform(context))
+ val offset = offsetWf.offset.eval(EmptyRow).asInstanceOf[Int]
+ val offsetNode =
ExpressionBuilder.makeLiteral(Math.abs(offset.toLong), LongType, false)
+ childrenNodeList.add(offsetNode)
+ if (offsetWf.default.dataType != NullType) {
+ childrenNodeList.add(
+ ExpressionConverter
+ .replaceWithExpressionTransformer(
+ offsetWf.default,
+ attributeSeq = originalInputAttributes)
+ .doTransform(context))
+ }
+ val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
+ WindowFunctionsBuilder.create(context, offsetWf).toInt,
+ childrenNodeList,
+ columnName,
+ ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
+ frame.upper,
+ frame.lower,
+ frame.frameType.sql,
+ offsetWf.ignoreNulls,
+ originalInputAttributes.asJava
+ )
+ windowExpressionNodes.add(windowFunctionNode)
+ case wf @ NthValue(input, offset: Literal, ignoreNulls: Boolean) =>
+ val frame =
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
+ val childrenNodeList = new JArrayList[ExpressionNode]()
+ childrenNodeList.add(
+ ExpressionConverter
+ .replaceWithExpressionTransformer(input, attributeSeq =
originalInputAttributes)
+ .doTransform(context))
+
childrenNodeList.add(LiteralTransformer(offset).doTransform(context))
+ val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
+ WindowFunctionsBuilder.create(context, wf).toInt,
+ childrenNodeList,
+ columnName,
+ ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
+ frame.upper,
+ frame.lower,
+ frame.frameType.sql,
+ ignoreNulls,
+ originalInputAttributes.asJava
+ )
+ windowExpressionNodes.add(windowFunctionNode)
+ case wf @ NTile(buckets: Expression) =>
+ val frame =
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
+ val childrenNodeList = new JArrayList[ExpressionNode]()
+ val literal = buckets.asInstanceOf[Literal]
+
childrenNodeList.add(LiteralTransformer(literal).doTransform(context))
+ val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
+ WindowFunctionsBuilder.create(context, wf).toInt,
+ childrenNodeList,
+ columnName,
+ ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
+ frame.upper,
+ frame.lower,
+ frame.frameType.sql,
+ originalInputAttributes.asJava
+ )
+ windowExpressionNodes.add(windowFunctionNode)
+ case _ =>
+ throw new GlutenNotSupportException(
+ "unsupported window function type: " +
+ wExpression.windowFunction)
+ }
+ }
+ }
}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
index d2405f9e93..43ee7b4f7d 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
@@ -344,15 +344,15 @@ object UDFResolver extends Logging {
val allowTypeConversion = checkAllowTypeConversion
val signatures =
- UDFMap.getOrElse(name, throw new GlutenNotSupportException(errorMessage))
- signatures.find(sig => tryBind(sig, children.map(_.dataType),
allowTypeConversion)) match {
- case Some(sig) =>
+ UDFMap.getOrElse(name, throw new
GlutenNotSupportException(errorMessage)).toSeq
+ tryBind(signatures, children.map(_.dataType), allowTypeConversion) match {
+ case Some((sig, withTypeConversion)) =>
UDFExpression(
name,
alias,
sig.expressionType.dataType,
sig.expressionType.nullable,
- if (!allowTypeConversion && !sig.allowTypeConversion) children
+ if (!withTypeConversion) children
else applyCast(children, sig)
)
case None =>
@@ -366,17 +366,15 @@ object UDFResolver extends Logging {
val allowTypeConversion = checkAllowTypeConversion
val signatures =
- UDAFMap.getOrElse(
- name,
- throw new GlutenNotSupportException(errorMessage)
- )
- signatures.find(sig => tryBind(sig, children.map(_.dataType),
allowTypeConversion)) match {
- case Some(sig) =>
+ UDAFMap.getOrElse(name, throw new
GlutenNotSupportException(errorMessage)).toSeq
+
+ tryBind(signatures, children.map(_.dataType), allowTypeConversion) match {
+ case Some((sig, withTypeConversion)) =>
UserDefinedAggregateFunction(
name,
sig.expressionType.dataType,
sig.expressionType.nullable,
- if (!allowTypeConversion && !sig.allowTypeConversion) children
+ if (!withTypeConversion) children
else applyCast(children, sig),
sig.intermediateAttrs
)
@@ -385,16 +383,23 @@ object UDFResolver extends Logging {
}
}
- private def tryBind(
- sig: UDFSignatureBase,
+ private def tryBind[U <: UDFSignatureBase](
+ signatures: Seq[U],
requiredDataTypes: Seq[DataType],
- allowTypeConversion: Boolean): Boolean = {
- if (
- !tryBindStrict(sig, requiredDataTypes) && (allowTypeConversion ||
sig.allowTypeConversion)
- ) {
- tryBindWithTypeConversion(sig, requiredDataTypes)
- } else {
- true
+ allowTypeConversion: Boolean): Option[(U, Boolean)] = {
+ signatures.find(sig => tryBindStrict(sig, requiredDataTypes)) match {
+ case Some(sig) => Some((sig, false))
+ case None =>
+ val allowTypeConversionSignatures = if (allowTypeConversion) {
+ signatures
+ } else {
+ signatures.filter(_.allowTypeConversion)
+ }
+ allowTypeConversionSignatures.find(
+ sig => tryBindWithTypeConversion(sig, requiredDataTypes)) match {
+ case Some(sig) => Some((sig, true))
+ case None => None
+ }
}
}
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
index 7eb61144a9..a5128f62d9 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
@@ -18,12 +18,14 @@ package org.apache.gluten.expression
import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
import org.apache.gluten.execution.ProjectExecTransformer
+import org.apache.gluten.execution.WindowExecTransformer
import org.apache.gluten.tags.{SkipTest, UDFTest}
import org.apache.spark.SparkConf
import org.apache.spark.sql.{GlutenQueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.execution.ProjectExec
+import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.expression.UDFResolver
import java.nio.file.Paths
@@ -92,16 +94,72 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with
SQLHelper {
.set("spark.memory.offHeap.enabled", "true")
.set("spark.memory.offHeap.size", "1024MB")
.set("spark.ui.enabled", "false")
+ .set("spark.sql.adaptive.enabled", "false")
}
- // Aggregate result can be flaky.
- ignore("test native hive udaf") {
+ test("test native hive udf") {
+ val tbl = "test_hive_udf_replacement"
+ withTempPath {
+ dir =>
+ try {
+ spark.sql(s"""
+ |CREATE EXTERNAL TABLE $tbl
+ |LOCATION 'file://$dir'
+ |AS select * from values (1, '1'), (2, '2'), (3, '3')
+ |""".stripMargin)
+
+ // Check native hive udf has been registered.
+ assert(
+
UDFResolver.UDFNames.contains("org.apache.spark.sql.hive.execution.UDFStringString"))
+
+ spark.sql("""
+ |CREATE TEMPORARY FUNCTION hive_string_string
+ |AS 'org.apache.spark.sql.hive.execution.UDFStringString'
+ |""".stripMargin)
+
+ val offloadWithImplicitConversionDF =
+ spark.sql(s"""SELECT hive_string_string(col1, 'a') FROM $tbl""")
+
checkGlutenPlan[ProjectExecTransformer](offloadWithImplicitConversionDF)
+ val offloadWithImplicitConversionResult =
offloadWithImplicitConversionDF.collect()
+
+ val offloadDF =
+ spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""")
+ checkGlutenPlan[ProjectExecTransformer](offloadDF)
+ val offloadResult = offloadWithImplicitConversionDF.collect()
+
+ // Unregister native hive udf to fallback.
+
UDFResolver.UDFNames.remove("org.apache.spark.sql.hive.execution.UDFStringString")
+ val fallbackDF =
+ spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""")
+ checkSparkPlan[ProjectExec](fallbackDF)
+ val fallbackResult = fallbackDF.collect()
+
assert(offloadWithImplicitConversionResult.sameElements(fallbackResult))
+ assert(offloadResult.sameElements(fallbackResult))
+
+ // Add an unimplemented udf to the map to test fallback of
registered native hive udf.
+
UDFResolver.UDFNames.add("org.apache.spark.sql.hive.execution.UDFIntegerToString")
+ spark.sql("""
+ |CREATE TEMPORARY FUNCTION hive_int_to_string
+ |AS
'org.apache.spark.sql.hive.execution.UDFIntegerToString'
+ |""".stripMargin)
+ val df = spark.sql(s"""select hive_int_to_string(col1) from $tbl""")
+ checkSparkPlan[ProjectExec](df)
+ checkAnswer(df, Seq(Row("1"), Row("2"), Row("3")))
+ } finally {
+ spark.sql(s"DROP TABLE IF EXISTS $tbl")
+ spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_string_string")
+ spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_int_to_string")
+ }
+ }
+ }
+
+ test("test native hive udaf") {
val tbl = "test_hive_udaf_replacement"
+ val udafClass = "test.org.apache.spark.sql.MyDoubleAvg"
withTempPath {
dir =>
try {
// Check native hive udaf has been registered.
- val udafClass = "test.org.apache.spark.sql.MyDoubleAvg"
assert(UDFResolver.UDAFNames.contains(udafClass))
spark.sql(s"""
@@ -136,64 +194,68 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with
SQLHelper {
assert(nativeResult.sameElements(fallbackResult))
assert(nativeImplicitConversionResult.sameElements(fallbackResult))
} finally {
+ UDFResolver.UDAFNames.add(udafClass)
spark.sql(s"DROP TABLE IF EXISTS $tbl")
spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS my_double_avg")
}
}
}
- test("test native hive udf") {
- val tbl = "test_hive_udf_replacement"
+ test("test native hive udaf in window") {
+ val tbl = "test_hive_udaf_window"
+ val udafClass = "test.org.apache.spark.sql.MyDoubleAvg"
+ val query =
+ s"""SELECT
+ | col1,
+ | my_double_avg(col1) OVER (
+ | PARTITION BY col1 % 2
+ | ORDER BY col1
+ | ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS
my_avg_window
+ |FROM $tbl
+ |ORDER BY col1
+ |""".stripMargin
+
withTempPath {
dir =>
try {
+ assert(UDFResolver.UDAFNames.contains(udafClass))
+
+ spark.sql(s"""
+ |CREATE TEMPORARY FUNCTION my_double_avg
+ |AS '$udafClass'
+ |""".stripMargin)
+ spark.sql(s"""
+ |DROP TABLE IF EXISTS $tbl;
+ |""".stripMargin)
spark.sql(s"""
|CREATE EXTERNAL TABLE $tbl
|LOCATION 'file://$dir'
- |AS select * from values (1, '1'), (2, '2'), (3, '3')
+ |AS SELECT CAST(v AS FLOAT) AS col1
+ |FROM VALUES (1.0), (2.0), (3.0), (4.0) AS t(v)
|""".stripMargin)
- // Check native hive udf has been registered.
- assert(
-
UDFResolver.UDFNames.contains("org.apache.spark.sql.hive.execution.UDFStringString"))
-
- spark.sql("""
- |CREATE TEMPORARY FUNCTION hive_string_string
- |AS 'org.apache.spark.sql.hive.execution.UDFStringString'
- |""".stripMargin)
-
- val offloadWithImplicitConversionDF =
- spark.sql(s"""SELECT hive_string_string(col1, 'a') FROM $tbl""")
-
checkGlutenPlan[ProjectExecTransformer](offloadWithImplicitConversionDF)
- val offloadWithImplicitConversionResult =
offloadWithImplicitConversionDF.collect()
+ val offloadDF = spark.sql(query)
+ checkGlutenPlan[WindowExecTransformer](offloadDF)
+ checkAnswer(
+ offloadDF,
+ Seq(
+ Row(1.0f, 101.0),
+ Row(2.0f, 102.0),
+ Row(3.0f, 102.0),
+ Row(4.0f, 103.0)
+ ))
+ val offloadResult = offloadDF.collect()
- val offloadDF =
- spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""")
- checkGlutenPlan[ProjectExecTransformer](offloadDF)
- val offloadResult = offloadWithImplicitConversionDF.collect()
-
- // Unregister native hive udf to fallback.
-
UDFResolver.UDFNames.remove("org.apache.spark.sql.hive.execution.UDFStringString")
- val fallbackDF =
- spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""")
- checkSparkPlan[ProjectExec](fallbackDF)
+ UDFResolver.UDAFNames.remove(udafClass)
+ val fallbackDF = spark.sql(query)
+ checkSparkPlan[WindowExec](fallbackDF)
val fallbackResult = fallbackDF.collect()
-
assert(offloadWithImplicitConversionResult.sameElements(fallbackResult))
- assert(offloadResult.sameElements(fallbackResult))
- // Add an unimplemented udf to the map to test fallback of
registered native hive udf.
-
UDFResolver.UDFNames.add("org.apache.spark.sql.hive.execution.UDFIntegerToString")
- spark.sql("""
- |CREATE TEMPORARY FUNCTION hive_int_to_string
- |AS
'org.apache.spark.sql.hive.execution.UDFIntegerToString'
- |""".stripMargin)
- val df = spark.sql(s"""select hive_int_to_string(col1) from $tbl""")
- checkSparkPlan[ProjectExec](df)
- checkAnswer(df, Seq(Row("1"), Row("2"), Row("3")))
+ assert(offloadResult.sameElements(fallbackResult))
} finally {
+ UDFResolver.UDAFNames.add(udafClass)
spark.sql(s"DROP TABLE IF EXISTS $tbl")
- spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_string_string")
- spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_int_to_string")
+ spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS my_double_avg")
}
}
}
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index 5477176ce8..b0fc0fc4a3 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -1165,7 +1165,8 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
windowParams.emplace_back(exprConverter_->toVeloxExpr(arg.value(),
inputType));
}
auto windowVeloxType =
SubstraitParser::parseType(windowFunction.output_type());
- auto windowCall = std::make_shared<const
core::CallTypedExpr>(windowVeloxType, std::move(windowParams), funcName);
+ auto windowCall = std::make_shared<const core::CallTypedExpr>(
+ windowVeloxType, std::move(windowParams),
exec::sanitizeName(funcName));
auto upperBound = windowFunction.upper_bound();
auto lowerBound = windowFunction.lower_bound();
auto type = windowFunction.window_type();
diff --git a/cpp/velox/udf/examples/UdfCommon.h
b/cpp/velox/udf/examples/UdfCommon.h
index a68c474607..7e73b916a2 100644
--- a/cpp/velox/udf/examples/UdfCommon.h
+++ b/cpp/velox/udf/examples/UdfCommon.h
@@ -24,7 +24,7 @@ namespace gluten {
class UdfRegisterer {
public:
- ~UdfRegisterer() = default;
+ virtual ~UdfRegisterer() = default;
// Returns the number of UDFs in populateUdfEntries.
virtual int getNumUdf() = 0;
@@ -38,7 +38,7 @@ class UdfRegisterer {
class UdafRegisterer {
public:
- ~UdafRegisterer() = default;
+ virtual ~UdafRegisterer() = default;
// Returns the number of UDFs in populateUdafEntries.
virtual int getNumUdaf() = 0;
diff --git a/docs/developers/VeloxUDF.md b/docs/developers/VeloxUDF.md
index b3154c41a7..a38f1a48db 100644
--- a/docs/developers/VeloxUDF.md
+++ b/docs/developers/VeloxUDF.md
@@ -14,6 +14,7 @@ Users can implement custom functions using the UDF interface
provided by Velox a
At runtime, these UDFs are registered alongside their Java implementations via
`CREATE TEMPORARY FUNCTION`.
Once registered, Gluten can parse and offload these UDFs to Velox during
execution,
meanwhile ensuring proper fallback to Java UDFs when necessary.
+Registered UDAFs can be used both as regular aggregate functions and as
aggregate window functions.
## Create and Build UDF/UDAF library
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 84e2d86554..79f3d67c0e 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -22,7 +22,7 @@ import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.SubstraitContext
-import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode, WindowFunctionNode}
+import org.apache.gluten.substrait.expression.WindowFunctionNode
import org.apache.spark.ShuffleDependency
import org.apache.spark.rdd.RDD
@@ -44,13 +44,11 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.window._
import org.apache.spark.sql.hive.HiveUDFTransformer
-import org.apache.spark.sql.types.{DecimalType, LongType, NullType, StructType}
+import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import java.io.{ObjectInputStream, ObjectOutputStream}
-import java.util.{ArrayList => JArrayList, List => JList}
-
-import scala.collection.JavaConverters._
+import java.util.{List => JList}
trait SparkPlanExecApi {
@@ -555,134 +553,7 @@ trait SparkPlanExecApi {
windowExpression: Seq[NamedExpression],
windowExpressionNodes: JList[WindowFunctionNode],
originalInputAttributes: Seq[Attribute],
- context: SubstraitContext): Unit = {
- windowExpression.map {
- windowExpr =>
- val aliasExpr = windowExpr.asInstanceOf[Alias]
- val columnName = s"${aliasExpr.name}_${aliasExpr.exprId.id}"
- val wExpression = aliasExpr.child.asInstanceOf[WindowExpression]
- wExpression.windowFunction match {
- case wf @ (RowNumber() | Rank(_) | DenseRank(_) | CumeDist() |
PercentRank(_)) =>
- val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction]
- val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
- val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- WindowFunctionsBuilder.create(context, aggWindowFunc).toInt,
- new JArrayList[ExpressionNode](),
- columnName,
- ConverterUtils.getTypeNode(aggWindowFunc.dataType,
aggWindowFunc.nullable),
- frame.upper,
- frame.lower,
- frame.frameType.sql,
- originalInputAttributes.asJava
- )
- windowExpressionNodes.add(windowFunctionNode)
- case aggExpression: AggregateExpression =>
- val frame =
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
- val aggregateFunc = aggExpression.aggregateFunction
- val substraitAggFuncName =
ExpressionMappings.expressionsMap.get(aggregateFunc.getClass)
- if (substraitAggFuncName.isEmpty) {
- throw new GlutenNotSupportException(s"Not currently supported:
$aggregateFunc.")
- }
-
- val childrenNodeList = aggregateFunc.children
- .map(
- ExpressionConverter
- .replaceWithExpressionTransformer(_, originalInputAttributes)
- .doTransform(context))
- .asJava
-
- val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- AggregateFunctionsBuilder.create(context,
aggExpression.aggregateFunction).toInt,
- childrenNodeList,
- columnName,
- ConverterUtils.getTypeNode(aggExpression.dataType,
aggExpression.nullable),
- frame.upper,
- frame.lower,
- frame.frameType.sql,
- originalInputAttributes.asJava
- )
- windowExpressionNodes.add(windowFunctionNode)
- case wf @ (_: Lead | _: Lag) =>
- val offsetWf = wf.asInstanceOf[FrameLessOffsetWindowFunction]
- val frame = offsetWf.frame.asInstanceOf[SpecifiedWindowFrame]
- val childrenNodeList = new JArrayList[ExpressionNode]()
- childrenNodeList.add(
- ExpressionConverter
- .replaceWithExpressionTransformer(
- offsetWf.input,
- attributeSeq = originalInputAttributes)
- .doTransform(context))
- // Spark only accepts foldable offset. Converts it to LongType
literal.
- val offset = offsetWf.offset.eval(EmptyRow).asInstanceOf[Int]
- // Velox only allows negative offset.
WindowFunctionsBuilder#create converts
- // lag/lead with negative offset to the function with positive
offset. So just
- // makes offsetNode store positive value.
- val offsetNode =
ExpressionBuilder.makeLiteral(Math.abs(offset.toLong), LongType, false)
- childrenNodeList.add(offsetNode)
- // NullType means Null is the default value. Don't pass it to
native.
- if (offsetWf.default.dataType != NullType) {
- childrenNodeList.add(
- ExpressionConverter
- .replaceWithExpressionTransformer(
- offsetWf.default,
- attributeSeq = originalInputAttributes)
- .doTransform(context))
- }
- val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- WindowFunctionsBuilder.create(context, offsetWf).toInt,
- childrenNodeList,
- columnName,
- ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
- frame.upper,
- frame.lower,
- frame.frameType.sql,
- offsetWf.ignoreNulls,
- originalInputAttributes.asJava
- )
- windowExpressionNodes.add(windowFunctionNode)
- case wf @ NthValue(input, offset: Literal, ignoreNulls: Boolean) =>
- val frame =
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
- val childrenNodeList = new JArrayList[ExpressionNode]()
- childrenNodeList.add(
- ExpressionConverter
- .replaceWithExpressionTransformer(input, attributeSeq =
originalInputAttributes)
- .doTransform(context))
-
childrenNodeList.add(LiteralTransformer(offset).doTransform(context))
- val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- WindowFunctionsBuilder.create(context, wf).toInt,
- childrenNodeList,
- columnName,
- ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
- frame.upper,
- frame.lower,
- frame.frameType.sql,
- ignoreNulls,
- originalInputAttributes.asJava
- )
- windowExpressionNodes.add(windowFunctionNode)
- case wf @ NTile(buckets: Expression) =>
- val frame =
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
- val childrenNodeList = new JArrayList[ExpressionNode]()
- val literal = buckets.asInstanceOf[Literal]
-
childrenNodeList.add(LiteralTransformer(literal).doTransform(context))
- val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- WindowFunctionsBuilder.create(context, wf).toInt,
- childrenNodeList,
- columnName,
- ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
- frame.upper,
- frame.lower,
- frame.frameType.sql,
- originalInputAttributes.asJava
- )
- windowExpressionNodes.add(windowFunctionNode)
- case _ =>
- throw new GlutenNotSupportException(
- "unsupported window function type: " +
- wExpression.windowFunction)
- }
- }
- }
+ context: SubstraitContext): Unit
def rewriteSpillPath(path: String): String = path
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]