This is an automated email from the ASF dual-hosted git repository.
yangzy pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 55b2e92f54 [GLUTEN-8772][CORE] refactor: Refactoring the use of
SubstraitContext#functionMap (#8775)
55b2e92f54 is described below
commit 55b2e92f54f95f980587d08f0beb430e0a922b15
Author: wypb <[email protected]>
AuthorDate: Thu Apr 17 22:01:24 2025 +0800
[GLUTEN-8772][CORE] refactor: Refactoring the use of
SubstraitContext#functionMap (#8775)
---
.../clickhouse/CHSparkPlanExecApi.scala | 24 ++++-----
.../backendsapi/clickhouse/CHTransformerApi.scala | 7 ++-
.../backendsapi/clickhouse/CHValidatorApi.scala | 2 +-
.../CHAggregateGroupLimitExecTransformer.scala | 5 +-
.../execution/CHHashAggregateExecTransformer.scala | 15 +++---
.../CHWindowGroupLimitExecTransformer.scala | 5 +-
.../expression/CHExpressionTransformer.scala | 60 ++++++++++------------
.../apache/gluten/expression/CHExpressions.scala | 10 ++--
.../org/apache/gluten/utils/PlanNodesUtil.scala | 3 +-
.../utils/RangePartitionerBoundsGenerator.scala | 6 +--
.../utils/MergeTreePartsPartitionsUtil.scala | 2 +-
.../backendsapi/velox/VeloxTransformerApi.scala | 3 +-
.../execution/HashAggregateExecTransformer.scala | 52 +++++++++----------
.../apache/gluten/execution/TopNTransformer.scala | 3 +-
.../gluten/expression/ExpressionTransformer.scala | 22 ++++----
.../execution/DeltaFilterExecTransformer.scala | 5 +-
.../execution/DeltaProjectExecTransformer.scala | 3 +-
.../substrait/expression/ExpressionBuilder.java | 10 ----
.../substrait/expression/WindowFunctionNode.java | 4 +-
.../gluten/backendsapi/SparkPlanExecApi.scala | 28 +++++-----
.../apache/gluten/backendsapi/TransformerApi.scala | 3 +-
.../BasicPhysicalOperatorTransformer.scala | 5 +-
.../execution/BasicScanExecTransformer.scala | 2 +-
.../CartesianProductExecTransformer.scala | 2 +-
.../gluten/execution/ExpandExecTransformer.scala | 3 +-
.../execution/GenerateExecTransformerBase.scala | 2 +-
.../gluten/execution/JoinExecTransformer.scala | 13 ++---
.../org/apache/gluten/execution/JoinUtils.scala | 12 ++---
.../gluten/execution/SampleExecTransformer.scala | 2 +-
.../gluten/execution/SortExecTransformer.scala | 3 +-
.../gluten/execution/WindowExecTransformer.scala | 7 ++-
.../WindowGroupLimitExecTransformer.scala | 5 +-
.../expression/AggregateFunctionsBuilder.scala | 9 ++--
.../expression/ArrayExpressionTransformer.scala | 5 +-
.../gluten/expression/ConditionalTransformer.scala | 17 +++---
.../gluten/expression/ExpressionTransformer.scala | 10 ++--
.../JsonTupleExpressionTransformer.scala | 10 ++--
.../expression/LambdaFunctionTransformer.scala | 5 +-
.../expression/MapExpressionTransformer.scala | 9 ++--
.../expression/NamedExpressionsTransformer.scala | 5 +-
.../PredicateExpressionTransformer.scala | 9 ++--
.../expression/ScalarSubqueryTransformer.scala | 3 +-
.../expression/UnaryExpressionTransformer.scala | 29 +++++------
.../gluten/expression/WindowFunctionsBuilder.scala | 7 ++-
.../org/apache/gluten/utils/SubstraitUtil.scala | 2 +-
.../python/EvalPythonExecTransformer.scala | 10 ++--
46 files changed, 208 insertions(+), 250 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index eecf0588d4..8136488b6a 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -25,6 +25,7 @@ import
org.apache.gluten.expression.ExpressionNames.MONOTONICALLY_INCREASING_ID
import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform
import org.apache.gluten.sql.shims.SparkShimLoader
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode, WindowFunctionNode}
import org.apache.gluten.utils.{CHJoinValidateUtil, UnknownJoinStrategy}
import org.apache.gluten.vectorized.{BlockOutputStream,
CHColumnarBatchSerializer, CHNativeBlock, CHStreamReader}
@@ -59,8 +60,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.commons.lang3.ClassUtils
import java.io.{ObjectInputStream, ObjectOutputStream}
-import java.lang.{Long => JLong}
-import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
+import java.util.{ArrayList => JArrayList, List => JList}
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
@@ -709,7 +709,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
windowExpression: Seq[NamedExpression],
windowExpressionNodes: JList[WindowFunctionNode],
originalInputAttributes: Seq[Attribute],
- args: JMap[String, JLong]): Unit = {
+ context: SubstraitContext): Unit = {
windowExpression.map {
windowExpr =>
@@ -721,7 +721,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction]
val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- WindowFunctionsBuilder.create(args, aggWindowFunc).toInt,
+ WindowFunctionsBuilder.create(context, aggWindowFunc).toInt,
new JArrayList[ExpressionNode](),
columnName,
ConverterUtils.getTypeNode(aggWindowFunc.dataType,
aggWindowFunc.nullable),
@@ -745,10 +745,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(expr,
originalInputAttributes)
- .doTransform(args)))
+ .doTransform(context)))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- CHExpressions.createAggregateFunction(args,
aggExpression.aggregateFunction).toInt,
+ CHExpressions.createAggregateFunction(context,
aggExpression.aggregateFunction).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(aggExpression.dataType,
aggExpression.nullable),
@@ -784,21 +784,21 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
.replaceWithExpressionTransformer(
offsetWf.input,
attributeSeq = originalInputAttributes)
- .doTransform(args))
+ .doTransform(context))
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
offsetWf.offset,
attributeSeq = originalInputAttributes)
- .doTransform(args))
+ .doTransform(context))
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
offsetWf.default,
attributeSeq = originalInputAttributes)
- .doTransform(args))
+ .doTransform(context))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- WindowFunctionsBuilder.create(args, offsetWf).toInt,
+ WindowFunctionsBuilder.create(context, offsetWf).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
@@ -812,9 +812,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
val frame =
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
val childrenNodeList = new JArrayList[ExpressionNode]()
val literal = buckets.asInstanceOf[Literal]
- childrenNodeList.add(LiteralTransformer(literal).doTransform(args))
+
childrenNodeList.add(LiteralTransformer(literal).doTransform(context))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- WindowFunctionsBuilder.create(args, wf).toInt,
+ WindowFunctionsBuilder.create(context, wf).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
index bec88d13d6..906d6d9ef7 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
@@ -19,6 +19,7 @@ package org.apache.gluten.backendsapi.clickhouse
import org.apache.gluten.backendsapi.TransformerApi
import org.apache.gluten.execution.{CHHashAggregateExecTransformer,
WriteFilesExecTransformer}
import org.apache.gluten.expression.ConverterUtils
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{BooleanLiteralNode,
ExpressionBuilder, ExpressionNode}
import org.apache.gluten.utils.{CHInputPartitionsUtil, ExpressionDocUtil}
@@ -211,16 +212,14 @@ class CHTransformerApi extends TransformerApi with
Logging {
}
override def createCheckOverflowExprNode(
- args: java.lang.Object,
+ context: SubstraitContext,
substraitExprName: String,
childNode: ExpressionNode,
childResultType: DataType,
dataType: DecimalType,
nullable: Boolean,
nullOnOverflow: Boolean): ExpressionNode = {
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
+ val functionId = context.registerFunction(
ConverterUtils.makeFuncName(
substraitExprName,
Seq(dataType, BooleanType),
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHValidatorApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHValidatorApi.scala
index c2b52d5919..49efc676c3 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHValidatorApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHValidatorApi.scala
@@ -85,7 +85,7 @@ class CHValidatorApi extends ValidatorApi with
AdaptiveSparkPlanHelper with Logg
expr =>
val node = ExpressionConverter
.replaceWithExpressionTransformer(expr, outputAttributes)
- .doTransform(substraitContext.registeredFunction)
+ .doTransform(substraitContext)
node.isInstanceOf[SelectionNode]
}
if (allSelectionNodes || supportShuffleWithProject(outputPartitioning,
child)) {
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHAggregateGroupLimitExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHAggregateGroupLimitExecTransformer.scala
index 83bb33bfa2..fab7f41e01 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHAggregateGroupLimitExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHAggregateGroupLimitExecTransformer.scala
@@ -86,13 +86,12 @@ case class CHAggregateGroupLimitExecTransformer(
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
- val args = context.registeredFunction
// Partition By Expressions
val partitionsExpressions = partitionSpec
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, attributeSeq = child.output)
- .doTransform(args))
+ .doTransform(context))
.asJava
// Sort By Expressions
@@ -102,7 +101,7 @@ case class CHAggregateGroupLimitExecTransformer(
val builder = SortField.newBuilder()
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(order.child, attributeSeq =
child.output)
- .doTransform(args)
+ .doTransform(context)
builder.setExpr(exprNode.toProtobuf)
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
builder.build()
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
index 1cabb5bc75..f941df18eb 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
@@ -238,7 +238,6 @@ case class CHHashAggregateExecTransformer(
operatorId: Long,
input: RelNode = null,
validation: Boolean): RelNode = {
- val args = context.registeredFunction
// Get the grouping nodes.
val groupingList = new util.ArrayList[ExpressionNode]()
groupingExpressions.foreach(
@@ -247,7 +246,7 @@ case class CHHashAggregateExecTransformer(
// may be different for each backend.
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(expr, childOutput)
- .doTransform(args)
+ .doTransform(context)
groupingList.add(exprNode)
})
// Get the aggregate function nodes.
@@ -267,7 +266,7 @@ case class CHHashAggregateExecTransformer(
if (aggExpr.filter.isDefined) {
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(aggExpr.filter.get, childOutput)
- .doTransform(args)
+ .doTransform(context)
aggFilterList.add(exprNode)
} else {
aggFilterList.add(null)
@@ -281,7 +280,7 @@ case class CHHashAggregateExecTransformer(
expr => {
ExpressionConverter
.replaceWithExpressionTransformer(expr, childOutput)
- .doTransform(args)
+ .doTransform(context)
})
val extraNodes = aggregateFunc match {
@@ -290,7 +289,7 @@ case class CHHashAggregateExecTransformer(
Seq(
ExpressionConverter
.replaceWithExpressionTransformer(relativeSDLiteral,
child.output)
- .doTransform(args))
+ .doTransform(context))
case _ => Seq.empty
}
@@ -311,12 +310,12 @@ case class CHHashAggregateExecTransformer(
child.asInstanceOf[BaseAggregateExec].groupingExpressions,
child.asInstanceOf[BaseAggregateExec].aggregateExpressions)
)
- Seq(aggTypesExpr.doTransform(args))
+ Seq(aggTypesExpr.doTransform(context))
case Final | PartialMerge =>
Seq(
ExpressionConverter
.replaceWithExpressionTransformer(aggExpr.resultAttribute,
originalInputAttributes)
- .doTransform(args))
+ .doTransform(context))
case other =>
throw new GlutenNotSupportException(s"$other not supported.")
}
@@ -324,7 +323,7 @@ case class CHHashAggregateExecTransformer(
childrenNodeList.add(node)
}
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
- CHExpressions.createAggregateFunction(args, aggregateFunc),
+ CHExpressions.createAggregateFunction(context, aggregateFunc),
childrenNodeList,
modeToKeyWord(aggExpr.mode),
ConverterUtils.getTypeNode(aggregateFunc.dataType,
aggregateFunc.nullable)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala
index 793d733abf..1111102e89 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala
@@ -96,13 +96,12 @@ case class CHWindowGroupLimitExecTransformer(
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
- val args = context.registeredFunction
// Partition By Expressions
val partitionsExpressions = partitionSpec
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, attributeSeq = child.output)
- .doTransform(args))
+ .doTransform(context))
.asJava
// Sort By Expressions
@@ -112,7 +111,7 @@ case class CHWindowGroupLimitExecTransformer(
val builder = SortField.newBuilder()
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(order.child, attributeSeq =
child.output)
- .doTransform(args)
+ .doTransform(context)
builder.setExpr(exprNode.toProtobuf)
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
builder.build()
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala
index 0f6c4e05a0..c5111fef83 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala
@@ -19,6 +19,7 @@ package org.apache.gluten.expression
import org.apache.gluten.backendsapi.clickhouse.CHConfig
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression.ConverterUtils.FunctionConfig
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression._
import org.apache.spark.sql.catalyst.expressions._
@@ -39,7 +40,7 @@ case class CHTruncTimestampTransformer(
extends ExpressionTransformer {
override def children: Seq[ExpressionTransformer] = format :: timestamp ::
Nil
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
// The format must be constant string in the function date_trunc of ch.
if (!original.format.foldable) {
throw new GlutenNotSupportException(s"The format ${original.format} must
be constant string.")
@@ -78,20 +79,17 @@ case class CHTruncTimestampTransformer(
s"${timeZoneId.get}.")
}
- val timestampNode = timestamp.doTransform(args)
+ val timestampNode = timestamp.doTransform(context)
val lowerFormatNode = ExpressionBuilder.makeStringLiteral(newFormatStr)
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
-
val dataTypes = if (timeZoneId.nonEmpty) {
Seq(original.format.dataType, original.timestamp.dataType, StringType)
} else {
Seq(original.format.dataType, original.timestamp.dataType)
}
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(substraitExprName, dataTypes))
+ val functionId =
+ context.registerFunction(ConverterUtils.makeFuncName(substraitExprName,
dataTypes))
val expressionNodes = new java.util.ArrayList[ExpressionNode]()
expressionNodes.add(lowerFormatNode)
@@ -114,10 +112,10 @@ case class CHStringTranslateTransformer(
extends ExpressionTransformer {
override def children: Seq[ExpressionTransformer] = srcExpr :: matchingExpr
:: replaceExpr :: Nil
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
// In CH, translateUTF8 requires matchingExpr and replaceExpr argument
have the same length
- val matchingNode = matchingExpr.doTransform(args)
- val replaceNode = replaceExpr.doTransform(args)
+ val matchingNode = matchingExpr.doTransform(context)
+ val replaceNode = replaceExpr.doTransform(context)
if (
!matchingNode.isInstanceOf[StringLiteralNode] ||
!replaceNode.isInstanceOf[StringLiteralNode]
@@ -125,7 +123,7 @@ case class CHStringTranslateTransformer(
throw new GlutenNotSupportException(s"$original not supported yet.")
}
- super.doTransform(args)
+ super.doTransform(context)
}
}
@@ -136,15 +134,11 @@ case class CHPosExplodeTransformer(
attributeSeq: Seq[Attribute])
extends UnaryExpressionTransformer {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val childNode: ExpressionNode = child.doTransform(args)
- val funcMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
- val funcId = ExpressionBuilder.newScalarFunction(
- funcMap,
- ConverterUtils.makeFuncName(
- ExpressionNames.POSEXPLODE,
- Seq(original.child.dataType),
- FunctionConfig.OPT))
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
+ val childNode: ExpressionNode = child.doTransform(context)
+ val funcId = context.registerFunction(
+ ConverterUtils
+ .makeFuncName(ExpressionNames.POSEXPLODE,
Seq(original.child.dataType), FunctionConfig.OPT))
val childType = original.child.dataType
childType match {
case a: ArrayType =>
@@ -181,10 +175,10 @@ case class CHRegExpReplaceTransformer(
extends ExpressionTransformer {
override def children: Seq[ExpressionTransformer] =
childrenWithPos.dropRight(1)
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
// In CH: replaceRegexpAll(subject, regexp, rep), which is equivalent
// In Spark: regexp_replace(subject, regexp, rep, pos=1)
- val posNode = childrenWithPos(3).doTransform(args)
+ val posNode = childrenWithPos(3).doTransform(context)
if (
!posNode.isInstanceOf[IntLiteralNode] ||
posNode.asInstanceOf[IntLiteralNode].getValue != 1
@@ -192,7 +186,7 @@ case class CHRegExpReplaceTransformer(
throw new UnsupportedOperationException(s"$original dose not supported
position yet.")
}
// Replace $num in rep with \num used in CH
- val repNode = childrenWithPos(2).doTransform(args)
+ val repNode = childrenWithPos(2).doTransform(context)
repNode match {
case node: StringLiteralNode =>
val strValue = node.getValue
@@ -204,19 +198,18 @@ case class CHRegExpReplaceTransformer(
FunctionConfig.OPT)
val replacedRepNode = ExpressionBuilder.makeLiteral(replacedValue,
StringType, false)
val exprNodes = Lists.newArrayList(
- childrenWithPos(0).doTransform(args),
- childrenWithPos(1).doTransform(args),
+ childrenWithPos(0).doTransform(context),
+ childrenWithPos(1).doTransform(context),
replacedRepNode)
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
return ExpressionBuilder.makeScalarFunction(
- ExpressionBuilder.newScalarFunction(functionMap, functionName),
+ context.registerFunction(functionName),
exprNodes,
ConverterUtils.getTypeNode(original.dataType, original.nullable))
}
case _ =>
}
- super.doTransform(args)
+ super.doTransform(context)
}
}
@@ -227,11 +220,10 @@ case class GetArrayItemTransformer(
original: Expression)
extends BinaryExpressionTransformer {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
// Ignore failOnError for clickhouse backend
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val leftNode = left.doTransform(args)
- var rightNode = right.doTransform(args)
+ val leftNode = left.doTransform(context)
+ var rightNode = right.doTransform(context)
val getArrayItem = original.asInstanceOf[GetArrayItem]
@@ -242,7 +234,7 @@ case class GetArrayItemTransformer(
ExpressionNames.ADD,
Seq(IntegerType, getArrayItem.right.dataType),
FunctionConfig.OPT)
- val addFunctionId = ExpressionBuilder.newScalarFunction(functionMap,
addFunctionName)
+ val addFunctionId = context.registerFunction(addFunctionName)
val literalNode = ExpressionBuilder.makeLiteral(1, IntegerType, false)
rightNode = ExpressionBuilder.makeScalarFunction(
addFunctionId,
@@ -255,7 +247,7 @@ case class GetArrayItemTransformer(
FunctionConfig.OPT)
val exprNodes = Lists.newArrayList(leftNode, rightNode)
ExpressionBuilder.makeScalarFunction(
- ExpressionBuilder.newScalarFunction(functionMap, functionName),
+ context.registerFunction(functionName),
exprNodes,
ConverterUtils.getTypeNode(getArrayItem.dataType, getArrayItem.nullable))
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala
index fa8a5763a6..70d45a4e52 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala
@@ -18,26 +18,24 @@ package org.apache.gluten.expression
import org.apache.gluten.expression.ConverterUtils.FunctionConfig
import org.apache.gluten.extension.ExpressionExtensionTrait
-import org.apache.gluten.substrait.expression.ExpressionBuilder
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
// Static helper object for handling expressions that are specifically used in
CH backend.
object CHExpressions {
// Since https://github.com/apache/incubator-gluten/pull/1937.
- def createAggregateFunction(args: java.lang.Object, aggregateFunc:
AggregateFunction): Long = {
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
+ def createAggregateFunction(context: SubstraitContext, aggregateFunc:
AggregateFunction): Long = {
val expressionExtensionTransformer =
ExpressionExtensionTrait.findExpressionExtension(aggregateFunc.getClass)
if (expressionExtensionTransformer.nonEmpty) {
val (substraitAggFuncName, inputTypes) =
expressionExtensionTransformer.get.buildCustomAggregateFunction(aggregateFunc)
assert(substraitAggFuncName.isDefined)
- return ExpressionBuilder.newScalarFunction(
- functionMap,
+ return context.registerFunction(
ConverterUtils.makeFuncName(substraitAggFuncName.get, inputTypes,
FunctionConfig.REQ))
}
- AggregateFunctionsBuilder.create(args, aggregateFunc)
+ AggregateFunctionsBuilder.create(context, aggregateFunc)
}
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala
index d6511f7a4a..5026cfa4f2 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala
@@ -42,13 +42,12 @@ object PlanNodesUtil {
// project
operatorId =
context.nextOperatorId("ClickHouseBuildSideRelationProjection")
- val args = context.registeredFunction
val columnarProjExpr = ExpressionConverter
.replaceWithExpressionTransformer(key, attributeSeq = output)
val projExprNodeList = new java.util.ArrayList[ExpressionNode]()
- columnarProjExpr.foreach(e => projExprNodeList.add(e.doTransform(args)))
+ columnarProjExpr.foreach(e => projExprNodeList.add(e.doTransform(context)))
PlanBuilder.makePlan(
context,
diff --git
a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/RangePartitionerBoundsGenerator.scala
b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/RangePartitionerBoundsGenerator.scala
index 694035b878..adba216b43 100644
---
a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/RangePartitionerBoundsGenerator.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/RangePartitionerBoundsGenerator.scala
@@ -120,10 +120,9 @@ class RangePartitionerBoundsGenerator[K: Ordering:
ClassTag, V](
context: SubstraitContext,
ordering: SortOrder,
attributes: Seq[Attribute]): Int = {
- val funcs = context.registeredFunction
val projExprNode = ExpressionConverter
.replaceWithExpressionTransformer(ordering.child, attributes)
- .doTransform(funcs)
+ .doTransform(context)
val pb = projExprNode.toProtobuf
if (!pb.hasSelection) {
throw new IllegalArgumentException(s"A sorting field should be an
attribute")
@@ -135,7 +134,6 @@ class RangePartitionerBoundsGenerator[K: Ordering:
ClassTag, V](
private def buildProjectionPlan(
context: SubstraitContext,
sortExpressions: Seq[NamedExpression]): PlanNode = {
- val args = context.registeredFunction
val columnarProjExprs = sortExpressions.map(
expr => {
ExpressionConverter
@@ -143,7 +141,7 @@ class RangePartitionerBoundsGenerator[K: Ordering:
ClassTag, V](
})
val projExprNodeList = new java.util.ArrayList[ExpressionNode]()
for (expr <- columnarProjExprs) {
- projExprNodeList.add(expr.doTransform(args))
+ projExprNodeList.add(expr.doTransform(context))
}
val projectRel = RelBuilder.makeProjectRel(null, projExprNodeList,
context, 0)
val outNames = new util.ArrayList[String]
diff --git
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/clickhouse/utils/MergeTreePartsPartitionsUtil.scala
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/clickhouse/utils/MergeTreePartsPartitionsUtil.scala
index 096212b80b..c6511a1d70 100644
---
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/clickhouse/utils/MergeTreePartsPartitionsUtil.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/clickhouse/utils/MergeTreePartsPartitionsUtil.scala
@@ -628,7 +628,7 @@ object MergeTreePartsPartitionsUtil extends Logging {
typeNodes,
nameList,
columnTypeNodes,
-
transformer.map(_.doTransform(substraitContext.registeredFunction)).orNull,
+ transformer.map(_.doTransform(substraitContext)).orNull,
extensionNode,
substraitContext,
substraitContext.nextOperatorId("readRel")
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
index 9949f8822a..bfc4a27b51 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
@@ -23,6 +23,7 @@ import
org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.gluten.expression.ConverterUtils
import org.apache.gluten.proto.ConfigMap
import org.apache.gluten.runtime.Runtimes
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
import org.apache.gluten.utils.InputPartitionsUtil
import org.apache.gluten.vectorized.PlanEvaluatorJniWrapper
@@ -73,7 +74,7 @@ class VeloxTransformerApi extends TransformerApi with Logging
{
}
override def createCheckOverflowExprNode(
- args: java.lang.Object,
+ context: SubstraitContext,
substraitExprName: String,
childNode: ExpressionNode,
childResultType: DataType,
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
index 70a81e1cdb..fbf32cc45e 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
@@ -37,8 +37,7 @@ import org.apache.spark.sql.types._
import com.google.protobuf.StringValue
-import java.lang.{Long => JLong}
-import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList}
+import java.util.{ArrayList => JArrayList, List => JList}
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
@@ -201,7 +200,7 @@ abstract class HashAggregateExecTransformer(
// Create aggregate function node and add to list.
private def addFunctionNode(
- args: java.lang.Object,
+ context: SubstraitContext,
aggregateFunction: AggregateFunction,
childrenNodeList: JList[ExpressionNode],
aggregateMode: AggregateMode,
@@ -212,7 +211,7 @@ abstract class HashAggregateExecTransformer(
aggregateMode match {
case Partial | PartialMerge =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
- VeloxAggregateFunctionsBuilder.create(args, aggregateFunction,
aggregateMode),
+ VeloxAggregateFunctionsBuilder.create(context, aggregateFunction,
aggregateMode),
childrenNodeList,
modeKeyWord,
VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)
@@ -220,7 +219,7 @@ abstract class HashAggregateExecTransformer(
aggregateNodeList.add(aggFunctionNode)
case Final | Complete =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
- VeloxAggregateFunctionsBuilder.create(args, aggregateFunction,
aggregateMode),
+ VeloxAggregateFunctionsBuilder.create(context, aggregateFunction,
aggregateMode),
childrenNodeList,
modeKeyWord,
ConverterUtils.getTypeNode(aggregateFunction.dataType,
aggregateFunction.nullable)
@@ -238,7 +237,7 @@ abstract class HashAggregateExecTransformer(
aggregateMode match {
case Partial | PartialMerge =>
val partialNode = ExpressionBuilder.makeAggregateFunction(
- VeloxAggregateFunctionsBuilder.create(args, aggregateFunction,
aggregateMode),
+ VeloxAggregateFunctionsBuilder.create(context,
aggregateFunction, aggregateMode),
childrenNodeList,
modeKeyWord,
ConverterUtils.getTypeNode(
@@ -248,7 +247,7 @@ abstract class HashAggregateExecTransformer(
aggregateNodeList.add(partialNode)
case Final | Complete =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
- VeloxAggregateFunctionsBuilder.create(args, aggregateFunction,
aggregateMode),
+ VeloxAggregateFunctionsBuilder.create(context,
aggregateFunction, aggregateMode),
childrenNodeList,
modeKeyWord,
ConverterUtils.getTypeNode(aggregateFunction.dataType,
aggregateFunction.nullable)
@@ -296,15 +295,14 @@ abstract class HashAggregateExecTransformer(
// Return a scalar function node representing row construct function in
Velox.
private def getRowConstructNode(
- args: java.lang.Object,
+ context: SubstraitContext,
childNodes: JList[ExpressionNode],
rowConstructAttributes: Seq[Attribute],
aggFunc: AggregateFunction): ScalarFunctionNode = {
- val functionMap = args.asInstanceOf[JHashMap[String, JLong]]
val functionName = ConverterUtils.makeFuncName(
VeloxIntermediateData.getRowConstructFuncName(aggFunc),
rowConstructAttributes.map(attr => attr.dataType))
- val functionId = ExpressionBuilder.newScalarFunction(functionMap,
functionName)
+ val functionId = context.registerFunction(functionName)
// Use struct type to represent Velox RowType.
val structTypeNodes = rowConstructAttributes
@@ -326,7 +324,6 @@ abstract class HashAggregateExecTransformer(
operatorId: Long,
inputRel: RelNode,
validation: Boolean): RelNode = {
- val args = context.registeredFunction
// Create a projection for row construct.
val exprNodes = new JArrayList[ExpressionNode]()
groupingExpressions.foreach(
@@ -334,7 +331,7 @@ abstract class HashAggregateExecTransformer(
exprNodes.add(
ExpressionConverter
.replaceWithExpressionTransformer(expr, originalInputAttributes)
- .doTransform(args))
+ .doTransform(context))
})
for (aggregateExpression <- aggregateExpressions) {
@@ -346,7 +343,7 @@ abstract class HashAggregateExecTransformer(
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
- .doTransform(args)
+ .doTransform(context)
)
.asJava
exprNodes.addAll(childNodes)
@@ -387,7 +384,7 @@ abstract class HashAggregateExecTransformer(
val attr = rewrittenInputAttributes(adjustedIdx)
val aggFuncInputAttrNode = ExpressionConverter
.replaceWithExpressionTransformer(attr,
originalInputAttributes)
- .doTransform(args)
+ .doTransform(context)
val expressionNode = if (sparkType != veloxType) {
newInputAttributes +=
attr.copy(dataType = veloxType)(attr.exprId,
attr.qualifier)
@@ -403,7 +400,7 @@ abstract class HashAggregateExecTransformer(
}
}
exprNodes.add(
- getRowConstructNode(args, childNodes,
newInputAttributes.toSeq, aggFunc))
+ getRowConstructNode(context, childNodes,
newInputAttributes.toSeq, aggFunc))
case other =>
throw new GlutenNotSupportException(s"$other is not supported.")
}
@@ -415,7 +412,7 @@ abstract class HashAggregateExecTransformer(
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
- .doTransform(args)
+ .doTransform(context)
)
.asJava
exprNodes.addAll(childNodes)
@@ -469,7 +466,7 @@ abstract class HashAggregateExecTransformer(
throw new GlutenNotSupportException(
s"$aggFunc of ${aggExpr.mode.toString} is not supported.")
}
- addFunctionNode(args, aggFunc, childrenNodes, aggExpr.mode,
aggregateFunctionList)
+ addFunctionNode(context, aggFunc, childrenNodes, aggExpr.mode,
aggregateFunctionList)
})
val extensionNode = getAdvancedExtension()
@@ -566,7 +563,6 @@ abstract class HashAggregateExecTransformer(
operatorId: Long,
input: RelNode = null,
validation: Boolean): RelNode = {
- val args = context.registeredFunction
// Get the grouping nodes.
// Use 'child.output' as based Seq[Attribute], the originalInputAttributes
// may be different for each backend.
@@ -574,7 +570,7 @@ abstract class HashAggregateExecTransformer(
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, child.output)
- .doTransform(args))
+ .doTransform(context))
.asJava
// Get the aggregate function nodes.
val aggFilterList = new JArrayList[ExpressionNode]()
@@ -584,7 +580,7 @@ abstract class HashAggregateExecTransformer(
if (aggExpr.filter.isDefined) {
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(aggExpr.filter.get, child.output)
- .doTransform(args)
+ .doTransform(context)
aggFilterList.add(exprNode)
} else {
// The number of filters should be aligned with that of aggregate
functions.
@@ -597,7 +593,7 @@ abstract class HashAggregateExecTransformer(
expr => {
ExpressionConverter
.replaceWithExpressionTransformer(expr,
originalInputAttributes)
- .doTransform(args)
+ .doTransform(context)
})
case PartialMerge | Final =>
rewriteAggBufferAttributes(
@@ -606,13 +602,13 @@ abstract class HashAggregateExecTransformer(
attr =>
ExpressionConverter
.replaceWithExpressionTransformer(attr,
originalInputAttributes)
- .doTransform(args)
+ .doTransform(context)
}
case other =>
throw new GlutenNotSupportException(s"$other not supported.")
}
addFunctionNode(
- args,
+ context,
aggregateFunc,
childrenNodes.asJava,
aggExpr.mode,
@@ -662,8 +658,8 @@ object VeloxAggregateFunctionsBuilder {
/**
* Create a scalar function for the input aggregate function.
- * @param args:
- * the function map.
+ * @param context:
+ * the SubstraitContext.
* @param aggregateFunc:
* the input aggregate function.
* @param mode:
@@ -671,10 +667,9 @@ object VeloxAggregateFunctionsBuilder {
* @return
*/
def create(
- args: java.lang.Object,
+ context: SubstraitContext,
aggregateFunc: AggregateFunction,
mode: AggregateMode): Long = {
- val functionMap = args.asInstanceOf[JHashMap[String, JLong]]
val (sigName, aggFunc) =
try {
(AggregateFunctionsBuilder.getSubstraitFunctionName(aggregateFunc),
aggregateFunc)
@@ -688,8 +683,7 @@ object VeloxAggregateFunctionsBuilder {
case e: Throwable => throw e
}
- ExpressionBuilder.newScalarFunction(
- functionMap,
+ context.registerFunction(
ConverterUtils.makeFuncName(
// Substrait-to-Velox procedure will choose appropriate companion
function if needed.
sigName,
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala
index f3bc929d7e..50e7cf9c51 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala
@@ -87,13 +87,12 @@ case class TopNTransformer(
inputAttributes: Seq[Attribute],
input: RelNode,
validation: Boolean): RelNode = {
- val args = context.registeredFunction
val sortFieldList = sortOrder.map {
order =>
val builder = SortField.newBuilder()
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(order.child, attributeSeq =
child.output)
- .doTransform(args)
+ .doTransform(context)
builder.setExpr(exprNode.toProtobuf)
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
index 1c47ff2a1d..a5e77920e4 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
@@ -19,13 +19,14 @@ package org.apache.gluten.expression
import org.apache.gluten.expression.ConverterUtils.FunctionConfig
import
org.apache.gluten.expression.ExpressionConverter.replaceWithExpressionTransformer
import org.apache.gluten.substrait.`type`.StructNode
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{IntegerType, LongType}
-import java.lang.{Integer => JInteger, Long => JLong}
-import java.util.{ArrayList => JArrayList, HashMap => JHashMap}
+import java.lang.{Integer => JInteger}
+import java.util.{ArrayList => JArrayList}
import scala.language.existentials
@@ -35,8 +36,8 @@ case class VeloxAliasTransformer(
original: Expression)
extends UnaryExpressionTransformer {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- child.doTransform(args)
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
+ child.doTransform(context)
}
}
@@ -58,8 +59,8 @@ case class VeloxGetStructFieldTransformer(
extends BinaryExpressionTransformer {
override def left: ExpressionTransformer = child
override def right: ExpressionTransformer = LiteralTransformer(ordinal)
- override def doTransform(args: Object): ExpressionNode = {
- val childNode = child.doTransform(args)
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
+ val childNode = child.doTransform(context)
childNode match {
case node: StructLiteralNode =>
node.getFieldLiteral(ordinal)
@@ -71,7 +72,7 @@ case class VeloxGetStructFieldTransformer(
node.getTypeNode.asInstanceOf[StructNode].getFieldTypes.get(ordinal)
ExpressionBuilder.makeNullLiteral(nodeType)
case _ =>
- super.doTransform(args)
+ super.doTransform(context)
}
}
}
@@ -82,7 +83,7 @@ case class VeloxHashExpressionTransformer(
original: HashExpression[_])
extends ExpressionTransformer {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
// As of Spark 3.3, there are 3 kinds of HashExpression.
// HiveHash is not supported in native backend and will fail native
validation.
val (seedNode, seedType) = original match {
@@ -98,13 +99,12 @@ case class VeloxHashExpressionTransformer(
nodes.add(seedNode)
children.foreach(
expression => {
- nodes.add(expression.doTransform(args))
+ nodes.add(expression.doTransform(context))
})
val childrenTypes = seedType +: original.children.map(child =>
child.dataType)
- val functionMap = args.asInstanceOf[JHashMap[String, JLong]]
val functionName =
ConverterUtils.makeFuncName(substraitExprName, childrenTypes,
FunctionConfig.OPT)
- val functionId = ExpressionBuilder.newScalarFunction(functionMap,
functionName)
+ val functionId = context.registerFunction(functionName)
val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
ExpressionBuilder.makeScalarFunction(functionId, nodes, typeNode)
}
diff --git
a/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaFilterExecTransformer.scala
b/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaFilterExecTransformer.scala
index 0c8cd54902..b71ff4ca4b 100644
---
a/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaFilterExecTransformer.scala
+++
b/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaFilterExecTransformer.scala
@@ -49,17 +49,16 @@ case class DeltaFilterExecTransformer(condition:
Expression, child: SparkPlan)
input: RelNode,
validation: Boolean): RelNode = {
assert(condExpr != null)
- val args = context.registeredFunction
val condExprNode = condExpr match {
case IncrementMetric(child, metric) =>
extraMetrics :+= (condExpr.prettyName, metric)
ExpressionConverter
.replaceWithExpressionTransformer(child, attributeSeq =
originalInputAttributes)
- .doTransform(args)
+ .doTransform(context)
case _ =>
ExpressionConverter
.replaceWithExpressionTransformer(condExpr, attributeSeq =
originalInputAttributes)
- .doTransform(args)
+ .doTransform(context)
}
if (!validation) {
diff --git
a/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaProjectExecTransformer.scala
b/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaProjectExecTransformer.scala
index a2be01a1f0..39e8d5bfa9 100644
---
a/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaProjectExecTransformer.scala
+++
b/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaProjectExecTransformer.scala
@@ -49,11 +49,10 @@ case class DeltaProjectExecTransformer(projectList:
Seq[NamedExpression], child:
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
- val args = context.registeredFunction
val newProjectList = genNewProjectList(projectList)
val columnarProjExprs: Seq[ExpressionTransformer] = ExpressionConverter
.replaceWithExpressionTransformer(newProjectList, attributeSeq =
originalInputAttributes)
- val projExprNodeList = columnarProjExprs.map(_.doTransform(args)).asJava
+ val projExprNodeList = columnarProjExprs.map(_.doTransform(context)).asJava
val emitStartIndex = originalInputAttributes.size
if (!validation) {
RelBuilder.makeProjectRel(input, projExprNodeList, context, operatorId,
emitStartIndex)
diff --git
a/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java
index 1e6c58f682..07c73452bd 100644
---
a/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java
+++
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java
@@ -36,16 +36,6 @@ import java.util.Map;
public class ExpressionBuilder {
private ExpressionBuilder() {}
- public static Long newScalarFunction(Map<String, Long> functionMap, String
functionName) {
- if (!functionMap.containsKey(functionName)) {
- Long functionId = (long) functionMap.size();
- functionMap.put(functionName, functionId);
- return functionId;
- } else {
- return functionMap.get(functionName);
- }
- }
-
public static NullLiteralNode makeNullLiteral(TypeNode typeNode) {
return new NullLiteralNode(typeNode);
}
diff --git
a/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java
index b9f1fbc126..a114c6050a 100644
---
a/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java
+++
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java
@@ -18,6 +18,7 @@ package org.apache.gluten.substrait.expression;
import org.apache.gluten.exception.GlutenException;
import org.apache.gluten.expression.ExpressionConverter;
+import org.apache.gluten.substrait.SubstraitContext;
import org.apache.gluten.substrait.type.TypeNode;
import io.substrait.proto.Expression;
@@ -29,7 +30,6 @@ import
org.apache.spark.sql.catalyst.expressions.PreComputeRangeFrameBound;
import java.io.Serializable;
import java.util.ArrayList;
-import java.util.HashMap;
import java.util.List;
import scala.collection.JavaConverters;
@@ -104,7 +104,7 @@ public class WindowFunctionNode implements Serializable {
JavaConverters.asScalaIteratorConverter(originalInputAttributes.iterator())
.asScala()
.toSeq())
- .doTransform(new HashMap<String, Long>());
+ .doTransform(new SubstraitContext());
Long offset = Long.valueOf(boundType.eval(null).toString());
if (offset < 0) {
Expression.WindowFunction.Bound.Preceding.Builder
refPrecedingBuilder =
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 1bb5a255f5..a798053f6f 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -20,6 +20,7 @@ import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.sql.shims.SparkShimLoader
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode, WindowFunctionNode}
import org.apache.spark.ShuffleDependency
@@ -46,8 +47,7 @@ import org.apache.spark.sql.types.{DecimalType, LongType,
NullType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import java.io.{ObjectInputStream, ObjectOutputStream}
-import java.lang.{Long => JLong}
-import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
+import java.util.{ArrayList => JArrayList, List => JList}
import scala.collection.JavaConverters._
@@ -483,7 +483,7 @@ trait SparkPlanExecApi {
windowExpression: Seq[NamedExpression],
windowExpressionNodes: JList[WindowFunctionNode],
originalInputAttributes: Seq[Attribute],
- args: JMap[String, JLong]): Unit = {
+ context: SubstraitContext): Unit = {
windowExpression.map {
windowExpr =>
val aliasExpr = windowExpr.asInstanceOf[Alias]
@@ -494,7 +494,7 @@ trait SparkPlanExecApi {
val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction]
val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- WindowFunctionsBuilder.create(args, aggWindowFunc).toInt,
+ WindowFunctionsBuilder.create(context, aggWindowFunc).toInt,
new JArrayList[ExpressionNode](),
columnName,
ConverterUtils.getTypeNode(aggWindowFunc.dataType,
aggWindowFunc.nullable),
@@ -516,11 +516,11 @@ trait SparkPlanExecApi {
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
- .doTransform(args))
+ .doTransform(context))
.asJava
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- AggregateFunctionsBuilder.create(args,
aggExpression.aggregateFunction).toInt,
+ AggregateFunctionsBuilder.create(context,
aggExpression.aggregateFunction).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(aggExpression.dataType,
aggExpression.nullable),
@@ -539,7 +539,7 @@ trait SparkPlanExecApi {
.replaceWithExpressionTransformer(
offsetWf.input,
attributeSeq = originalInputAttributes)
- .doTransform(args))
+ .doTransform(context))
// Spark only accepts foldable offset. Converts it to LongType
literal.
val offset = offsetWf.offset.eval(EmptyRow).asInstanceOf[Int]
// Velox only allows negative offset.
WindowFunctionsBuilder#create converts
@@ -554,10 +554,10 @@ trait SparkPlanExecApi {
.replaceWithExpressionTransformer(
offsetWf.default,
attributeSeq = originalInputAttributes)
- .doTransform(args))
+ .doTransform(context))
}
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- WindowFunctionsBuilder.create(args, offsetWf).toInt,
+ WindowFunctionsBuilder.create(context, offsetWf).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
@@ -574,10 +574,10 @@ trait SparkPlanExecApi {
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(input, attributeSeq =
originalInputAttributes)
- .doTransform(args))
- childrenNodeList.add(LiteralTransformer(offset).doTransform(args))
+ .doTransform(context))
+
childrenNodeList.add(LiteralTransformer(offset).doTransform(context))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- WindowFunctionsBuilder.create(args, wf).toInt,
+ WindowFunctionsBuilder.create(context, wf).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
@@ -592,9 +592,9 @@ trait SparkPlanExecApi {
val frame =
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
val childrenNodeList = new JArrayList[ExpressionNode]()
val literal = buckets.asInstanceOf[Literal]
- childrenNodeList.add(LiteralTransformer(literal).doTransform(args))
+
childrenNodeList.add(LiteralTransformer(literal).doTransform(context))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- WindowFunctionsBuilder.create(args, wf).toInt,
+ WindowFunctionsBuilder.create(context, wf).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
index 984450bf16..92d6ebd325 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
@@ -17,6 +17,7 @@
package org.apache.gluten.backendsapi
import org.apache.gluten.execution.WriteFilesExecTransformer
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.ExpressionNode
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
@@ -61,7 +62,7 @@ trait TransformerApi {
}
def createCheckOverflowExprNode(
- args: java.lang.Object,
+ context: SubstraitContext,
substraitExprName: String,
childNode: ExpressionNode,
childResultType: DataType,
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
index d1f3462564..b9a7ac2e83 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
@@ -79,7 +79,7 @@ abstract class FilterExecTransformerBase(val cond:
Expression, val input: SparkP
assert(condExpr != null)
val condExprNode = ExpressionConverter
.replaceWithExpressionTransformer(condExpr, originalInputAttributes)
- .doTransform(context.registeredFunction)
+ .doTransform(context)
RelBuilder.makeFilterRel(
context,
condExprNode,
@@ -222,10 +222,9 @@ abstract class ProjectExecTransformerBase(val list:
Seq[NamedExpression], val in
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
- val args = context.registeredFunction
val columnarProjExprs: Seq[ExpressionTransformer] = ExpressionConverter
.replaceWithExpressionTransformer(projectList, originalInputAttributes)
- val projExprNodeList = columnarProjExprs.map(_.doTransform(args)).asJava
+ val projExprNodeList = columnarProjExprs.map(_.doTransform(context)).asJava
RelBuilder.makeProjectRel(
originalInputAttributes.asJava,
input,
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala
index 056c35a527..b1ba7820d5 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala
@@ -139,7 +139,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport
with BaseDataSource
.map(ExpressionConverter.replaceAttributeReference)
.reduceLeftOption(And)
.map(ExpressionConverter.replaceWithExpressionTransformer(_, output))
- val filterNodes =
transformer.map(_.doTransform(context.registeredFunction))
+ val filterNodes = transformer.map(_.doTransform(context))
val exprNode = filterNodes.orNull
// used by CH backend
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala
index 9e2f12bcf8..ec71afe03c 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala
@@ -127,7 +127,7 @@ case class CartesianProductExecTransformer(
expr =>
ExpressionConverter
.replaceWithExpressionTransformer(expr, left.output ++ right.output)
- .doTransform(substraitContext.registeredFunction)
+ .doTransform(substraitContext)
}
val extensionNode =
JoinUtils.createExtensionNode(left.output ++ right.output, validation =
true)
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala
index c6936daaff..8a40fc0412 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala
@@ -67,7 +67,6 @@ case class ExpandExecTransformer(
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
- val args = context.registeredFunction
val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]()
projections.foreach {
projectSet =>
@@ -76,7 +75,7 @@ case class ExpandExecTransformer(
project =>
val projectExprNode = ExpressionConverter
.replaceWithExpressionTransformer(project,
originalInputAttributes)
- .doTransform(args)
+ .doTransform(context)
projectExprNodes.add(projectExprNode)
}
projectSetExprNodes.add(projectExprNodes)
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala
index 698d1f14c5..20c1e088b5 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala
@@ -94,5 +94,5 @@ abstract class GenerateExecTransformerBase(
private def getGeneratorNode(context: SubstraitContext): ExpressionNode =
ExpressionConverter
.replaceWithExpressionTransformer(generator, child.output)
- .doTransform(context.registeredFunction)
+ .doTransform(context)
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
index 86e6c1f412..439b5689e8 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
@@ -41,9 +41,6 @@ import com.google.common.collect.Lists
import com.google.protobuf.{Any, StringValue}
import io.substrait.proto.JoinRel
-import java.lang.{Long => JLong}
-import java.util.{Map => JMap}
-
trait ColumnarShuffledJoin extends BaseJoinExec {
def isSkewJoin: Boolean
@@ -324,9 +321,8 @@ object HashJoinLikeExecTransformer {
leftType: DataType,
rightNode: ExpressionNode,
rightType: DataType,
- functionMap: JMap[String, JLong]): ExpressionNode = {
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
+ context: SubstraitContext): ExpressionNode = {
+ val functionId = context.registerFunction(
ConverterUtils.makeFuncName(ExpressionNames.EQUAL, Seq(leftType,
rightType)))
val expressionNodes = Lists.newArrayList(leftNode, rightNode)
@@ -338,9 +334,8 @@ object HashJoinLikeExecTransformer {
def makeAndExpression(
leftNode: ExpressionNode,
rightNode: ExpressionNode,
- functionMap: JMap[String, JLong]): ExpressionNode = {
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
+ context: SubstraitContext): ExpressionNode = {
+ val functionId = context.registerFunction(
ConverterUtils.makeFuncName(ExpressionNames.AND, Seq(BooleanType,
BooleanType)))
val expressionNodes = Lists.newArrayList(leftNode, rightNode)
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
index 303c9e818f..12b544de90 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
@@ -64,7 +64,7 @@ object JoinUtils {
ExpressionConverter
.replaceWithExpressionTransformer(expr,
partialConstructedJoinOutput)
.asInstanceOf[AttributeReferenceTransformer]
- .doTransform(substraitContext.registeredFunction),
+ .doTransform(substraitContext),
expr.dataType)
}
(keys, inputNode, inputNodeOutput)
@@ -78,7 +78,7 @@ object JoinUtils {
(
ExpressionConverter
.replaceWithExpressionTransformer(expr, inputNodeOutput)
- .doTransform(substraitContext.registeredFunction),
+ .doTransform(substraitContext),
expr.dataType))
}
val preProjectNode = RelBuilder.makeProjectRel(
@@ -100,7 +100,7 @@ object JoinUtils {
ExpressionConverter
.replaceWithExpressionTransformer(a,
partialConstructedJoinOutput)
.asInstanceOf[AttributeReferenceTransformer]
- .doTransform(substraitContext.registeredFunction),
+ .doTransform(substraitContext),
a.dataType)
case _ =>
val (key, idx) = appendedKeysAndIndices.next()
@@ -207,11 +207,9 @@ object JoinUtils {
leftType,
rightKey,
rightType,
- substraitContext.registeredFunction)
+ substraitContext)
}
- .reduce(
- (l, r) =>
- HashJoinLikeExecTransformer.makeAndExpression(l, r,
substraitContext.registeredFunction))
+ .reduce((l, r) => HashJoinLikeExecTransformer.makeAndExpression(l, r,
substraitContext))
// Create post-join filter, which will be computed in hash join.
val postJoinFilter =
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala
index 4fed8b36e9..c3a70bb81a 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala
@@ -78,7 +78,7 @@ case class SampleExecTransformer(
assert(condExpr != null)
val condExprNode = ExpressionConverter
.replaceWithExpressionTransformer(condExpr, originalInputAttributes)
- .doTransform(context.registeredFunction)
+ .doTransform(context)
RelBuilder.makeFilterRel(
context,
condExprNode,
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala
index 6f9564e6d5..5ee9e3f381 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala
@@ -68,13 +68,12 @@ case class SortExecTransformer(
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
- val args = context.registeredFunction
val sortFieldList = sortOrder.map {
order =>
val builder = SortField.newBuilder()
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(order.child, attributeSeq =
child.output)
- .doTransform(args)
+ .doTransform(context)
builder.setExpr(exprNode.toProtobuf)
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala
index 792885ef2f..2c934b1b5c 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala
@@ -105,14 +105,13 @@ case class WindowExecTransformer(
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
- val args = context.registeredFunction
// WindowFunction Expressions
val windowExpressions = new JArrayList[WindowFunctionNode]()
BackendsApiManager.getSparkPlanExecApiInstance.genWindowFunctionsNode(
windowExpression,
windowExpressions,
originalInputAttributes,
- args
+ context
)
// Partition By Expressions
@@ -120,7 +119,7 @@ case class WindowExecTransformer(
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, attributeSeq = child.output)
- .doTransform(args))
+ .doTransform(context))
.asJava
// Sort By Expressions
@@ -130,7 +129,7 @@ case class WindowExecTransformer(
val builder = SortField.newBuilder()
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(order.child, attributeSeq =
child.output)
- .doTransform(args)
+ .doTransform(context)
builder.setExpr(exprNode.toProtobuf)
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
builder.build()
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala
index d96d04dfad..5ee7bdd684 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala
@@ -91,13 +91,12 @@ case class WindowGroupLimitExecTransformer(
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
- val args = context.registeredFunction
// Partition By Expressions
val partitionsExpressions = partitionSpec
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, attributeSeq = child.output)
- .doTransform(args))
+ .doTransform(context))
.asJava
// Sort By Expressions
@@ -107,7 +106,7 @@ case class WindowGroupLimitExecTransformer(
val builder = SortField.newBuilder()
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(order.child, attributeSeq =
child.output)
- .doTransform(args)
+ .doTransform(context)
builder.setExpr(exprNode.toProtobuf)
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
builder.build()
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala
index 15de4a734d..a567903eda 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala
@@ -19,15 +19,13 @@ package org.apache.gluten.expression
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.substrait.expression.ExpressionBuilder
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.types.DataType
object AggregateFunctionsBuilder {
- def create(args: java.lang.Object, aggregateFunc: AggregateFunction): Long =
{
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
-
+ def create(context: SubstraitContext, aggregateFunc: AggregateFunction):
Long = {
// First handle the custom aggregate functions
val substraitAggFuncName = getSubstraitFunctionName(aggregateFunc)
@@ -42,8 +40,7 @@ object AggregateFunctionsBuilder {
val inputTypes: Seq[DataType] = aggregateFunc.children.map(child =>
child.dataType)
- ExpressionBuilder.newScalarFunction(
- functionMap,
+ context.registerFunction(
ConverterUtils.makeFuncName(substraitAggFuncName, inputTypes,
FunctionConfig.REQ))
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala
index 2a09e039e5..765f2dafea 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala
@@ -17,6 +17,7 @@
package org.apache.gluten.expression
import org.apache.gluten.exception.GlutenNotSupportException
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.ExpressionNode
import org.apache.spark.sql.catalyst.expressions._
@@ -27,7 +28,7 @@ case class CreateArrayTransformer(
original: CreateArray)
extends ExpressionTransformer {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
// If children is empty,
// transformation is only supported when useStringTypeWhenEmpty is false
// because ClickHouse and Velox currently doesn't support this config.
@@ -35,6 +36,6 @@ case class CreateArrayTransformer(
throw new GlutenNotSupportException(s"$original not supported yet.")
}
- super.doTransform(args)
+ super.doTransform(context)
}
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala
index 1dffd39063..1e3df12901 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala
@@ -16,6 +16,7 @@
*/
package org.apache.gluten.expression
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode, IfThenNode}
import org.apache.spark.sql.catalyst.expressions._
@@ -32,19 +33,19 @@ case class CaseWhenTransformer(
override def children: Seq[ExpressionTransformer] =
branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
// generate branches nodes
val ifNodes = new JArrayList[ExpressionNode]
val thenNodes = new JArrayList[ExpressionNode]
branches.foreach(
branch => {
- ifNodes.add(branch._1.doTransform(args))
- thenNodes.add(branch._2.doTransform(args))
+ ifNodes.add(branch._1.doTransform(context))
+ thenNodes.add(branch._2.doTransform(context))
})
val branchDataType =
original.asInstanceOf[CaseWhen].inputTypesForMerging(0)
// generate else value node, maybe null
val elseValueNode = elseValue
- .map(_.doTransform(args))
+ .map(_.doTransform(context))
.getOrElse(ExpressionBuilder.makeLiteral(null, branchDataType, true))
new IfThenNode(ifNodes, thenNodes, elseValueNode)
}
@@ -59,14 +60,14 @@ case class IfTransformer(
extends ExpressionTransformer {
override def children: Seq[ExpressionTransformer] = predicate :: trueValue
:: falseValue :: Nil
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
val ifNodes = new JArrayList[ExpressionNode]
- ifNodes.add(predicate.doTransform(args))
+ ifNodes.add(predicate.doTransform(context))
val thenNodes = new JArrayList[ExpressionNode]
- thenNodes.add(trueValue.doTransform(args))
+ thenNodes.add(trueValue.doTransform(context))
- val elseValueNode = falseValue.doTransform(args)
+ val elseValueNode = falseValue.doTransform(context)
new IfThenNode(ifNodes, thenNodes, elseValueNode)
}
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
index ebb9db3e82..cfda2d8782 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
@@ -16,6 +16,7 @@
*/
package org.apache.gluten.expression
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
@@ -32,13 +33,12 @@ trait ExpressionTransformer {
def dataType: DataType = original.dataType
def nullable: Boolean = original.nullable
- def doTransform(args: java.lang.Object): ExpressionNode = {
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
+ def doTransform(context: SubstraitContext): ExpressionNode = {
// TODO: the funcName seems can be simplified to `substraitExprName`
val funcName: String =
ConverterUtils.makeFuncName(substraitExprName,
original.children.map(_.dataType))
- val functionId = ExpressionBuilder.newScalarFunction(functionMap, funcName)
- val childNodes = children.map(_.doTransform(args)).asJava
+ val functionId = context.registerFunction(funcName)
+ val childNodes = children.map(_.doTransform(context)).asJava
val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
ExpressionBuilder.makeScalarFunction(functionId, childNodes, typeNode)
}
@@ -78,7 +78,7 @@ object GenericExpressionTransformer {
case class LiteralTransformer(original: Literal) extends
LeafExpressionTransformer {
override def substraitExprName: String = "literal"
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
ExpressionBuilder.makeLiteral(original.value, original.dataType,
original.nullable)
}
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala
index 25e3e12a53..c5978f714b 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala
@@ -18,6 +18,7 @@ package org.apache.gluten.expression
import org.apache.gluten.expression.ConverterUtils.FunctionConfig
import org.apache.gluten.substrait.`type`.ListNode
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
import org.apache.spark.sql.catalyst.expressions.Expression
@@ -30,19 +31,18 @@ case class JsonTupleExpressionTransformer(
original: Expression)
extends ExpressionTransformer {
- override def doTransform(args: Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
val jsonExpr = children.head
val fields = children.tail
- val jsonExprNode = jsonExpr.doTransform(args)
+ val jsonExprNode = jsonExpr.doTransform(context)
val expressNodes = Lists.newArrayList(jsonExprNode)
- fields.foreach(f => expressNodes.add(f.doTransform(args)))
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
+ fields.foreach(f => expressNodes.add(f.doTransform(context)))
val functionName =
ConverterUtils.makeFuncName(
substraitExprName,
original.children.map(_.dataType),
FunctionConfig.REQ)
- val functionId = ExpressionBuilder.newScalarFunction(functionMap,
functionName)
+ val functionId = context.registerFunction(functionName)
val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
typeNode match {
case node: ListNode =>
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala
index 9e7285ac3a..ba20e9737d 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala
@@ -17,6 +17,7 @@
package org.apache.gluten.expression
import org.apache.gluten.exception.GlutenNotSupportException
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.ExpressionNode
import org.apache.spark.sql.catalyst.expressions.LambdaFunction
@@ -29,11 +30,11 @@ case class LambdaFunctionTransformer(
extends ExpressionTransformer {
override def children: Seq[ExpressionTransformer] = function +: arguments
- override def doTransform(args: Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
// Need to fallback when hidden be true as it's not supported in Velox
if (original.hidden) {
throw new GlutenNotSupportException(s"Unsupported LambdaFunction with
hidden be true.")
}
- super.doTransform(args)
+ super.doTransform(context)
}
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala
index fe715979b1..c9f0f19c4e 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala
@@ -18,6 +18,7 @@ package org.apache.gluten.expression
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.ExpressionNode
import org.apache.spark.sql.catalyst.expressions._
@@ -28,7 +29,7 @@ case class CreateMapTransformer(
original: CreateMap)
extends ExpressionTransformer {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
// If children is empty,
// transformation is only supported when useStringTypeWhenEmpty is false
// because ClickHouse and Velox currently doesn't support this config.
@@ -36,7 +37,7 @@ case class CreateMapTransformer(
throw new GlutenNotSupportException(s"$original not supported yet.")
}
- super.doTransform(args)
+ super.doTransform(context)
}
}
@@ -48,7 +49,7 @@ case class GetMapValueTransformer(
original: GetMapValue)
extends BinaryExpressionTransformer {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
if (BackendsApiManager.getSettings.alwaysFailOnMapExpression()) {
throw new GlutenNotSupportException(s"$original not supported yet.")
}
@@ -57,6 +58,6 @@ case class GetMapValueTransformer(
throw new GlutenNotSupportException(s"$original not supported yet.")
}
- super.doTransform(args)
+ super.doTransform(context)
}
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala
index f4c703d88e..76437d0c3e 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala
@@ -16,6 +16,7 @@
*/
package org.apache.gluten.expression
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
import org.apache.spark.sql.catalyst.expressions._
@@ -31,14 +32,14 @@ case class AttributeReferenceTransformer(
original: AttributeReference,
bound: BoundReference)
extends LeafExpressionTransformer {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
ExpressionBuilder.makeSelection(bound.ordinal.asInstanceOf[java.lang.Integer])
}
}
case class BoundReferenceTransformer(substraitExprName: String, original:
BoundReference)
extends LeafExpressionTransformer {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
ExpressionBuilder.makeSelection(original.ordinal.asInstanceOf[java.lang.Integer])
}
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala
index d13c61d64a..9f443973a9 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala
@@ -16,6 +16,7 @@
*/
package org.apache.gluten.expression
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
import org.apache.spark.sql.catalyst.expressions._
@@ -25,11 +26,11 @@ import scala.collection.JavaConverters._
case class InTransformer(substraitExprName: String, child:
ExpressionTransformer, original: In)
extends UnaryExpressionTransformer {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
assert(original.list.forall(_.foldable))
// Stores the values in a List Literal.
val values: Set[Any] = original.list.map(_.eval()).toSet
- InExpressionTransformer.toTransformer(child.doTransform(args), values,
child.dataType)
+ InExpressionTransformer.toTransformer(child.doTransform(context), values,
child.dataType)
}
}
@@ -38,9 +39,9 @@ case class InSetTransformer(
child: ExpressionTransformer,
original: InSet)
extends UnaryExpressionTransformer {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
InExpressionTransformer.toTransformer(
- child.doTransform(args),
+ child.doTransform(context),
original.hset,
original.child.dataType)
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala
index 9508d27df7..a1c6e9b715 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala
@@ -16,6 +16,7 @@
*/
package org.apache.gluten.expression
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
import org.apache.spark.sql.catalyst.InternalRow
@@ -26,7 +27,7 @@ case class ScalarSubqueryTransformer(substraitExprName:
String, query: ScalarSub
extends LeafExpressionTransformer {
override def original: Expression = query
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
// don't trigger collect when in validation phase
if (TransformerState.underValidationState) {
return ExpressionBuilder.makeLiteral(null, query.dataType, true)
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
index f9eb1e8eab..f309621cce 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
@@ -21,6 +21,7 @@ import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.`type`.ListNode
import org.apache.gluten.substrait.`type`.MapNode
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode, StructLiteralNode}
import org.apache.spark.sql.catalyst.expressions._
@@ -35,18 +36,18 @@ case class ChildTransformer(
extends UnaryExpressionTransformer {
override def dataType: DataType = child.dataType
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- child.doTransform(args)
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
+ child.doTransform(context)
}
}
case class CastTransformer(substraitExprName: String, child:
ExpressionTransformer, original: Cast)
extends UnaryExpressionTransformer {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
val typeNode = ConverterUtils.getTypeNode(dataType, original.nullable)
ExpressionBuilder.makeCast(
typeNode,
- child.doTransform(args),
+ child.doTransform(context),
SparkShimLoader.getSparkShims.withAnsiEvalMode(original))
}
}
@@ -57,12 +58,10 @@ case class ExplodeTransformer(
original: Explode)
extends UnaryExpressionTransformer {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val childNode: ExpressionNode = child.doTransform(args)
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
+ val childNode: ExpressionNode = child.doTransform(context)
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
+ val functionId = context.registerFunction(
ConverterUtils.makeFuncName(substraitExprName,
Seq(original.child.dataType)))
val expressionNodes = Lists.newArrayList(childNode)
@@ -83,11 +82,11 @@ case class CheckOverflowTransformer(
child: ExpressionTransformer,
original: CheckOverflow)
extends UnaryExpressionTransformer {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
BackendsApiManager.getTransformerApiInstance.createCheckOverflowExprNode(
- args,
+ context,
substraitExprName,
- child.doTransform(args),
+ child.doTransform(context),
original.child.dataType,
original.dataType,
original.nullable,
@@ -103,13 +102,13 @@ case class GetStructFieldTransformer(
override def left: ExpressionTransformer = child
override def right: ExpressionTransformer =
LiteralTransformer(original.ordinal)
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val childNode = child.doTransform(args)
+ override def doTransform(context: SubstraitContext): ExpressionNode = {
+ val childNode = child.doTransform(context)
childNode match {
case node: StructLiteralNode =>
node.getFieldLiteral(original.ordinal)
case _ =>
- super.doTransform(args)
+ super.doTransform(context)
}
}
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/WindowFunctionsBuilder.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/WindowFunctionsBuilder.scala
index 831e319973..5873c29f20 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/WindowFunctionsBuilder.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/WindowFunctionsBuilder.scala
@@ -19,15 +19,14 @@ package org.apache.gluten.expression
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression.ConverterUtils.FunctionConfig
import org.apache.gluten.expression.ExpressionNames.{LAG, LEAD}
-import org.apache.gluten.substrait.expression.ExpressionBuilder
+import org.apache.gluten.substrait.SubstraitContext
import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, Lag,
Lead, WindowExpression, WindowFunction}
import scala.util.control.Breaks.{break, breakable}
object WindowFunctionsBuilder {
- def create(args: java.lang.Object, windowFunc: WindowFunction): Long = {
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
+ def create(context: SubstraitContext, windowFunc: WindowFunction): Long = {
val substraitFunc = windowFunc match {
// Handle lag with negative inputOffset, e.g., converts lag(c1, -1) to
lead(c1, 1).
// Spark uses `-inputOffset` as `offset` for Lag function.
@@ -46,7 +45,7 @@ object WindowFunctionsBuilder {
val functionName =
ConverterUtils.makeFuncName(substraitFunc.get, Seq(windowFunc.dataType),
FunctionConfig.OPT)
- ExpressionBuilder.newScalarFunction(functionMap, functionName)
+ context.registerFunction(functionName)
}
def extractWindowExpression(expr: Expression): WindowExpression = {
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala
index 41d5b18c38..80b3a365b8 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala
@@ -84,7 +84,7 @@ object SubstraitUtil {
context: SubstraitContext): ExpressionNode = {
ExpressionConverter
.replaceWithExpressionTransformer(expr, attributeSeq)
- .doTransform(context.registeredFunction)
+ .doTransform(context)
}
def createNameStructBuilder(
diff --git
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala
index 7b4c09d4f9..553fc15fac 100644
---
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala
@@ -67,7 +67,6 @@ case class EvalPythonExecTransformer(
}
val context = new SubstraitContext
- val args = context.registeredFunction
val operatorId = context.nextOperatorId(this.nodeName)
val expressionNodes = new JArrayList[ExpressionNode]
@@ -76,7 +75,9 @@ case class EvalPythonExecTransformer(
udfs.foreach(
udf => {
expressionNodes.add(
- ExpressionConverter.replaceWithExpressionTransformer(udf,
child.output).doTransform(args))
+ ExpressionConverter
+ .replaceWithExpressionTransformer(udf, child.output)
+ .doTransform(context))
})
val relNode = RelBuilder.makeProjectRel(null, expressionNodes, context,
operatorId)
@@ -86,7 +87,6 @@ case class EvalPythonExecTransformer(
override protected def doTransform(context: SubstraitContext):
TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
- val args = context.registeredFunction
val operatorId = context.nextOperatorId(this.nodeName)
val expressionNodes = new JArrayList[ExpressionNode]
child.output.zipWithIndex.foreach(
@@ -94,7 +94,9 @@ case class EvalPythonExecTransformer(
udfs.foreach(
udf => {
expressionNodes.add(
- ExpressionConverter.replaceWithExpressionTransformer(udf,
child.output).doTransform(args))
+ ExpressionConverter
+ .replaceWithExpressionTransformer(udf, child.output)
+ .doTransform(context))
})
val relNode =
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]