This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 9bb4b28a5 [GLUTEN-6951][CORE][CH] Move CustomerExpressionTransformer
to CH backend (#6993)
9bb4b28a5 is described below
commit 9bb4b28a52193aa446be90706df1465ca35460b5
Author: Hongze Zhang <[email protected]>
AuthorDate: Mon Aug 26 15:29:13 2024 +0800
[GLUTEN-6951][CORE][CH] Move CustomerExpressionTransformer to CH backend
(#6993)
Closes #6951
---
.../backendsapi/clickhouse/CHListenerApi.scala | 9 +
.../clickhouse/CHSparkPlanExecApi.scala | 22 +-
.../execution/CHHashAggregateExecTransformer.scala | 13 +-
.../apache/gluten/expression/CHExpressions.scala | 45 ++++
.../extension/ExpressionExtensionTrait.scala | 11 +-
.../apache/spark/sql/utils/ExpressionUtil.scala | 3 +-
.../extension/CustomerExpressionTransformer.scala | 0
...ckhouseCustomerExpressionTransformerSuite.scala | 24 +-
.../scala/org/apache/gluten/GlutenPlugin.scala | 12 +-
.../gluten/backendsapi/SparkPlanExecApi.scala | 9 +-
.../expression/AggregateFunctionsBuilder.scala | 34 +--
.../gluten/expression/ExpressionConverter.scala | 254 +++++++++------------
.../gluten/expression/ExpressionMappings.scala | 22 +-
.../utils/clickhouse/ClickHouseTestSettings.scala | 3 +-
.../gluten/utils/velox/VeloxTestSettings.scala | 3 +-
.../scala/org/apache/gluten/GlutenConfig.scala | 2 +
.../spark/sql/catalyst/expressions/EvalMode.scala | 36 +++
.../spark/sql/catalyst/expressions/EvalMode.scala | 36 +++
18 files changed, 323 insertions(+), 215 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
index 69797feb6..60dc3dad0 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
@@ -21,6 +21,7 @@ import org.apache.gluten.backendsapi.ListenerApi
import org.apache.gluten.execution.CHBroadcastBuildSideCache
import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects,
GlutenParquetWriterInjects, GlutenRowSplitter}
import org.apache.gluten.expression.UDFMappings
+import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.vectorized.{CHNativeExpressionEvaluator, JniLibLoader}
import org.apache.spark.{SparkConf, SparkContext}
@@ -30,6 +31,7 @@ import org.apache.spark.listener.CHGlutenSQLAppStatusListener
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rpc.{GlutenDriverEndpoint, GlutenExecutorEndpoint}
import org.apache.spark.sql.execution.datasources.v1._
+import org.apache.spark.sql.utils.ExpressionUtil
import org.apache.spark.util.SparkDirectoryUtil
import org.apache.commons.lang3.StringUtils
@@ -42,6 +44,13 @@ class CHListenerApi extends ListenerApi with Logging {
GlutenDriverEndpoint.glutenDriverEndpointRef = (new
GlutenDriverEndpoint).self
CHGlutenSQLAppStatusListener.registerListener(sc)
initialize(pc.conf, isDriver = true)
+
+ val expressionExtensionTransformer =
ExpressionUtil.extendedExpressionTransformer(
+ pc.conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "")
+ )
+ if (expressionExtensionTransformer != null) {
+ ExpressionExtensionTrait.expressionExtensionTransformer =
expressionExtensionTransformer
+ }
}
override def onDriverShutdown(): Unit = shutdown()
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index bfa59aee7..a8996c4d2 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -18,10 +18,10 @@ package org.apache.gluten.backendsapi.clickhouse
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi}
-import org.apache.gluten.exception.GlutenException
-import org.apache.gluten.exception.GlutenNotSupportException
+import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException}
import org.apache.gluten.execution._
import org.apache.gluten.expression._
+import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.extension.columnar.AddFallbackTagRule
import
org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
import org.apache.gluten.extension.columnar.transition.Convention
@@ -558,9 +558,25 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
Sig[CollectList](ExpressionNames.COLLECT_LIST),
Sig[CollectSet](ExpressionNames.COLLECT_SET)
) ++
+
ExpressionExtensionTrait.expressionExtensionTransformer.expressionSigList ++
SparkShimLoader.getSparkShims.bloomFilterExpressionMappings()
}
+ /** Define backend-specific expression converter. */
+ override def extraExpressionConverter(
+ substraitExprName: String,
+ expr: Expression,
+ attributeSeq: Seq[Attribute]): Option[ExpressionTransformer] = expr
match {
+ case e
+ if
ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping
+ .contains(e.getClass) =>
+ // Use extended expression transformer to replace custom expression first
+ Some(
+ ExpressionExtensionTrait.expressionExtensionTransformer
+ .replaceWithExtensionExpressionTransformer(substraitExprName, e,
attributeSeq))
+ case _ => None
+ }
+
override def genStringTranslateTransformer(
substraitExprName: String,
srcExpr: ExpressionTransformer,
@@ -700,7 +716,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
.doTransform(args)))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
- AggregateFunctionsBuilder.create(args,
aggExpression.aggregateFunction).toInt,
+ CHExpressions.createAggregateFunction(args,
aggExpression.aggregateFunction).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(aggExpression.dataType,
aggExpression.nullable),
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
index 6c1fee39c..d641c05cd 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
@@ -20,6 +20,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import
org.apache.gluten.execution.CHHashAggregateExecTransformer.getAggregateResultAttributes
import org.apache.gluten.expression._
+import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode}
import org.apache.gluten.substrait.{AggregationParams, SubstraitContext}
import org.apache.gluten.substrait.expression.{AggregateFunctionNode,
ExpressionBuilder, ExpressionNode}
@@ -249,7 +250,7 @@ case class CHHashAggregateExecTransformer(
childrenNodeList.add(node)
}
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
- AggregateFunctionsBuilder.create(args, aggregateFunc),
+ CHExpressions.createAggregateFunction(args, aggregateFunc),
childrenNodeList,
modeToKeyWord(aggExpr.mode),
ConverterUtils.getTypeNode(aggregateFunc.dataType,
aggregateFunc.nullable)
@@ -286,10 +287,10 @@ case class CHHashAggregateExecTransformer(
val aggregateFunc = aggExpr.aggregateFunction
var aggFunctionName =
if (
-
ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
- aggregateFunc.getClass)
+
ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping
+ .contains(aggregateFunc.getClass)
) {
- ExpressionMappings.expressionExtensionTransformer
+ ExpressionExtensionTrait.expressionExtensionTransformer
.buildCustomAggregateFunction(aggregateFunc)
._1
.get
@@ -437,10 +438,10 @@ case class CHHashAggregateExecPullOutHelper(
val aggregateFunc = exp.aggregateFunction
// First handle the custom aggregate functions
if (
-
ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
+
ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping.contains(
aggregateFunc.getClass)
) {
- ExpressionMappings.expressionExtensionTransformer
+ ExpressionExtensionTrait.expressionExtensionTransformer
.getAttrsIndexForExtensionAggregateExpr(
aggregateFunc,
exp.mode,
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala
new file mode 100644
index 000000000..af1ac52b1
--- /dev/null
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.expression
+
+import org.apache.gluten.expression.ConverterUtils.FunctionConfig
+import org.apache.gluten.extension.ExpressionExtensionTrait
+import org.apache.gluten.substrait.expression.ExpressionBuilder
+
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
+
+// Static helper object for handling expressions that are specifically used in
CH backend.
+object CHExpressions {
+ // Since https://github.com/apache/incubator-gluten/pull/1937.
+ def createAggregateFunction(args: java.lang.Object, aggregateFunc:
AggregateFunction): Long = {
+ val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
+ if (
+
ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping.contains(
+ aggregateFunc.getClass)
+ ) {
+ val (substraitAggFuncName, inputTypes) =
+
ExpressionExtensionTrait.expressionExtensionTransformer.buildCustomAggregateFunction(
+ aggregateFunc)
+ assert(substraitAggFuncName.isDefined)
+ return ExpressionBuilder.newScalarFunction(
+ functionMap,
+ ConverterUtils.makeFuncName(substraitAggFuncName.get, inputTypes,
FunctionConfig.REQ))
+ }
+
+ AggregateFunctionsBuilder.create(args, aggregateFunc)
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala
similarity index 86%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala
rename to
backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala
index 89bcb7064..c64f26869 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala
@@ -63,8 +63,13 @@ trait ExpressionExtensionTrait {
}
}
-case class DefaultExpressionExtensionTransformer() extends
ExpressionExtensionTrait with Logging {
+object ExpressionExtensionTrait {
+ var expressionExtensionTransformer: ExpressionExtensionTrait =
+ DefaultExpressionExtensionTransformer()
- /** Generate the extension expressions list, format:
Sig[XXXExpression]("XXXExpressionName") */
- override def expressionSigList: Seq[Sig] = Seq.empty[Sig]
+ case class DefaultExpressionExtensionTransformer() extends
ExpressionExtensionTrait with Logging {
+
+ /** Generate the extension expressions list, format:
Sig[XXXExpression]("XXXExpressionName") */
+ override def expressionSigList: Seq[Sig] = Seq.empty[Sig]
+ }
}
diff --git
a/gluten-core/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala
similarity index 92%
rename from
gluten-core/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala
rename to
backends-clickhouse/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala
index b5c45e090..852b34a09 100644
--- a/gluten-core/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/utils/ExpressionUtil.scala
@@ -16,7 +16,8 @@
*/
package org.apache.spark.sql.utils
-import org.apache.gluten.extension.{DefaultExpressionExtensionTransformer,
ExpressionExtensionTrait}
+import org.apache.gluten.extension.ExpressionExtensionTrait
+import
org.apache.gluten.extension.ExpressionExtensionTrait.DefaultExpressionExtensionTransformer
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala
similarity index 100%
rename from
gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala
rename to
backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenCustomerExpressionTransformerSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala
similarity index 87%
rename from
gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenCustomerExpressionTransformerSuite.scala
rename to
backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala
index 91344f877..cd8bf579f 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenCustomerExpressionTransformerSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/extension/GlutenClickhouseCustomerExpressionTransformerSuite.scala
@@ -16,24 +16,25 @@
*/
package org.apache.spark.sql.extension
-import org.apache.gluten.execution.ProjectExecTransformer
+import
org.apache.gluten.execution.{GlutenClickHouseWholeStageTransformerSuite,
ProjectExecTransformer}
import org.apache.gluten.expression.ExpressionConverter
import org.apache.spark.SparkConf
-import org.apache.spark.sql.{GlutenSQLTestsTrait, Row}
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{AbstractDataType, CalendarIntervalType,
DayTimeIntervalType, TypeCollection, YearMonthIntervalType}
+import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
case class CustomAdd(
left: Expression,
right: Expression,
- failOnError: Boolean = SQLConf.get.ansiEnabled)
- extends BinaryArithmetic {
+ override val failOnError: Boolean = SQLConf.get.ansiEnabled)
+ extends BinaryArithmetic
+ with CustomAdd.Compatibility {
def this(left: Expression, right: Expression) = this(left, right,
SQLConf.get.ansiEnabled)
@@ -69,9 +70,18 @@ case class CustomAdd(
newLeft: Expression,
newRight: Expression
): CustomAdd = copy(left = newLeft, right = newRight)
+
+ override protected val evalMode: EvalMode.Value = EvalMode.LEGACY
+}
+
+object CustomAdd {
+ trait Compatibility {
+ protected val evalMode: EvalMode.Value
+ }
}
-class GlutenCustomerExpressionTransformerSuite extends GlutenSQLTestsTrait {
+class GlutenClickhouseCustomerExpressionTransformerSuite
+ extends GlutenClickHouseWholeStageTransformerSuite {
override def sparkConf: SparkConf = {
super.sparkConf
@@ -92,7 +102,7 @@ class GlutenCustomerExpressionTransformerSuite extends
GlutenSQLTestsTrait {
)
}
- testGluten("test custom expression transformer") {
+ test("test custom expression transformer") {
spark
.createDataFrame(Seq((1, 1.1), (2, 2.2)))
.createOrReplaceTempView("custom_table")
diff --git a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala
b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala
index f775d78a1..6d0cdd0f8 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala
@@ -20,7 +20,6 @@ import
org.apache.gluten.GlutenConfig.GLUTEN_DEFAULT_SESSION_TIMEZONE_KEY
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.events.GlutenBuildInfoEvent
import org.apache.gluten.exception.GlutenException
-import org.apache.gluten.expression.ExpressionMappings
import
org.apache.gluten.extension.GlutenSessionExtensions.{GLUTEN_SESSION_EXTENSION_NAME,
SPARK_SESSION_EXTS_KEY}
import org.apache.gluten.test.TestStats
import org.apache.gluten.utils.TaskListener
@@ -32,7 +31,6 @@ import org.apache.spark.listener.GlutenListenerFactory
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.execution.ui.GlutenEventUtils
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.utils.ExpressionUtil
import org.apache.spark.util.{SparkResourceUtil, TaskResources}
import java.util
@@ -73,14 +71,6 @@ private[gluten] class GlutenDriverPlugin extends
DriverPlugin with Logging {
BackendsApiManager.getListenerApiInstance.onDriverStart(sc, pluginContext)
GlutenListenerFactory.addToSparkListenerBus(sc)
- val expressionExtensionTransformer =
ExpressionUtil.extendedExpressionTransformer(
- conf.get(GlutenConfig.GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF, "")
- )
-
- if (expressionExtensionTransformer != null) {
- ExpressionMappings.expressionExtensionTransformer =
expressionExtensionTransformer
- }
-
Collections.emptyMap()
}
@@ -275,7 +265,7 @@ private[gluten] class GlutenDriverPlugin extends
DriverPlugin with Logging {
}
private[gluten] class GlutenExecutorPlugin extends ExecutorPlugin {
- private val taskListeners: Seq[TaskListener] = Array(TaskResources)
+ private val taskListeners: Seq[TaskListener] = Seq(TaskResources)
/** Initialize the executor plugin. */
override def init(ctx: PluginContext, extraConf: util.Map[String, String]):
Unit = {
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index fb87a9ac9..a55926d76 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -448,9 +448,16 @@ trait SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, exprs, original)
}
- /** Define backend specfic expression mappings. */
+ /** Define backend-specific expression mappings. */
def extraExpressionMappings: Seq[Sig] = Seq.empty
+ /** Define backend-specific expression converter. */
+ def extraExpressionConverter(
+ substraitExprName: String,
+ expr: Expression,
+ attributeSeq: Seq[Attribute]): Option[ExpressionTransformer] =
+ None
+
/**
* Define whether the join operator is fallback because of the join operator
is not supported by
* backend
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala
index 6ac2c67eb..bd73b7b7a 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala
@@ -29,32 +29,18 @@ object AggregateFunctionsBuilder {
val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
// First handle the custom aggregate functions
- val (substraitAggFuncName, inputTypes) =
- if (
-
ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
- aggregateFunc.getClass)
- ) {
- val (substraitAggFuncName, inputTypes) =
-
ExpressionMappings.expressionExtensionTransformer.buildCustomAggregateFunction(
- aggregateFunc)
- assert(substraitAggFuncName.isDefined)
- (substraitAggFuncName.get, inputTypes)
- } else {
- val substraitAggFuncName = getSubstraitFunctionName(aggregateFunc)
+ val substraitAggFuncName = getSubstraitFunctionName(aggregateFunc)
- // Check whether each backend supports this aggregate function.
- if (
- !BackendsApiManager.getValidatorApiInstance.doExprValidate(
- substraitAggFuncName,
- aggregateFunc)
- ) {
- throw new GlutenNotSupportException(
- s"Aggregate function not supported for $aggregateFunc.")
- }
+ // Check whether each backend supports this aggregate function.
+ if (
+ !BackendsApiManager.getValidatorApiInstance.doExprValidate(
+ substraitAggFuncName,
+ aggregateFunc)
+ ) {
+ throw new GlutenNotSupportException(s"Aggregate function not supported
for $aggregateFunc.")
+ }
- val inputTypes: Seq[DataType] = aggregateFunc.children.map(child =>
child.dataType)
- (substraitAggFuncName, inputTypes)
- }
+ val inputTypes: Seq[DataType] = aggregateFunc.children.map(child =>
child.dataType)
ExpressionBuilder.newScalarFunction(
functionMap,
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index d5ca31bb5..c5ba3a8a7 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -43,16 +43,14 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
exprs: Seq[Expression],
attributeSeq: Seq[Attribute]): Seq[ExpressionTransformer] = {
val expressionsMap = ExpressionMappings.expressionsMap
- exprs.map {
- expr => replaceWithExpressionTransformerInternal(expr, attributeSeq,
expressionsMap)
- }
+ exprs.map(expr => replaceWithExpressionTransformer0(expr, attributeSeq,
expressionsMap))
}
def replaceWithExpressionTransformer(
expr: Expression,
attributeSeq: Seq[Attribute]): ExpressionTransformer = {
val expressionsMap = ExpressionMappings.expressionsMap
- replaceWithExpressionTransformerInternal(expr, attributeSeq,
expressionsMap)
+ replaceWithExpressionTransformer0(expr, attributeSeq, expressionsMap)
}
private def replacePythonUDFWithExpressionTransformer(
@@ -64,8 +62,7 @@ object ExpressionConverter extends SQLConfHelper with Logging
{
case Some(name) =>
GenericExpressionTransformer(
name,
- udf.children.map(
- replaceWithExpressionTransformerInternal(_, attributeSeq,
expressionsMap)),
+ udf.children.map(replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap)),
udf)
case _ =>
throw new GlutenNotSupportException(s"Not supported python udf: $udf.")
@@ -84,8 +81,7 @@ object ExpressionConverter extends SQLConfHelper with Logging
{
case Some(name) =>
GenericExpressionTransformer(
name,
- udf.children.map(
- replaceWithExpressionTransformerInternal(_, attributeSeq,
expressionsMap)),
+ udf.children.map(replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap)),
udf)
case _ =>
throw new GlutenNotSupportException(s"Not supported scala udf: $udf.")
@@ -108,13 +104,13 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
)
val leftChild =
- replaceWithExpressionTransformerInternal(left, attributeSeq,
expressionsMap)
+ replaceWithExpressionTransformer0(left, attributeSeq, expressionsMap)
val rightChild =
- replaceWithExpressionTransformerInternal(right, attributeSeq,
expressionsMap)
+ replaceWithExpressionTransformer0(right, attributeSeq, expressionsMap)
DecimalArithmeticExpressionTransformer(substraitName, leftChild,
rightChild, resultType, b)
}
- private def replaceWithExpressionTransformerInternal(
+ private def replaceWithExpressionTransformer0(
expr: Expression,
attributeSeq: Seq[Attribute],
expressionsMap: Map[Class[_], String]): ExpressionTransformer = {
@@ -139,14 +135,12 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
case "decode" =>
return GenericExpressionTransformer(
ExpressionNames.URL_DECODE,
- child.map(
- replaceWithExpressionTransformerInternal(_, attributeSeq,
expressionsMap)),
+ child.map(replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap)),
i)
case "encode" =>
return GenericExpressionTransformer(
ExpressionNames.URL_ENCODE,
- child.map(
- replaceWithExpressionTransformerInternal(_, attributeSeq,
expressionsMap)),
+ child.map(replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap)),
i)
}
}
@@ -154,61 +148,61 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
}
val substraitExprName: String = getAndCheckSubstraitName(expr,
expressionsMap)
-
+ val backendConverted =
BackendsApiManager.getSparkPlanExecApiInstance.extraExpressionConverter(
+ substraitExprName,
+ expr,
+ attributeSeq)
+ if (backendConverted.isDefined) {
+ return backendConverted.get
+ }
expr match {
- case extendedExpr
- if
ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
- extendedExpr.getClass) =>
- // Use extended expression transformer to replace custom expression
first
- ExpressionMappings.expressionExtensionTransformer
- .replaceWithExtensionExpressionTransformer(substraitExprName,
extendedExpr, attributeSeq)
case c: CreateArray =>
val children =
- c.children.map(replaceWithExpressionTransformerInternal(_,
attributeSeq, expressionsMap))
+ c.children.map(replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap))
CreateArrayTransformer(substraitExprName, children, c)
case g: GetArrayItem =>
BackendsApiManager.getSparkPlanExecApiInstance.genGetArrayItemTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(g.left, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(g.right, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(g.left, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(g.right, attributeSeq,
expressionsMap),
g
)
case c: CreateMap =>
val children =
- c.children.map(replaceWithExpressionTransformerInternal(_,
attributeSeq, expressionsMap))
+ c.children.map(replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap))
CreateMapTransformer(substraitExprName, children, c)
case g: GetMapValue =>
BackendsApiManager.getSparkPlanExecApiInstance.genGetMapValueTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(g.child, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(g.key, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(g.child, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(g.key, attributeSeq,
expressionsMap),
g
)
case m: MapEntries =>
BackendsApiManager.getSparkPlanExecApiInstance.genMapEntriesTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(m.child, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(m.child, attributeSeq,
expressionsMap),
m)
case e: Explode =>
ExplodeTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(e.child, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(e.child, attributeSeq,
expressionsMap),
e)
case p: PosExplode =>
BackendsApiManager.getSparkPlanExecApiInstance.genPosExplodeTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(p.child, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(p.child, attributeSeq,
expressionsMap),
p,
attributeSeq)
case i: Inline =>
BackendsApiManager.getSparkPlanExecApiInstance.genInlineTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(i.child, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(i.child, attributeSeq,
expressionsMap),
i)
case a: Alias =>
BackendsApiManager.getSparkPlanExecApiInstance.genAliasTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(a.child, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.child, attributeSeq,
expressionsMap),
a)
case a: AttributeReference =>
if (attributeSeq == null) {
@@ -233,14 +227,14 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
case d: DateDiff =>
BackendsApiManager.getSparkPlanExecApiInstance.genDateDiffTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(d.endDate, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(d.startDate, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(d.endDate, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(d.startDate, attributeSeq,
expressionsMap),
d
)
case r: Round if r.child.dataType.isInstanceOf[DecimalType] =>
DecimalRoundTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(r.child, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(r.child, attributeSeq,
expressionsMap),
r)
case t: ToUnixTimestamp =>
// The failOnError depends on the config for ANSI. ANSI is not
supported currently.
@@ -248,8 +242,8 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
GenericExpressionTransformer(
substraitExprName,
Seq(
- replaceWithExpressionTransformerInternal(t.timeExp, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(t.format, attributeSeq,
expressionsMap)
+ replaceWithExpressionTransformer0(t.timeExp, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(t.format, attributeSeq,
expressionsMap)
),
t
)
@@ -257,33 +251,33 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
GenericExpressionTransformer(
substraitExprName,
Seq(
- replaceWithExpressionTransformerInternal(u.timeExp, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(u.format, attributeSeq,
expressionsMap)
+ replaceWithExpressionTransformer0(u.timeExp, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(u.format, attributeSeq,
expressionsMap)
),
ToUnixTimestamp(u.timeExp, u.format, u.timeZoneId, u.failOnError)
)
case t: TruncTimestamp =>
BackendsApiManager.getSparkPlanExecApiInstance.genTruncTimestampTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(t.format, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(t.timestamp, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(t.format, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(t.timestamp, attributeSeq,
expressionsMap),
t.timeZoneId,
t
)
case m: MonthsBetween =>
MonthsBetweenTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(m.date1, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(m.date2, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(m.roundOff, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(m.date1, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(m.date2, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(m.roundOff, attributeSeq,
expressionsMap),
m
)
case i: If =>
IfTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(i.predicate, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(i.trueValue, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(i.falseValue, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(i.predicate, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(i.trueValue, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(i.falseValue, attributeSeq,
expressionsMap),
i
)
case cw: CaseWhen =>
@@ -293,14 +287,14 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
expr =>
{
(
- replaceWithExpressionTransformerInternal(expr._1,
attributeSeq, expressionsMap),
- replaceWithExpressionTransformerInternal(expr._2,
attributeSeq, expressionsMap))
+ replaceWithExpressionTransformer0(expr._1, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(expr._2, attributeSeq,
expressionsMap))
}
},
cw.elseValue.map {
expr =>
{
- replaceWithExpressionTransformerInternal(expr, attributeSeq,
expressionsMap)
+ replaceWithExpressionTransformer0(expr, attributeSeq,
expressionsMap)
}
},
cw
@@ -312,12 +306,12 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
}
InTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(i.value, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(i.value, attributeSeq,
expressionsMap),
i)
case i: InSet =>
InSetTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(i.child, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(i.child, attributeSeq,
expressionsMap),
i)
case s: ScalarSubquery =>
ScalarSubqueryTransformer(substraitExprName, s)
@@ -327,7 +321,7 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
BackendsApiManager.getSparkPlanExecApiInstance.genCastWithNewChild(c)
CastTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(newCast.child,
attributeSeq, expressionsMap),
+ replaceWithExpressionTransformer0(newCast.child, attributeSeq,
expressionsMap),
newCast)
case s: String2TrimExpression =>
val (srcStr, trimStr) = s match {
@@ -336,9 +330,9 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
case StringTrimRight(srcStr, trimStr) => (srcStr, trimStr)
}
val children = trimStr
- .map(replaceWithExpressionTransformerInternal(_, attributeSeq,
expressionsMap))
+ .map(replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap))
.toSeq ++
- Seq(replaceWithExpressionTransformerInternal(srcStr, attributeSeq,
expressionsMap))
+ Seq(replaceWithExpressionTransformer0(srcStr, attributeSeq,
expressionsMap))
GenericExpressionTransformer(
substraitExprName,
children,
@@ -348,23 +342,20 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
BackendsApiManager.getSparkPlanExecApiInstance.genHashExpressionTransformer(
substraitExprName,
m.children.map(
- expr => replaceWithExpressionTransformerInternal(expr,
attributeSeq, expressionsMap)),
+ expr => replaceWithExpressionTransformer0(expr, attributeSeq,
expressionsMap)),
m)
case getStructField: GetStructField =>
// Different backends may have different result.
BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(
- getStructField.child,
- attributeSeq,
- expressionsMap),
+ replaceWithExpressionTransformer0(getStructField.child,
attributeSeq, expressionsMap),
getStructField.ordinal,
getStructField)
case getArrayStructFields: GetArrayStructFields =>
GenericExpressionTransformer(
substraitExprName,
Seq(
- replaceWithExpressionTransformerInternal(
+ replaceWithExpressionTransformer0(
getArrayStructFields.child,
attributeSeq,
expressionsMap),
@@ -374,26 +365,26 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
case t: StringTranslate =>
BackendsApiManager.getSparkPlanExecApiInstance.genStringTranslateTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(t.srcExpr, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(t.matchingExpr,
attributeSeq, expressionsMap),
- replaceWithExpressionTransformerInternal(t.replaceExpr,
attributeSeq, expressionsMap),
+ replaceWithExpressionTransformer0(t.srcExpr, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(t.matchingExpr, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(t.replaceExpr, attributeSeq,
expressionsMap),
t
)
case r: RegExpReplace =>
BackendsApiManager.getSparkPlanExecApiInstance.genRegexpReplaceTransformer(
substraitExprName,
Seq(
- replaceWithExpressionTransformerInternal(r.subject, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(r.regexp, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(r.rep, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(r.pos, attributeSeq,
expressionsMap)
+ replaceWithExpressionTransformer0(r.subject, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(r.regexp, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(r.rep, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(r.pos, attributeSeq,
expressionsMap)
),
r
)
case size: Size =>
// Covers Spark ArraySize which is replaced by Size(child, false).
val child =
- replaceWithExpressionTransformerInternal(size.child, attributeSeq,
expressionsMap)
+ replaceWithExpressionTransformer0(size.child, attributeSeq,
expressionsMap)
GenericExpressionTransformer(
substraitExprName,
Seq(child, LiteralTransformer(size.legacySizeOfNull)),
@@ -402,7 +393,7 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
BackendsApiManager.getSparkPlanExecApiInstance.genNamedStructTransformer(
substraitExprName,
namedStruct.children.map(
- replaceWithExpressionTransformerInternal(_, attributeSeq,
expressionsMap)),
+ replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap)),
namedStruct,
attributeSeq)
case namedLambdaVariable: NamedLambdaVariable =>
@@ -415,64 +406,57 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
case lambdaFunction: LambdaFunction =>
LambdaFunctionTransformer(
substraitExprName,
- function = replaceWithExpressionTransformerInternal(
+ function = replaceWithExpressionTransformer0(
lambdaFunction.function,
attributeSeq,
expressionsMap),
arguments = lambdaFunction.arguments.map(
- replaceWithExpressionTransformerInternal(_, attributeSeq,
expressionsMap)),
+ replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap)),
original = lambdaFunction
)
case j: JsonTuple =>
val children =
- j.children.map(replaceWithExpressionTransformerInternal(_,
attributeSeq, expressionsMap))
+ j.children.map(replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap))
JsonTupleExpressionTransformer(substraitExprName, children, j)
case l: Like =>
BackendsApiManager.getSparkPlanExecApiInstance.genLikeTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(l.left, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(l.right, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(l.left, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(l.right, attributeSeq,
expressionsMap),
l
)
case m: MakeDecimal =>
GenericExpressionTransformer(
substraitExprName,
Seq(
- replaceWithExpressionTransformerInternal(m.child, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(m.child, attributeSeq,
expressionsMap),
LiteralTransformer(m.nullOnOverflow)),
m
)
case _: NormalizeNaNAndZero | _: PromotePrecision | _: TaggingExpression
=>
ChildTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(
- expr.children.head,
- attributeSeq,
- expressionsMap),
+ replaceWithExpressionTransformer0(expr.children.head, attributeSeq,
expressionsMap),
expr
)
case _: GetDateField | _: GetTimeField =>
ExtractDateTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(
- expr.children.head,
- attributeSeq,
- expressionsMap),
+ replaceWithExpressionTransformer0(expr.children.head, attributeSeq,
expressionsMap),
expr)
case _: StringToMap =>
BackendsApiManager.getSparkPlanExecApiInstance.genStringToMapTransformer(
substraitExprName,
- expr.children.map(
- replaceWithExpressionTransformerInternal(_, attributeSeq,
expressionsMap)),
+ expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap)),
expr)
case CheckOverflow(b: BinaryArithmetic, decimalType, _)
if !BackendsApiManager.getSettings.transformCheckOverflow &&
DecimalArithmeticUtil.isDecimalArithmetic(b) =>
DecimalArithmeticUtil.checkAllowDecimalArithmetic()
val leftChild =
- replaceWithExpressionTransformerInternal(b.left, attributeSeq,
expressionsMap)
+ replaceWithExpressionTransformer0(b.left, attributeSeq,
expressionsMap)
val rightChild =
- replaceWithExpressionTransformerInternal(b.right, attributeSeq,
expressionsMap)
+ replaceWithExpressionTransformer0(b.right, attributeSeq,
expressionsMap)
DecimalArithmeticExpressionTransformer(
getAndCheckSubstraitName(b, expressionsMap),
leftChild,
@@ -482,15 +466,14 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
case c: CheckOverflow =>
CheckOverflowTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(c.child, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(c.child, attributeSeq,
expressionsMap),
c)
case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b)
=>
DecimalArithmeticUtil.checkAllowDecimalArithmetic()
if (!BackendsApiManager.getSettings.transformCheckOverflow) {
GenericExpressionTransformer(
substraitExprName,
- expr.children.map(
- replaceWithExpressionTransformerInternal(_, attributeSeq,
expressionsMap)),
+ expr.children.map(replaceWithExpressionTransformer0(_,
attributeSeq, expressionsMap)),
expr
)
} else {
@@ -501,14 +484,14 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
case n: NaNvl =>
BackendsApiManager.getSparkPlanExecApiInstance.genNaNvlTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(n.left, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(n.right, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(n.left, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(n.right, attributeSeq,
expressionsMap),
n
)
case m: MakeTimestamp =>
BackendsApiManager.getSparkPlanExecApiInstance.genMakeTimestampTransformer(
substraitExprName,
- m.children.map(replaceWithExpressionTransformerInternal(_,
attributeSeq, expressionsMap)),
+ m.children.map(replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap)),
m)
case timestampAdd if
timestampAdd.getClass.getSimpleName.equals("TimestampAdd") =>
// for spark3.3
@@ -520,111 +503,99 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
TimestampAddTransformer(
substraitExprName,
extract.get.head,
- replaceWithExpressionTransformerInternal(add.left, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(add.right, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(add.left, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(add.right, attributeSeq,
expressionsMap),
extract.get.last,
add
)
case e: Transformable =>
val childrenTransformers =
- e.children.map(replaceWithExpressionTransformerInternal(_,
attributeSeq, expressionsMap))
+ e.children.map(replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap))
e.getTransformer(childrenTransformers)
case u: Uuid =>
BackendsApiManager.getSparkPlanExecApiInstance.genUuidTransformer(substraitExprName,
u)
case f: ArrayFilter =>
BackendsApiManager.getSparkPlanExecApiInstance.genArrayFilterTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(f.argument, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(f.function, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(f.argument, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(f.function, attributeSeq,
expressionsMap),
f
)
case arrayTransform: ArrayTransform =>
BackendsApiManager.getSparkPlanExecApiInstance.genArrayTransformTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(
- arrayTransform.argument,
- attributeSeq,
- expressionsMap),
- replaceWithExpressionTransformerInternal(
- arrayTransform.function,
- attributeSeq,
- expressionsMap),
+ replaceWithExpressionTransformer0(arrayTransform.argument,
attributeSeq, expressionsMap),
+ replaceWithExpressionTransformer0(arrayTransform.function,
attributeSeq, expressionsMap),
arrayTransform
)
case arraySort: ArraySort =>
BackendsApiManager.getSparkPlanExecApiInstance.genArraySortTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(
- arraySort.argument,
- attributeSeq,
- expressionsMap),
- replaceWithExpressionTransformerInternal(
- arraySort.function,
- attributeSeq,
- expressionsMap),
+ replaceWithExpressionTransformer0(arraySort.argument, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(arraySort.function, attributeSeq,
expressionsMap),
arraySort
)
case tryEval @ TryEval(a: Add) =>
BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(a.left, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(a.right, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.left, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.right, attributeSeq,
expressionsMap),
tryEval,
ExpressionNames.CHECKED_ADD
)
case tryEval @ TryEval(a: Subtract) =>
BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(a.left, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(a.right, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.left, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.right, attributeSeq,
expressionsMap),
tryEval,
ExpressionNames.CHECKED_SUBTRACT
)
case tryEval @ TryEval(a: Divide) =>
BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(a.left, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(a.right, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.left, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.right, attributeSeq,
expressionsMap),
tryEval,
ExpressionNames.CHECKED_DIVIDE
)
case tryEval @ TryEval(a: Multiply) =>
BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(a.left, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(a.right, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.left, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.right, attributeSeq,
expressionsMap),
tryEval,
ExpressionNames.CHECKED_MULTIPLY
)
case a: Add =>
BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(a.left, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(a.right, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.left, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.right, attributeSeq,
expressionsMap),
a,
ExpressionNames.CHECKED_ADD
)
case a: Subtract =>
BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(a.left, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(a.right, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.left, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.right, attributeSeq,
expressionsMap),
a,
ExpressionNames.CHECKED_SUBTRACT
)
case a: Multiply =>
BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(a.left, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(a.right, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.left, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.right, attributeSeq,
expressionsMap),
a,
ExpressionNames.CHECKED_MULTIPLY
)
case a: Divide =>
BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(a.left, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(a.right, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.left, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.right, attributeSeq,
expressionsMap),
a,
ExpressionNames.CHECKED_DIVIDE
)
@@ -632,34 +603,34 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
// This is a placeholder to handle try_eval(other expressions).
BackendsApiManager.getSparkPlanExecApiInstance.genTryEvalTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(tryEval.child,
attributeSeq, expressionsMap),
+ replaceWithExpressionTransformer0(tryEval.child, attributeSeq,
expressionsMap),
tryEval
)
case a: ArrayForAll =>
BackendsApiManager.getSparkPlanExecApiInstance.genArrayForAllTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(a.argument, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(a.function, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.argument, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.function, attributeSeq,
expressionsMap),
a
)
case a: ArrayExists =>
BackendsApiManager.getSparkPlanExecApiInstance.genArrayExistsTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(a.argument, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(a.function, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.argument, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(a.function, attributeSeq,
expressionsMap),
a
)
case s: Shuffle =>
GenericExpressionTransformer(
substraitExprName,
Seq(
- replaceWithExpressionTransformerInternal(s.child, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformer0(s.child, attributeSeq,
expressionsMap),
LiteralTransformer(Literal(s.randomSeed.get))),
s)
case c: PreciseTimestampConversion =>
BackendsApiManager.getSparkPlanExecApiInstance.genPreciseTimestampConversionTransformer(
substraitExprName,
- Seq(replaceWithExpressionTransformerInternal(c.child, attributeSeq,
expressionsMap)),
+ Seq(replaceWithExpressionTransformer0(c.child, attributeSeq,
expressionsMap)),
c
)
case t: TransformKeys =>
@@ -674,7 +645,7 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
}
GenericExpressionTransformer(
substraitExprName,
- t.children.map(replaceWithExpressionTransformerInternal(_,
attributeSeq, expressionsMap)),
+ t.children.map(replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap)),
t
)
case e: EulerNumber =>
@@ -700,8 +671,7 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
case expr =>
GenericExpressionTransformer(
substraitExprName,
- expr.children.map(
- replaceWithExpressionTransformerInternal(_, attributeSeq,
expressionsMap)),
+ expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap)),
expr
)
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
index f2bb4a906..38f9de629 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
@@ -19,7 +19,6 @@ package org.apache.gluten.expression
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.ExpressionNames._
-import org.apache.gluten.extension.{DefaultExpressionExtensionTransformer,
ExpressionExtensionTrait}
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.sql.catalyst.expressions._
@@ -338,22 +337,19 @@ object ExpressionMappings {
def expressionsMap: Map[Class[_], String] = {
val blacklist = GlutenConfig.getConf.expressionBlacklist
- val supportedExprs = defaultExpressionsMap ++
- expressionExtensionTransformer.extensionExpressionsMapping
- if (blacklist.isEmpty) {
- supportedExprs
- } else {
- supportedExprs.filterNot(kv => blacklist.contains(kv._2))
- }
+ val filtered = (defaultExpressionsMap ++ toMap(
+
BackendsApiManager.getSparkPlanExecApiInstance.extraExpressionMappings)).filterNot(
+ kv => blacklist.contains(kv._2))
+ filtered
}
private lazy val defaultExpressionsMap: Map[Class[_], String] = {
- (SCALAR_SIGS ++ AGGREGATE_SIGS ++ WINDOW_SIGS ++
- BackendsApiManager.getSparkPlanExecApiInstance.extraExpressionMappings)
+ toMap(SCALAR_SIGS ++ AGGREGATE_SIGS ++ WINDOW_SIGS)
+ }
+
+ private def toMap(sigs: Seq[Sig]): Map[Class[_], String] = {
+ sigs
.map(s => (s.expClass, s.name))
.toMap[Class[_], String]
}
-
- var expressionExtensionTransformer: ExpressionExtensionTrait =
- DefaultExpressionExtensionTransformer()
}
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 5c2833de4..9b2e2ab95 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -34,7 +34,7 @@ import
org.apache.spark.sql.execution.datasources.text.{GlutenTextV1Suite, Glute
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.execution.exchange.GlutenEnsureRequirementsSuite
import org.apache.spark.sql.execution.joins.{GlutenExistenceJoinSuite,
GlutenInnerJoinSuite, GlutenOuterJoinSuite}
-import
org.apache.spark.sql.extension.{GlutenCustomerExpressionTransformerSuite,
GlutenCustomerExtensionSuite, GlutenSessionExtensionSuite}
+import org.apache.spark.sql.extension.{GlutenCustomerExtensionSuite,
GlutenSessionExtensionSuite}
import org.apache.spark.sql.hive.execution.GlutenHiveSQLQueryCHSuite
import org.apache.spark.sql.sources._
import org.apache.spark.sql.statistics.SparkFunctionStatistics
@@ -2133,7 +2133,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("right outer join with unique keys using ShuffledHashJoin
(whole-stage-codegen on)")
.exclude("right outer join with unique keys using SortMergeJoin
(whole-stage-codegen off)")
.exclude("right outer join with unique keys using SortMergeJoin
(whole-stage-codegen on)")
- enableSuite[GlutenCustomerExpressionTransformerSuite]
enableSuite[GlutenCustomerExtensionSuite]
enableSuite[GlutenSessionExtensionSuite]
enableSuite[GlutenBucketedReadWithoutHiveSupportSuite]
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index c4799366d..e064f2afc 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -34,7 +34,7 @@ import
org.apache.spark.sql.execution.datasources.text.{GlutenTextV1Suite, Glute
import org.apache.spark.sql.execution.datasources.v2.GlutenFileTableSuite
import org.apache.spark.sql.execution.exchange.GlutenEnsureRequirementsSuite
import org.apache.spark.sql.execution.joins.{GlutenBroadcastJoinSuite,
GlutenExistenceJoinSuite, GlutenInnerJoinSuite, GlutenOuterJoinSuite}
-import
org.apache.spark.sql.extension.{GlutenCollapseProjectExecTransformerSuite,
GlutenCustomerExpressionTransformerSuite, GlutenSessionExtensionSuite}
+import
org.apache.spark.sql.extension.{GlutenCollapseProjectExecTransformerSuite,
GlutenSessionExtensionSuite}
import org.apache.spark.sql.hive.execution.GlutenHiveSQLQuerySuite
import org.apache.spark.sql.sources._
@@ -44,7 +44,6 @@ import org.apache.spark.sql.sources._
class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenSessionExtensionSuite]
- enableSuite[GlutenCustomerExpressionTransformerSuite]
enableSuite[GlutenDataFrameAggregateSuite]
.exclude(
diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
index bb0e683c2..5c032d4b0 100644
--- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
@@ -610,6 +610,7 @@ object GlutenConfig {
val GLUTEN_SUPPORTED_PYTHON_UDFS = "spark.gluten.supported.python.udfs"
val GLUTEN_SUPPORTED_SCALA_UDFS = "spark.gluten.supported.scala.udfs"
+ // FIXME: This only works with CH backend.
val GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF =
"spark.gluten.sql.columnar.extended.expressions.transformer"
@@ -1686,6 +1687,7 @@ object GlutenConfig {
.stringConf
.createWithDefaultString("")
+ // FIXME: This only works with CH backend.
val EXTENDED_EXPRESSION_TRAN_CONF =
buildConf(GLUTEN_EXTENDED_EXPRESSION_TRAN_CONF)
.doc("A class for the extended expressions transformer.")
diff --git
a/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala
b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala
new file mode 100644
index 000000000..0a3c63ccd
--- /dev/null
+++
b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.internal.SQLConf
+
+/** For compatibility with Spark version <= 3.3. The class was added in
vanilla Spark since 3.4. */
+object EvalMode extends Enumeration {
+ val LEGACY, ANSI, TRY = Value
+
+ def fromSQLConf(conf: SQLConf): Value = if (conf.ansiEnabled) {
+ ANSI
+ } else {
+ LEGACY
+ }
+
+ def fromBoolean(ansiEnabled: Boolean): Value = if (ansiEnabled) {
+ ANSI
+ } else {
+ LEGACY
+ }
+}
diff --git
a/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala
b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala
new file mode 100644
index 000000000..0a3c63ccd
--- /dev/null
+++
b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/expressions/EvalMode.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.internal.SQLConf
+
+/** For compatibility with Spark version <= 3.3. The class was added in
vanilla Spark since 3.4. */
+object EvalMode extends Enumeration {
+ val LEGACY, ANSI, TRY = Value
+
+ def fromSQLConf(conf: SQLConf): Value = if (conf.ansiEnabled) {
+ ANSI
+ } else {
+ LEGACY
+ }
+
+ def fromBoolean(ansiEnabled: Boolean): Value = if (ansiEnabled) {
+ ANSI
+ } else {
+ LEGACY
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]