This is an automated email from the ASF dual-hosted git repository.
ulyssesyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 683232b74 [CORE] Refactor ExpressionTransformer (#5796)
683232b74 is described below
commit 683232b746263ac6d720a4bf3c61474fe3d39e1a
Author: Xiduo You <[email protected]>
AuthorDate: Tue May 21 15:18:08 2024 +0800
[CORE] Refactor ExpressionTransformer (#5796)
---
.../clickhouse/CHSparkPlanExecApi.scala | 20 +++
.../backendsapi/clickhouse/CHTransformerApi.scala | 12 --
.../expression/CHExpressionTransformer.scala | 44 +++---
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 26 +++-
.../backendsapi/velox/VeloxTransformerApi.scala | 13 --
.../gluten/expression/ExpressionTransformer.scala | 48 +++---
.../apache/spark/sql/expression/UDFResolver.scala | 25 +---
.../gluten/backendsapi/SparkPlanExecApi.scala | 79 +++-------
.../apache/gluten/backendsapi/TransformerApi.scala | 7 -
.../expression/ArrayExpressionTransformer.scala | 21 +--
.../expression/BoundReferenceTransformer.scala | 29 ----
.../gluten/expression/ConditionalTransformer.scala | 13 +-
.../DateTimeExpressionsTransformer.scala | 164 +++-----------------
.../expression/DecimalRoundTransformer.scala | 27 +---
.../gluten/expression/ExpressionConverter.scala | 147 +++++++-----------
.../gluten/expression/ExpressionTransformer.scala | 71 ++++++++-
.../expression/GenericExpressionTransformer.scala | 46 ------
.../expression/HashExpressionTransformer.scala | 44 ------
.../JsonTupleExpressionTransformer.scala | 2 +-
.../expression/LambdaFunctionTransformer.scala | 22 +--
.../gluten/expression/LiteralTransformer.scala | 28 ----
.../expression/MapExpressionTransformer.scala | 45 +-----
.../expression/NamedExpressionsTransformer.scala | 58 ++-----
.../PredicateExpressionTransformer.scala | 82 ++--------
.../expression/ScalarSubqueryTransformer.scala | 6 +-
.../expression/StringExpressionTransformer.scala | 49 ------
.../expression/StructExpressionTransformer.scala | 54 -------
.../expression/TimestampAddTransformer.scala | 53 -------
.../expression/UnaryExpressionTransformer.scala | 166 ++++-----------------
.../extension/CustomerExpressionTransformer.scala | 26 +---
30 files changed, 331 insertions(+), 1096 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index d6e323679..45f90719f 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -627,6 +627,15 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
CHSizeExpressionTransformer(substraitExprName, child, original)
}
+ override def genLikeTransformer(
+ substraitExprName: String,
+ left: ExpressionTransformer,
+ right: ExpressionTransformer,
+ original: Like): ExpressionTransformer = {
+ // CH backend does not support escapeChar, so skip it here.
+ GenericExpressionTransformer(substraitExprName, Seq(left, right), original)
+ }
+
/** Generate an ExpressionTransformer to transform TruncTimestamp expression
for CH. */
override def genTruncTimestampTransformer(
substraitExprName: String,
@@ -637,6 +646,17 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
CHTruncTimestampTransformer(substraitExprName, format, timestamp,
timeZoneId, original)
}
+ override def genDateDiffTransformer(
+ substraitExprName: String,
+ endDate: ExpressionTransformer,
+ startDate: ExpressionTransformer,
+ original: DateDiff): ExpressionTransformer = {
+ GenericExpressionTransformer(
+ substraitExprName,
+ Seq(LiteralTransformer("day"), startDate, endDate),
+ original)
+ }
+
override def genPosExplodeTransformer(
substraitExprName: String,
child: ExpressionTransformer,
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
index ee46d685c..c75cf4788 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
@@ -203,18 +203,6 @@ class CHTransformerApi extends TransformerApi with Logging
{
}
- override def createDateDiffParamList(
- start: ExpressionNode,
- end: ExpressionNode): Iterable[ExpressionNode] = {
- List(ExpressionBuilder.makeStringLiteral("day"), start, end)
- }
-
- override def createLikeParamList(
- left: ExpressionNode,
- right: ExpressionNode,
- escapeChar: ExpressionNode): Iterable[ExpressionNode] =
- List(left, right)
-
override def createCheckOverflowExprNode(
args: java.lang.Object,
substraitExprName: String,
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala
index 6403471c7..5ca4e0233 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala
@@ -32,17 +32,12 @@ import java.util.Locale
case class CHSizeExpressionTransformer(
substraitExprName: String,
- child: ExpressionTransformer,
+ expr: ExpressionTransformer,
original: Size)
- extends ExpressionTransformerWithOrigin {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- // Pass legacyLiteral as second argument in substrait function
- val legacyLiteral = new Literal(original.legacySizeOfNull, BooleanType)
- val legacyTransformer = new LiteralTransformer(legacyLiteral)
- GenericExpressionTransformer(substraitExprName, Seq(child,
legacyTransformer), original)
- .doTransform(args)
- }
+ extends BinaryExpressionTransformer {
+ override def left: ExpressionTransformer = expr
+ // Pass legacyLiteral as second argument in substrait function
+ override def right: ExpressionTransformer =
LiteralTransformer(original.legacySizeOfNull)
}
case class CHTruncTimestampTransformer(
@@ -51,7 +46,8 @@ case class CHTruncTimestampTransformer(
timestamp: ExpressionTransformer,
timeZoneId: Option[String] = None,
original: TruncTimestamp)
- extends ExpressionTransformerWithOrigin {
+ extends ExpressionTransformer {
+ override def children: Seq[ExpressionTransformer] = format :: timestamp ::
Nil
override def doTransform(args: java.lang.Object): ExpressionNode = {
// The format must be constant string in the function date_trunc of ch.
@@ -126,7 +122,8 @@ case class CHStringTranslateTransformer(
matchingExpr: ExpressionTransformer,
replaceExpr: ExpressionTransformer,
original: StringTranslate)
- extends ExpressionTransformerWithOrigin {
+ extends ExpressionTransformer {
+ override def children: Seq[ExpressionTransformer] = srcExpr :: matchingExpr
:: replaceExpr :: Nil
override def doTransform(args: java.lang.Object): ExpressionNode = {
// In CH, translateUTF8 requires matchingExpr and replaceExpr argument
have the same length
@@ -145,11 +142,7 @@ case class CHStringTranslateTransformer(
throw new GlutenNotSupportException(s"$original not supported yet.")
}
- GenericExpressionTransformer(
- substraitExprName,
- Seq(srcExpr, matchingExpr, replaceExpr),
- original)
- .doTransform(args)
+ super.doTransform(args)
}
}
@@ -158,7 +151,7 @@ case class CHPosExplodeTransformer(
child: ExpressionTransformer,
original: PosExplode,
attributeSeq: Seq[Attribute])
- extends ExpressionTransformerWithOrigin {
+ extends UnaryExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val childNode: ExpressionNode = child.doTransform(args)
@@ -200,14 +193,15 @@ case class CHPosExplodeTransformer(
case class CHRegExpReplaceTransformer(
substraitExprName: String,
- children: Seq[ExpressionTransformer],
+ childrenWithPos: Seq[ExpressionTransformer],
original: RegExpReplace)
- extends ExpressionTransformerWithOrigin {
+ extends ExpressionTransformer {
+ override def children: Seq[ExpressionTransformer] =
childrenWithPos.dropRight(1)
override def doTransform(args: java.lang.Object): ExpressionNode = {
// In CH: replaceRegexpAll(subject, regexp, rep), which is equivalent
// In Spark: regexp_replace(subject, regexp, rep, pos=1)
- val posNode = children(3).doTransform(args)
+ val posNode = childrenWithPos(3).doTransform(args)
if (
!posNode.isInstanceOf[IntLiteralNode] ||
posNode.asInstanceOf[IntLiteralNode].getValue != 1
@@ -215,11 +209,7 @@ case class CHRegExpReplaceTransformer(
throw new UnsupportedOperationException(s"$original not supported yet.")
}
- GenericExpressionTransformer(
- substraitExprName,
- Seq(children(0), children(1), children(2)),
- original)
- .doTransform(args)
+ super.doTransform(args)
}
}
@@ -228,7 +218,7 @@ case class GetArrayItemTransformer(
left: ExpressionTransformer,
right: ExpressionTransformer,
original: Expression)
- extends ExpressionTransformerWithOrigin {
+ extends BinaryExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
// Ignore failOnError for clickhouse backend
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index c30e34952..2d37b1185 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -104,6 +104,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
val condFuncName = ExpressionMappings.expressionsMap(classOf[IsNaN])
val newExpr = If(condExpr, original.right, original.left)
IfTransformer(
+ substraitExprName,
GenericExpressionTransformer(condFuncName, Seq(left), condExpr),
right,
left,
@@ -117,7 +118,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
original: Uuid): ExpressionTransformer = {
GenericExpressionTransformer(
substraitExprName,
- Seq(LiteralTransformer(Literal(original.randomSeed.get))),
+ Seq(LiteralTransformer(original.randomSeed.get)),
original)
}
@@ -243,6 +244,17 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, Seq(child), expr)
}
+ override def genLikeTransformer(
+ substraitExprName: String,
+ left: ExpressionTransformer,
+ right: ExpressionTransformer,
+ original: Like): ExpressionTransformer = {
+ GenericExpressionTransformer(
+ substraitExprName,
+ Seq(left, right, LiteralTransformer(original.escapeChar)),
+ original)
+ }
+
/** Transform make_timestamp to Substrait. */
override def genMakeTimestampTransformer(
substraitExprName: String,
@@ -251,6 +263,14 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, children, expr)
}
+ override def genDateDiffTransformer(
+ substraitExprName: String,
+ endDate: ExpressionTransformer,
+ startDate: ExpressionTransformer,
+ original: DateDiff): ExpressionTransformer = {
+ GenericExpressionTransformer(substraitExprName, Seq(endDate, startDate),
original)
+ }
+
/**
* Generate FilterExecTransformer.
*
@@ -419,7 +439,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
override def genHashExpressionTransformer(
substraitExprName: String,
exprs: Seq[ExpressionTransformer],
- original: Expression): ExpressionTransformer = {
+ original: HashExpression[_]): ExpressionTransformer = {
VeloxHashExpressionTransformer(substraitExprName, exprs, original)
}
@@ -612,7 +632,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
childTransformer: ExpressionTransformer,
ordinal: Int,
original: GetStructField): ExpressionTransformer = {
- VeloxGetStructFieldTransformer(substraitExprName, childTransformer,
ordinal, original)
+ VeloxGetStructFieldTransformer(substraitExprName, childTransformer,
original)
}
/**
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
index 33f612440..aadfcd9b7 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
@@ -63,19 +63,6 @@ class VeloxTransformerApi extends TransformerApi with
Logging {
// TODO: IMPLEMENT SPECIAL PROCESS FOR VELOX BACKEND
}
- override def createDateDiffParamList(
- start: ExpressionNode,
- end: ExpressionNode): Iterable[ExpressionNode] = {
- List(end, start)
- }
-
- override def createLikeParamList(
- left: ExpressionNode,
- right: ExpressionNode,
- escapeChar: ExpressionNode): Iterable[ExpressionNode] = {
- List(left, right, escapeChar)
- }
-
override def createCheckOverflowExprNode(
args: java.lang.Object,
substraitExprName: String,
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
index da8433fa2..0f0eb2969 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
@@ -24,8 +24,6 @@ import org.apache.gluten.substrait.expression._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{IntegerType, LongType}
-import com.google.common.collect.Lists
-
import java.lang.{Integer => JInteger, Long => JLong}
import java.util.{ArrayList => JArrayList, HashMap => JHashMap}
@@ -35,7 +33,7 @@ case class VeloxAliasTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: Expression)
- extends ExpressionTransformerWithOrigin {
+ extends UnaryExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
child.doTransform(args)
@@ -46,36 +44,25 @@ case class VeloxNamedStructTransformer(
substraitExprName: String,
original: CreateNamedStruct,
attributeSeq: Seq[Attribute])
- extends ExpressionTransformerWithOrigin {
- override def doTransform(args: Object): ExpressionNode = {
- val expressionNodes = Lists.newArrayList[ExpressionNode]()
- original.valExprs.foreach(
- child =>
- expressionNodes.add(
- replaceWithExpressionTransformer(child,
attributeSeq).doTransform(args)))
- val functionMap = args.asInstanceOf[JHashMap[String, JLong]]
- val functionName = ConverterUtils
- .makeFuncName(substraitExprName, Seq(original.dataType),
FunctionConfig.OPT)
- val functionId = ExpressionBuilder.newScalarFunction(functionMap,
functionName)
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
+ extends ExpressionTransformer {
+ override def children: Seq[ExpressionTransformer] = {
+ original.valExprs.map(replaceWithExpressionTransformer(_, attributeSeq))
}
}
case class VeloxGetStructFieldTransformer(
substraitExprName: String,
- childTransformer: ExpressionTransformer,
- ordinal: Int,
+ child: ExpressionTransformer,
original: GetStructField)
- extends ExpressionTransformerWithOrigin {
+ extends UnaryExpressionTransformer {
override def doTransform(args: Object): ExpressionNode = {
- val childNode = childTransformer.doTransform(args)
+ val childNode = child.doTransform(args)
childNode match {
case node: StructLiteralNode =>
- node.getFieldLiteral(ordinal)
+ node.getFieldLiteral(original.ordinal)
case node: SelectionNode =>
// Append the nested index to selection node.
- node.addNestedChildIdx(JInteger.valueOf(ordinal))
+ node.addNestedChildIdx(JInteger.valueOf(original.ordinal))
case other =>
throw new GlutenNotSupportException(s"$other is not supported.")
}
@@ -84,9 +71,10 @@ case class VeloxGetStructFieldTransformer(
case class VeloxHashExpressionTransformer(
substraitExprName: String,
- exps: Seq[ExpressionTransformer],
- original: Expression)
- extends ExpressionTransformerWithOrigin {
+ children: Seq[ExpressionTransformer],
+ original: HashExpression[_])
+ extends ExpressionTransformer {
+
override def doTransform(args: java.lang.Object): ExpressionNode = {
// As of Spark 3.3, there are 3 kinds of HashExpression.
// HiveHash is not supported in native backend and will fail native
validation.
@@ -101,7 +89,7 @@ case class VeloxHashExpressionTransformer(
val nodes = new JArrayList[ExpressionNode]()
// Seed as the first argument
nodes.add(seedNode)
- exps.foreach(
+ children.foreach(
expression => {
nodes.add(expression.doTransform(args))
})
@@ -121,7 +109,9 @@ case class VeloxStringSplitTransformer(
regexExpr: ExpressionTransformer,
limitExpr: ExpressionTransformer,
original: StringSplit)
- extends ExpressionTransformerWithOrigin {
+ extends ExpressionTransformer {
+ // TODO: split function support limit arg
+ override def children: Seq[ExpressionTransformer] = srcExpr :: regexExpr ::
Nil
override def doTransform(args: java.lang.Object): ExpressionNode = {
if (
@@ -139,8 +129,6 @@ case class VeloxStringSplitTransformer(
s"$original supported single-length regex and negative limit, but
given $limit and $regex")
}
- // TODO: split function support limit arg
- GenericExpressionTransformer(substraitExprName, Seq(srcExpr, regexExpr),
original)
- .doTransform(args)
+ super.doTransform(args)
}
}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
index c34c1ae7f..ec98e98f1 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
@@ -18,9 +18,7 @@ package org.apache.spark.sql.expression
import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
import org.apache.gluten.exception.GlutenException
-import org.apache.gluten.expression.{ConverterUtils, ExpressionTransformer,
ExpressionType, Transformable}
-import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
+import org.apache.gluten.expression.{ConverterUtils, ExpressionTransformer,
ExpressionType, GenericExpressionTransformer, Transformable}
import org.apache.gluten.udf.UdfJniWrapper
import org.apache.gluten.vectorized.JniWorkspace
@@ -37,8 +35,6 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils
-import com.google.common.collect.Lists
-
import java.io.File
import java.net.URI
import java.nio.file.{Files, FileVisitOption, Paths}
@@ -112,24 +108,7 @@ case class UDFExpression(
": getTransformer called before children transformer initialized.")
}
- val localDataType = dataType
- new ExpressionTransformer {
- override def doTransform(args: Object): ExpressionNode = {
- val transformers = childrenTransformers.map(_.doTransform(args))
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(name, children.map(_.dataType),
FunctionConfig.REQ))
-
- val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
- ExpressionBuilder.makeScalarFunction(
- functionId,
- Lists.newArrayList(transformers: _*),
- typeNode)
- }
-
- override def dataType: DataType = localDataType
- }
+ GenericExpressionTransformer(name, childrenTransformers, this)
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 7e72b1758..69777f77a 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -214,7 +214,7 @@ trait SparkPlanExecApi {
throw new GlutenNotSupportException("try_add is not supported")
}
- def genTryAddTransformer(
+ def genTryEvalTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: TryEval): ExpressionTransformer = {
@@ -286,9 +286,7 @@ trait SparkPlanExecApi {
substraitExprName: String,
child: ExpressionTransformer,
original: PosExplode,
- attributeSeq: Seq[Attribute]): ExpressionTransformer = {
- PosExplodeTransformer(substraitExprName, child, original, attributeSeq)
- }
+ attributeSeq: Seq[Attribute]): ExpressionTransformer
/** Transform make_timestamp to Substrait. */
def genMakeTimestampTransformer(
@@ -427,7 +425,7 @@ trait SparkPlanExecApi {
childTransformer: ExpressionTransformer,
ordinal: Int,
original: GetStructField): ExpressionTransformer = {
- GetStructFieldTransformer(substraitExprName, childTransformer, ordinal,
original)
+ GetStructFieldTransformer(substraitExprName, childTransformer, original)
}
def genNamedStructTransformer(
@@ -438,13 +436,6 @@ trait SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, children, original)
}
- def genMd5Transformer(
- substraitExprName: String,
- child: ExpressionTransformer,
- original: Md5): ExpressionTransformer = {
- GenericExpressionTransformer(substraitExprName, Seq(child), original)
- }
-
def genStringTranslateTransformer(
substraitExprName: String,
srcExpr: ExpressionTransformer,
@@ -457,38 +448,6 @@ trait SparkPlanExecApi {
original)
}
- def genStringLocateTransformer(
- substraitExprName: String,
- first: ExpressionTransformer,
- second: ExpressionTransformer,
- third: ExpressionTransformer,
- original: StringLocate): ExpressionTransformer = {
- GenericExpressionTransformer(substraitExprName, Seq(first, second, third),
original)
- }
-
- /**
- * Generate an ExpressionTransformer to transform Sha2 expression.
Sha2Transformer is the default
- * implementation.
- */
- def genSha2Transformer(
- substraitExprName: String,
- left: ExpressionTransformer,
- right: ExpressionTransformer,
- original: Sha2): ExpressionTransformer = {
- GenericExpressionTransformer(substraitExprName, Seq(left, right), original)
- }
-
- /**
- * Generate an ExpressionTransformer to transform Sha1 expression.
Sha1Transformer is the default
- * implementation.
- */
- def genSha1Transformer(
- substraitExprName: String,
- child: ExpressionTransformer,
- original: Sha1): ExpressionTransformer = {
- GenericExpressionTransformer(substraitExprName, Seq(child), original)
- }
-
def genSizeExpressionTransformer(
substraitExprName: String,
child: ExpressionTransformer,
@@ -496,6 +455,12 @@ trait SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, Seq(child), original)
}
+ def genLikeTransformer(
+ substraitExprName: String,
+ left: ExpressionTransformer,
+ right: ExpressionTransformer,
+ original: Like): ExpressionTransformer
+
/**
* Generate an ExpressionTransformer to transform TruncTimestamp expression.
* TruncTimestampTransformer is the default implementation.
@@ -506,30 +471,22 @@ trait SparkPlanExecApi {
timestamp: ExpressionTransformer,
timeZoneId: Option[String] = None,
original: TruncTimestamp): ExpressionTransformer = {
- TruncTimestampTransformer(substraitExprName, format, timestamp,
timeZoneId, original)
+ TruncTimestampTransformer(substraitExprName, format, timestamp, original)
}
+ def genDateDiffTransformer(
+ substraitExprName: String,
+ endDate: ExpressionTransformer,
+ startDate: ExpressionTransformer,
+ original: DateDiff): ExpressionTransformer
+
def genCastWithNewChild(c: Cast): Cast = c
def genHashExpressionTransformer(
substraitExprName: String,
exprs: Seq[ExpressionTransformer],
- original: Expression): ExpressionTransformer = {
- HashExpressionTransformer(substraitExprName, exprs, original)
- }
-
- def genUnixTimestampTransformer(
- substraitExprName: String,
- timeExp: ExpressionTransformer,
- format: ExpressionTransformer,
- original: ToUnixTimestamp): ExpressionTransformer = {
- ToUnixTimestampTransformer(
- substraitExprName,
- timeExp,
- format,
- original.timeZoneId,
- original.failOnError,
- original)
+ original: HashExpression[_]): ExpressionTransformer = {
+ GenericExpressionTransformer(substraitExprName, exprs, original)
}
/** Define backend specfic expression mappings. */
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
index 7a10dc68c..e41df0f2f 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
@@ -58,13 +58,6 @@ trait TransformerApi {
plan.output
}
- def createDateDiffParamList(start: ExpressionNode, end: ExpressionNode):
Iterable[ExpressionNode]
-
- def createLikeParamList(
- left: ExpressionNode,
- right: ExpressionNode,
- escapeChar: ExpressionNode): Iterable[ExpressionNode]
-
def createCheckOverflowExprNode(
args: java.lang.Object,
substraitExprName: String,
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala
index 38f65c178..2a09e039e 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala
@@ -17,37 +17,24 @@
package org.apache.gluten.expression
import org.apache.gluten.exception.GlutenNotSupportException
-import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
+import org.apache.gluten.substrait.expression.ExpressionNode
import org.apache.spark.sql.catalyst.expressions._
-import scala.collection.JavaConverters._
-
case class CreateArrayTransformer(
substraitExprName: String,
children: Seq[ExpressionTransformer],
- useStringTypeWhenEmpty: Boolean,
original: CreateArray)
- extends ExpressionTransformerWithOrigin {
+ extends ExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
// If children is empty,
// transformation is only supported when useStringTypeWhenEmpty is false
// because ClickHouse and Velox currently doesn't support this config.
- if (useStringTypeWhenEmpty && children.isEmpty) {
+ if (original.useStringTypeWhenEmpty && children.isEmpty) {
throw new GlutenNotSupportException(s"$original not supported yet.")
}
- val childNodes = children.map(_.doTransform(args)).asJava
-
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionName = ConverterUtils.makeFuncName(
- substraitExprName,
- original.children.map(_.dataType),
- FunctionConfig.OPT)
- val functionId = ExpressionBuilder.newScalarFunction(functionMap,
functionName)
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, childNodes, typeNode)
+ super.doTransform(args)
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/BoundReferenceTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/BoundReferenceTransformer.scala
deleted file mode 100644
index 2cfced13b..000000000
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/BoundReferenceTransformer.scala
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.expression
-
-import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
-
-import org.apache.spark.sql.types._
-
-case class BoundReferenceTransformer(ordinal: Int, dataType: DataType,
nullable: Boolean)
- extends ExpressionTransformer {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- ExpressionBuilder.makeSelection(ordinal.asInstanceOf[java.lang.Integer])
- }
-}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala
index 0fdd68511..1dffd3906 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala
@@ -24,10 +24,13 @@ import java.util.{ArrayList => JArrayList}
/** A version of substring that supports columnar processing for utf8. */
case class CaseWhenTransformer(
+ substraitExprName: String,
branches: Seq[(ExpressionTransformer, ExpressionTransformer)],
elseValue: Option[ExpressionTransformer],
- original: Expression)
- extends ExpressionTransformerWithOrigin {
+ original: CaseWhen)
+ extends ExpressionTransformer {
+ override def children: Seq[ExpressionTransformer] =
+ branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
override def doTransform(args: java.lang.Object): ExpressionNode = {
// generate branches nodes
@@ -48,11 +51,13 @@ case class CaseWhenTransformer(
}
case class IfTransformer(
+ substraitExprName: String,
predicate: ExpressionTransformer,
trueValue: ExpressionTransformer,
falseValue: ExpressionTransformer,
- original: Expression)
- extends ExpressionTransformerWithOrigin {
+ original: If)
+ extends ExpressionTransformer {
+ override def children: Seq[ExpressionTransformer] = predicate :: trueValue
:: falseValue :: Nil
override def doTransform(args: java.lang.Object): ExpressionNode = {
val ifNodes = new JArrayList[ExpressionNode]
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/DateTimeExpressionsTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/DateTimeExpressionsTransformer.scala
index 66004291a..505ca33ea 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/DateTimeExpressionsTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/DateTimeExpressionsTransformer.scala
@@ -16,140 +16,36 @@
*/
package org.apache.gluten.expression
-import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
-import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types._
-
-import com.google.common.collect.Lists
-
-import java.lang.{Long => JLong}
-import java.util.{ArrayList => JArrayList, HashMap => JHashMap}
-
-import scala.collection.JavaConverters._
/** The extract trait for 'GetDateField' from Date */
case class ExtractDateTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: Expression)
- extends ExpressionTransformerWithOrigin {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val childNode = child.doTransform(args)
-
- val functionMap = args.asInstanceOf[JHashMap[String, JLong]]
- val functionName = ConverterUtils.makeFuncName(
- substraitExprName,
- original.children.map(_.dataType),
- FunctionConfig.OPT)
- val functionId = ExpressionBuilder.newScalarFunction(functionMap,
functionName)
+ extends BinaryExpressionTransformer {
+ override def left: ExpressionTransformer = {
val dateFieldName =
DateTimeExpressionsTransformer.EXTRACT_DATE_FIELD_MAPPING.get(original.getClass)
if (dateFieldName.isEmpty) {
throw new GlutenNotSupportException(s"$original not supported yet.")
}
- val fieldNode = ExpressionBuilder.makeStringLiteral(dateFieldName.get)
- val expressNodes = Lists.newArrayList(fieldNode, childNode)
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
-
- ExpressionBuilder.makeScalarFunction(functionId, expressNodes, typeNode)
- }
-}
-
-case class DateDiffTransformer(
- substraitExprName: String,
- endDate: ExpressionTransformer,
- startDate: ExpressionTransformer,
- original: DateDiff)
- extends ExpressionTransformerWithOrigin {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val endDateNode = endDate.doTransform(args)
- val startDateNode = startDate.doTransform(args)
-
- val functionMap = args.asInstanceOf[JHashMap[String, JLong]]
- val functionName = ConverterUtils.makeFuncName(
- substraitExprName,
- Seq(StringType, original.startDate.dataType, original.endDate.dataType),
- FunctionConfig.OPT)
- val functionId = ExpressionBuilder.newScalarFunction(functionMap,
functionName)
-
- val expressionNodes =
BackendsApiManager.getTransformerApiInstance.createDateDiffParamList(
- startDateNode,
- endDateNode)
- ExpressionBuilder.makeScalarFunction(
- functionId,
- expressionNodes.toList.asJava,
- ConverterUtils.getTypeNode(original.dataType, original.nullable))
- }
-}
-
-/**
- * The failOnError depends on the config for ANSI. ANSI is not supported
currently. And timeZoneId
- * is passed to backend config.
- */
-case class ToUnixTimestampTransformer(
- substraitExprName: String,
- timeExp: ExpressionTransformer,
- format: ExpressionTransformer,
- timeZoneId: Option[String],
- failOnError: Boolean,
- original: ToUnixTimestamp)
- extends ExpressionTransformerWithOrigin {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val dataTypes = Seq(original.timeExp.dataType, StringType)
- val functionMap = args.asInstanceOf[JHashMap[String, JLong]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(substraitExprName, dataTypes))
-
- val expressionNodes = new JArrayList[ExpressionNode]()
- val timeExpNode = timeExp.doTransform(args)
- expressionNodes.add(timeExpNode)
- val formatNode = format.doTransform(args)
- expressionNodes.add(formatNode)
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
+ LiteralTransformer(dateFieldName.get)
}
+ override def right: ExpressionTransformer = child
}
case class TruncTimestampTransformer(
substraitExprName: String,
format: ExpressionTransformer,
timestamp: ExpressionTransformer,
- timeZoneId: Option[String] = None,
original: TruncTimestamp)
- extends ExpressionTransformerWithOrigin {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val timestampNode = timestamp.doTransform(args)
- val formatNode = format.doTransform(args)
-
- val functionMap = args.asInstanceOf[JHashMap[String, JLong]]
- val dataTypes = if (timeZoneId.isDefined) {
- Seq(original.format.dataType, original.timestamp.dataType, StringType)
- } else {
- Seq(original.format.dataType, original.timestamp.dataType)
- }
-
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(substraitExprName, dataTypes))
-
- val expressionNodes = new JArrayList[ExpressionNode]()
- expressionNodes.add(formatNode)
- expressionNodes.add(timestampNode)
- if (timeZoneId.isDefined) {
- expressionNodes.add(ExpressionBuilder.makeStringLiteral(timeZoneId.get))
- }
-
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
+ extends ExpressionTransformer {
+ override def children: Seq[ExpressionTransformer] = {
+ val timeZoneId = original.timeZoneId.map(timeZoneId =>
LiteralTransformer(timeZoneId))
+ Seq(format, timestamp) ++ timeZoneId
}
}
@@ -158,36 +54,24 @@ case class MonthsBetweenTransformer(
date1: ExpressionTransformer,
date2: ExpressionTransformer,
roundOff: ExpressionTransformer,
- timeZoneId: Option[String] = None,
original: MonthsBetween)
- extends ExpressionTransformerWithOrigin {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val date1Node = date1.doTransform(args)
- val data2Node = date2.doTransform(args)
- val roundOffNode = roundOff.doTransform(args)
-
- val functionMap = args.asInstanceOf[JHashMap[String, JLong]]
- val dataTypes = if (timeZoneId.isDefined) {
- Seq(original.date1.dataType, original.date2.dataType,
original.roundOff.dataType, StringType)
- } else {
- Seq(original.date1.dataType, original.date2.dataType,
original.roundOff.dataType)
- }
-
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(substraitExprName, dataTypes))
-
- val expressionNodes = new JArrayList[ExpressionNode]()
- expressionNodes.add(date1Node)
- expressionNodes.add(data2Node)
- expressionNodes.add(roundOffNode)
- if (timeZoneId.isDefined) {
- expressionNodes.add(ExpressionBuilder.makeStringLiteral(timeZoneId.get))
- }
+ extends ExpressionTransformer {
+ override def children: Seq[ExpressionTransformer] = {
+ val timeZoneId = original.timeZoneId.map(timeZoneId =>
LiteralTransformer(timeZoneId))
+ Seq(date1, date2, roundOff) ++ timeZoneId
+ }
+}
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
+case class TimestampAddTransformer(
+ substraitExprName: String,
+ unit: String,
+ left: ExpressionTransformer,
+ right: ExpressionTransformer,
+ timeZoneId: String,
+ original: Expression)
+ extends ExpressionTransformer {
+ override def children: Seq[ExpressionTransformer] = {
+ Seq(LiteralTransformer(unit), left, right, LiteralTransformer(timeZoneId))
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/DecimalRoundTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/DecimalRoundTransformer.scala
index 60e64cd95..305d4feb9 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/DecimalRoundTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/DecimalRoundTransformer.scala
@@ -17,24 +17,20 @@
package org.apache.gluten.expression
import org.apache.gluten.exception.GlutenNotSupportException
-import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{DataType, DecimalType}
-import com.google.common.collect.Lists
-
case class DecimalRoundTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: Round)
- extends ExpressionTransformer {
+ extends BinaryExpressionTransformer {
val toScale: Int = original.scale.eval(EmptyRow).asInstanceOf[Int]
// Use the same result type for different Spark versions.
- val dataType: DataType = original.child.dataType match {
+ override val dataType: DataType = original.child.dataType match {
case decimalType: DecimalType =>
val p = decimalType.precision
val s = decimalType.scale
@@ -57,21 +53,6 @@ case class DecimalRoundTransformer(
s"Decimal type is expected but received
${original.child.dataType.typeName}.")
}
- override def doTransform(args: Object): ExpressionNode = {
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(
- substraitExprName,
- Seq(original.child.dataType),
- FunctionConfig.OPT))
-
- ExpressionBuilder.makeScalarFunction(
- functionId,
- Lists.newArrayList[ExpressionNode](
- child.doTransform(args),
- ExpressionBuilder.makeIntLiteral(toScale)),
- ConverterUtils.getTypeNode(dataType, original.nullable)
- )
- }
+ override def left: ExpressionTransformer = child
+ override def right: ExpressionTransformer = LiteralTransformer(toScale)
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index e22a20e0d..6e1427e2f 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -25,9 +25,8 @@ import org.apache.gluten.test.TestStats
import org.apache.gluten.utils.DecimalArithmeticUtil
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
+import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.execution.{ScalarSubquery, _}
@@ -37,13 +36,8 @@ import org.apache.spark.sql.hive.HiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-trait Transformable extends Expression {
+trait Transformable extends Unevaluable {
def getTransformer(childrenTransformers: Seq[ExpressionTransformer]):
ExpressionTransformer
-
- override def eval(input: InternalRow): Any = throw new
UnsupportedOperationException()
-
- override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode =
- throw new UnsupportedOperationException()
}
object ExpressionConverter extends SQLConfHelper with Logging {
@@ -172,7 +166,7 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
case c: CreateArray =>
val children =
c.children.map(replaceWithExpressionTransformerInternal(_,
attributeSeq, expressionsMap))
- CreateArrayTransformer(substraitExprName, children,
useStringTypeWhenEmpty = true, c)
+ CreateArrayTransformer(substraitExprName, children, c)
case g: GetArrayItem =>
BackendsApiManager.getSparkPlanExecApiInstance.genGetArrayItemTransformer(
substraitExprName,
@@ -183,7 +177,7 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
case c: CreateMap =>
val children =
c.children.map(replaceWithExpressionTransformerInternal(_,
attributeSeq, expressionsMap))
- CreateMapTransformer(substraitExprName, children,
c.useStringTypeWhenEmpty, c)
+ CreateMapTransformer(substraitExprName, children, c)
case g: GetMapValue =>
BackendsApiManager.getSparkPlanExecApiInstance.genGetMapValueTransformer(
substraitExprName,
@@ -225,14 +219,7 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
val bindReference =
BindReferences.bindReference(expr, attributeSeq, allowFailures =
false)
val b = bindReference.asInstanceOf[BoundReference]
- AttributeReferenceTransformer(
- a.name,
- b.ordinal,
- a.dataType,
- b.nullable,
- a.exprId,
- a.qualifier,
- a.metadata)
+ AttributeReferenceTransformer(substraitExprName, a, b)
} catch {
case e: IllegalStateException =>
// This situation may need developers to fix, although we just
throw the below
@@ -241,11 +228,11 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
s"Failed to bind reference for $expr: ${e.getMessage}")
}
case b: BoundReference =>
- BoundReferenceTransformer(b.ordinal, b.dataType, b.nullable)
+ BoundReferenceTransformer(substraitExprName, b)
case l: Literal =>
LiteralTransformer(l)
case d: DateDiff =>
- DateDiffTransformer(
+ BackendsApiManager.getSparkPlanExecApiInstance.genDateDiffTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(d.endDate, attributeSeq,
expressionsMap),
replaceWithExpressionTransformerInternal(d.startDate, attributeSeq,
expressionsMap),
@@ -257,17 +244,23 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
replaceWithExpressionTransformerInternal(r.child, attributeSeq,
expressionsMap),
r)
case t: ToUnixTimestamp =>
-
BackendsApiManager.getSparkPlanExecApiInstance.genUnixTimestampTransformer(
+ // The failOnError depends on the config for ANSI. ANSI is not
supported currently.
+ // And timeZoneId is passed to backend config.
+ GenericExpressionTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(t.timeExp, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(t.format, attributeSeq,
expressionsMap),
+ Seq(
+ replaceWithExpressionTransformerInternal(t.timeExp, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformerInternal(t.format, attributeSeq,
expressionsMap)
+ ),
t
)
case u: UnixTimestamp =>
-
BackendsApiManager.getSparkPlanExecApiInstance.genUnixTimestampTransformer(
+ GenericExpressionTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(u.timeExp, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(u.format, attributeSeq,
expressionsMap),
+ Seq(
+ replaceWithExpressionTransformerInternal(u.timeExp, attributeSeq,
expressionsMap),
+ replaceWithExpressionTransformerInternal(u.format, attributeSeq,
expressionsMap)
+ ),
ToUnixTimestamp(u.timeExp, u.format, u.timeZoneId, u.failOnError)
)
case t: TruncTimestamp =>
@@ -284,11 +277,11 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
replaceWithExpressionTransformerInternal(m.date1, attributeSeq,
expressionsMap),
replaceWithExpressionTransformerInternal(m.date2, attributeSeq,
expressionsMap),
replaceWithExpressionTransformerInternal(m.roundOff, attributeSeq,
expressionsMap),
- m.timeZoneId,
m
)
case i: If =>
IfTransformer(
+ substraitExprName,
replaceWithExpressionTransformerInternal(i.predicate, attributeSeq,
expressionsMap),
replaceWithExpressionTransformerInternal(i.trueValue, attributeSeq,
expressionsMap),
replaceWithExpressionTransformerInternal(i.falseValue, attributeSeq,
expressionsMap),
@@ -296,6 +289,7 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
)
case cw: CaseWhen =>
CaseWhenTransformer(
+ substraitExprName,
cw.branches.map {
expr =>
{
@@ -318,26 +312,23 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
s"In list option does not support non-foldable expression,
${i.list.map(_.sql)}")
}
InTransformer(
+ substraitExprName,
replaceWithExpressionTransformerInternal(i.value, attributeSeq,
expressionsMap),
- i.list,
- i.value.dataType,
i)
case i: InSet =>
InSetTransformer(
+ substraitExprName,
replaceWithExpressionTransformerInternal(i.child, attributeSeq,
expressionsMap),
- i.hset,
- i.child.dataType,
i)
case s: ScalarSubquery =>
- ScalarSubqueryTransformer(s.plan, s.exprId, s)
+ ScalarSubqueryTransformer(substraitExprName, s)
case c: Cast =>
// Add trim node, as necessary.
val newCast =
BackendsApiManager.getSparkPlanExecApiInstance.genCastWithNewChild(c)
CastTransformer(
+ substraitExprName,
replaceWithExpressionTransformerInternal(newCast.child,
attributeSeq, expressionsMap),
- newCast.dataType,
- newCast.timeZoneId,
newCast)
case s: String2TrimExpression =>
val (srcStr, trimStr) = s match {
@@ -345,10 +336,13 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
case StringTrimLeft(srcStr, trimStr) => (srcStr, trimStr)
case StringTrimRight(srcStr, trimStr) => (srcStr, trimStr)
}
- String2TrimExpressionTransformer(
+ val children = trimStr
+ .map(replaceWithExpressionTransformerInternal(_, attributeSeq,
expressionsMap))
+ .toSeq ++
+ Seq(replaceWithExpressionTransformerInternal(srcStr, attributeSeq,
expressionsMap))
+ GenericExpressionTransformer(
substraitExprName,
- trimStr.map(replaceWithExpressionTransformerInternal(_,
attributeSeq, expressionsMap)),
- replaceWithExpressionTransformerInternal(srcStr, attributeSeq,
expressionsMap),
+ children,
s
)
case m: HashExpression[_] =>
@@ -368,15 +362,14 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
getStructField.ordinal,
getStructField)
case getArrayStructFields: GetArrayStructFields =>
- GetArrayStructFieldsTransformer(
+ GenericExpressionTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(
- getArrayStructFields.child,
- attributeSeq,
- expressionsMap),
- getArrayStructFields.ordinal,
- getArrayStructFields.numFields,
- getArrayStructFields.containsNull,
+ Seq(
+ replaceWithExpressionTransformerInternal(
+ getArrayStructFields.child,
+ attributeSeq,
+ expressionsMap),
+ LiteralTransformer(getArrayStructFields.ordinal)),
getArrayStructFields
)
case t: StringTranslate =>
@@ -387,14 +380,6 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
replaceWithExpressionTransformerInternal(t.replaceExpr,
attributeSeq, expressionsMap),
t
)
- case l: StringLocate =>
-
BackendsApiManager.getSparkPlanExecApiInstance.genStringLocateTransformer(
- substraitExprName,
- replaceWithExpressionTransformerInternal(l.first, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(l.second, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(l.third, attributeSeq,
expressionsMap),
- l
- )
case s: StringSplit =>
BackendsApiManager.getSparkPlanExecApiInstance.genStringSplitTransformer(
substraitExprName,
@@ -414,23 +399,6 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
),
r
)
- case md5: Md5 =>
- BackendsApiManager.getSparkPlanExecApiInstance.genMd5Transformer(
- substraitExprName,
- replaceWithExpressionTransformerInternal(md5.child, attributeSeq,
expressionsMap),
- md5)
- case sha1: Sha1 =>
- BackendsApiManager.getSparkPlanExecApiInstance.genSha1Transformer(
- substraitExprName,
- replaceWithExpressionTransformerInternal(sha1.child, attributeSeq,
expressionsMap),
- sha1)
- case sha2: Sha2 =>
- BackendsApiManager.getSparkPlanExecApiInstance.genSha2Transformer(
- substraitExprName,
- replaceWithExpressionTransformerInternal(sha2.left, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(sha2.right, attributeSeq,
expressionsMap),
- sha2
- )
case size: Size =>
if (size.legacySizeOfNull != SQLConf.get.legacySizeOfNull) {
throw new GlutenNotSupportException(
@@ -449,12 +417,11 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
namedStruct,
attributeSeq)
case namedLambdaVariable: NamedLambdaVariable =>
- NamedLambdaVariableTransformer(
+ // namedlambdavariable('acc')-> <Integer, notnull>
+ GenericExpressionTransformer(
substraitExprName,
- name = namedLambdaVariable.name,
- dataType = namedLambdaVariable.dataType,
- nullable = namedLambdaVariable.nullable,
- exprId = namedLambdaVariable.exprId
+ LiteralTransformer(namedLambdaVariable.name),
+ namedLambdaVariable
)
case lambdaFunction: LambdaFunction =>
LambdaFunctionTransformer(
@@ -472,25 +439,33 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
j.children.map(replaceWithExpressionTransformerInternal(_,
attributeSeq, expressionsMap))
JsonTupleExpressionTransformer(substraitExprName, children, j)
case l: Like =>
- LikeTransformer(
+ BackendsApiManager.getSparkPlanExecApiInstance.genLikeTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(l.left, attributeSeq,
expressionsMap),
replaceWithExpressionTransformerInternal(l.right, attributeSeq,
expressionsMap),
l
)
case m: MakeDecimal =>
- MakeDecimalTransformer(
+ GenericExpressionTransformer(
substraitExprName,
- replaceWithExpressionTransformerInternal(m.child, attributeSeq,
expressionsMap),
- m)
+ Seq(
+ replaceWithExpressionTransformerInternal(m.child, attributeSeq,
expressionsMap),
+ LiteralTransformer(m.nullOnOverflow)),
+ m
+ )
case rand: Rand =>
BackendsApiManager.getSparkPlanExecApiInstance.genRandTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(rand.child, attributeSeq,
expressionsMap),
rand)
- case _: NormalizeNaNAndZero | _: PromotePrecision =>
+ case _: NormalizeNaNAndZero | _: PromotePrecision | _: TaggingExpression
=>
ChildTransformer(
- replaceWithExpressionTransformerInternal(expr.children.head,
attributeSeq, expressionsMap)
+ substraitExprName,
+ replaceWithExpressionTransformerInternal(
+ expr.children.head,
+ attributeSeq,
+ expressionsMap),
+ expr
)
case _: GetDateField | _: GetTimeField =>
ExtractDateTransformer(
@@ -524,7 +499,6 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
CheckOverflowTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(c.child, attributeSeq,
expressionsMap),
- c.child.dataType,
c)
case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b)
=>
DecimalArithmeticUtil.checkAllowDecimalArithmetic()
@@ -565,12 +539,7 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
replaceWithExpressionTransformerInternal(add.left, attributeSeq,
expressionsMap),
replaceWithExpressionTransformerInternal(add.right, attributeSeq,
expressionsMap),
extract.get.last,
- add.dataType,
- add.nullable
- )
- case e: TaggingExpression =>
- ChildTransformer(
- replaceWithExpressionTransformerInternal(e.child, attributeSeq,
expressionsMap)
+ add
)
case e: Transformable =>
val childrenTransformers =
@@ -614,7 +583,7 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
)
case tryEval: TryEval =>
// This is a placeholder to handle try_eval(other expressions).
- BackendsApiManager.getSparkPlanExecApiInstance.genTryAddTransformer(
+ BackendsApiManager.getSparkPlanExecApiInstance.genTryEvalTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(tryEval.child,
attributeSeq, expressionsMap),
tryEval
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
index 6b6587862..ebb9db3e8 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
@@ -16,17 +16,74 @@
*/
package org.apache.gluten.expression
-import org.apache.gluten.substrait.expression.ExpressionNode
+import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.types.DataType
-trait ExpressionTransformer {
- def doTransform(args: java.lang.Object): ExpressionNode
- def dataType: DataType
-}
+import scala.collection.JavaConverters._
+
+// ==== Expression transformer basic interface start ====
-trait ExpressionTransformerWithOrigin extends ExpressionTransformer {
+trait ExpressionTransformer {
+ def substraitExprName: String
+ def children: Seq[ExpressionTransformer]
def original: Expression
def dataType: DataType = original.dataType
+ def nullable: Boolean = original.nullable
+
+ def doTransform(args: java.lang.Object): ExpressionNode = {
+ val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
+ // TODO: the funcName seems can be simplified to `substraitExprName`
+ val funcName: String =
+ ConverterUtils.makeFuncName(substraitExprName,
original.children.map(_.dataType))
+ val functionId = ExpressionBuilder.newScalarFunction(functionMap, funcName)
+ val childNodes = children.map(_.doTransform(args)).asJava
+ val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
+ ExpressionBuilder.makeScalarFunction(functionId, childNodes, typeNode)
+ }
+}
+
+trait LeafExpressionTransformer extends ExpressionTransformer {
+ final override def children: Seq[ExpressionTransformer] = Nil
+}
+
+trait UnaryExpressionTransformer extends ExpressionTransformer {
+ def child: ExpressionTransformer
+ final override def children: Seq[ExpressionTransformer] = child :: Nil
+}
+
+trait BinaryExpressionTransformer extends ExpressionTransformer {
+ def left: ExpressionTransformer
+ def right: ExpressionTransformer
+ final override def children: Seq[ExpressionTransformer] = left :: right ::
Nil
+}
+
+// ==== Expression transformer basic interface end ====
+
+case class GenericExpressionTransformer(
+ substraitExprName: String,
+ children: Seq[ExpressionTransformer],
+ original: Expression)
+ extends ExpressionTransformer
+
+object GenericExpressionTransformer {
+ def apply(
+ substraitExprName: String,
+ child: ExpressionTransformer,
+ original: Expression): GenericExpressionTransformer = {
+ GenericExpressionTransformer(substraitExprName, child :: Nil, original)
+ }
+}
+
+case class LiteralTransformer(original: Literal) extends
LeafExpressionTransformer {
+ override def substraitExprName: String = "literal"
+ override def doTransform(args: java.lang.Object): ExpressionNode = {
+ ExpressionBuilder.makeLiteral(original.value, original.dataType,
original.nullable)
+ }
+}
+object LiteralTransformer {
+ def apply(v: Any): LiteralTransformer = {
+ LiteralTransformer(Literal(v))
+ }
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/GenericExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/GenericExpressionTransformer.scala
deleted file mode 100644
index 8faf4965f..000000000
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/GenericExpressionTransformer.scala
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.expression
-
-import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
-
-import org.apache.spark.sql.catalyst.expressions._
-
-import com.google.common.collect.Lists
-
-case class GenericExpressionTransformer(
- substraitExprName: String,
- children: Seq[ExpressionTransformer],
- original: Expression)
- extends ExpressionTransformerWithOrigin {
- override def doTransform(args: Object): ExpressionNode = {
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(
- substraitExprName,
- original.children.map(_.dataType),
- FunctionConfig.OPT))
-
- val exprNodes = Lists.newArrayList[ExpressionNode]()
- children.foreach(expr => exprNodes.add(expr.doTransform(args)))
-
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, exprNodes, typeNode)
- }
-}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/HashExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/HashExpressionTransformer.scala
deleted file mode 100644
index 28f2dda01..000000000
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/HashExpressionTransformer.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.expression
-
-import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
-
-import org.apache.spark.sql.catalyst.expressions._
-
-case class HashExpressionTransformer(
- substraitExprName: String,
- exps: Seq[ExpressionTransformer],
- original: Expression)
- extends ExpressionTransformerWithOrigin {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val nodes = new java.util.ArrayList[ExpressionNode]()
- exps.foreach(
- expression => {
- nodes.add(expression.doTransform(args))
- })
- val childrenTypes = original.children.map(child => child.dataType)
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionName =
- ConverterUtils.makeFuncName(substraitExprName, childrenTypes,
FunctionConfig.OPT)
- val functionId = ExpressionBuilder.newScalarFunction(functionMap,
functionName)
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, nodes, typeNode)
- }
-}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala
index e8ff3d360..25e3e12a5 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala
@@ -28,7 +28,7 @@ case class JsonTupleExpressionTransformer(
substraitExprName: String,
children: Seq[ExpressionTransformer],
original: Expression)
- extends ExpressionTransformerWithOrigin {
+ extends ExpressionTransformer {
override def doTransform(args: Object): ExpressionNode = {
val jsonExpr = children.head
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala
index ce6d13a95..9e7285ac3 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala
@@ -17,7 +17,7 @@
package org.apache.gluten.expression
import org.apache.gluten.exception.GlutenNotSupportException
-import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
+import org.apache.gluten.substrait.expression.ExpressionNode
import org.apache.spark.sql.catalyst.expressions.LambdaFunction
@@ -25,27 +25,15 @@ case class LambdaFunctionTransformer(
substraitExprName: String,
function: ExpressionTransformer,
arguments: Seq[ExpressionTransformer],
- hidden: Boolean = false,
original: LambdaFunction)
- extends ExpressionTransformerWithOrigin {
+ extends ExpressionTransformer {
+ override def children: Seq[ExpressionTransformer] = function +: arguments
override def doTransform(args: Object): ExpressionNode = {
// Need to fallback when hidden be true as it's not supported in Velox
- if (hidden) {
+ if (original.hidden) {
throw new GlutenNotSupportException(s"Unsupported LambdaFunction with
hidden be true.")
}
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(
- substraitExprName,
- Seq(original.dataType),
- ConverterUtils.FunctionConfig.OPT))
- val expressionNodes = new java.util.ArrayList[ExpressionNode]
- expressionNodes.add(function.doTransform(args))
- arguments.foreach(argument =>
expressionNodes.add(argument.doTransform(args)))
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
+ super.doTransform(args)
}
-
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/LiteralTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/LiteralTransformer.scala
deleted file mode 100644
index 8fb9943d6..000000000
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/LiteralTransformer.scala
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.expression
-
-import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
-
-import org.apache.spark.sql.catalyst.expressions._
-
-case class LiteralTransformer(original: Literal) extends
ExpressionTransformerWithOrigin {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- ExpressionBuilder.makeLiteral(original.value, original.dataType,
original.nullable)
- }
-}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala
index c09afaebc..fe715979b 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala
@@ -18,53 +18,35 @@ package org.apache.gluten.expression
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
-import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
+import org.apache.gluten.substrait.expression.ExpressionNode
import org.apache.spark.sql.catalyst.expressions._
-import com.google.common.collect.Lists
-
case class CreateMapTransformer(
substraitExprName: String,
children: Seq[ExpressionTransformer],
- useStringTypeWhenEmpty: Boolean,
original: CreateMap)
- extends ExpressionTransformerWithOrigin {
+ extends ExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
// If children is empty,
// transformation is only supported when useStringTypeWhenEmpty is false
// because ClickHouse and Velox currently doesn't support this config.
- if (children.isEmpty && useStringTypeWhenEmpty) {
+ if (children.isEmpty && original.useStringTypeWhenEmpty) {
throw new GlutenNotSupportException(s"$original not supported yet.")
}
- val childNodes = new java.util.ArrayList[ExpressionNode]()
- children.foreach(
- child => {
- val childNode = child.doTransform(args)
- childNodes.add(childNode)
- })
-
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionName = ConverterUtils.makeFuncName(
- substraitExprName,
- original.children.map(_.dataType),
- FunctionConfig.OPT)
- val functionId = ExpressionBuilder.newScalarFunction(functionMap,
functionName)
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, childNodes, typeNode)
+ super.doTransform(args)
}
}
case class GetMapValueTransformer(
substraitExprName: String,
- child: ExpressionTransformer,
- key: ExpressionTransformer,
+ left: ExpressionTransformer,
+ right: ExpressionTransformer,
failOnError: Boolean,
original: GetMapValue)
- extends ExpressionTransformerWithOrigin {
+ extends BinaryExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
if (BackendsApiManager.getSettings.alwaysFailOnMapExpression()) {
@@ -75,17 +57,6 @@ case class GetMapValueTransformer(
throw new GlutenNotSupportException(s"$original not supported yet.")
}
- val childNode = child.doTransform(args)
- val keyNode = key.doTransform(args)
-
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionName = ConverterUtils.makeFuncName(
- substraitExprName,
- Seq(original.child.dataType, original.key.dataType),
- FunctionConfig.OPT)
- val functionId = ExpressionBuilder.newScalarFunction(functionMap,
functionName)
- val exprNodes = Lists.newArrayList(childNode, keyNode)
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, exprNodes, typeNode)
+ super.doTransform(args)
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala
index 2af4a5fa2..f4c703d88 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala
@@ -16,67 +16,29 @@
*/
package org.apache.gluten.expression
-import org.apache.gluten.expression.ConverterUtils.FunctionConfig
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types._
-
-import com.google.common.collect.Lists
case class AliasTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: Expression)
- extends ExpressionTransformerWithOrigin {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val childNode = child.doTransform(args)
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(
- substraitExprName,
- original.children.map(_.dataType),
- FunctionConfig.REQ))
- val expressionNodes = Lists.newArrayList(childNode)
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
- }
-}
+ extends UnaryExpressionTransformer {}
-case class NamedLambdaVariableTransformer(
+case class AttributeReferenceTransformer(
substraitExprName: String,
- name: String,
- dataType: DataType,
- nullable: Boolean,
- exprId: ExprId)
- extends ExpressionTransformer {
- override def doTransform(args: Object): ExpressionNode = {
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val namedLambdaVarFunctionName =
- ConverterUtils.makeFuncName(substraitExprName, Seq(dataType),
FunctionConfig.OPT)
- val arrayAggFunctionId =
- ExpressionBuilder.newScalarFunction(functionMap,
namedLambdaVarFunctionName)
- val exprNodes = Lists.newArrayList(
- ExpressionBuilder.makeLiteral(name, StringType,
false).asInstanceOf[ExpressionNode])
- val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
- // namedlambdavariable('acc')-> <Integer, notnull>
- ExpressionBuilder.makeScalarFunction(arrayAggFunctionId, exprNodes,
typeNode)
+ original: AttributeReference,
+ bound: BoundReference)
+ extends LeafExpressionTransformer {
+ override def doTransform(args: java.lang.Object): ExpressionNode = {
+
ExpressionBuilder.makeSelection(bound.ordinal.asInstanceOf[java.lang.Integer])
}
}
-case class AttributeReferenceTransformer(
- name: String,
- ordinal: Int,
- dataType: DataType,
- nullable: Boolean = true,
- exprId: ExprId,
- qualifier: Seq[String],
- metadata: Metadata = Metadata.empty)
- extends ExpressionTransformer {
-
+case class BoundReferenceTransformer(substraitExprName: String, original:
BoundReference)
+ extends LeafExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
- ExpressionBuilder.makeSelection(ordinal.asInstanceOf[java.lang.Integer])
+
ExpressionBuilder.makeSelection(original.ordinal.asInstanceOf[java.lang.Integer])
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala
index 7d34466e5..d13c61d64 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala
@@ -16,39 +16,33 @@
*/
package org.apache.gluten.expression
-import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.expression.ConverterUtils.FunctionConfig
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-import com.google.common.collect.Lists
-
import scala.collection.JavaConverters._
-case class InTransformer(
- value: ExpressionTransformer,
- list: Seq[Expression],
- valueType: DataType,
- original: Expression)
- extends ExpressionTransformerWithOrigin {
+case class InTransformer(substraitExprName: String, child:
ExpressionTransformer, original: In)
+ extends UnaryExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
- assert(list.forall(_.foldable))
+ assert(original.list.forall(_.foldable))
// Stores the values in a List Literal.
- val values: Set[Any] = list.map(_.eval()).toSet
- InExpressionTransformer.toTransformer(value.doTransform(args), values,
valueType)
+ val values: Set[Any] = original.list.map(_.eval()).toSet
+ InExpressionTransformer.toTransformer(child.doTransform(args), values,
child.dataType)
}
}
case class InSetTransformer(
- value: ExpressionTransformer,
- hset: Set[Any],
- valueType: DataType,
- original: Expression)
- extends ExpressionTransformerWithOrigin {
+ substraitExprName: String,
+ child: ExpressionTransformer,
+ original: InSet)
+ extends UnaryExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
- InExpressionTransformer.toTransformer(value.doTransform(args), hset,
valueType)
+ InExpressionTransformer.toTransformer(
+ child.doTransform(args),
+ original.hset,
+ original.child.dataType)
}
}
@@ -69,60 +63,12 @@ object InExpressionTransformer {
}
}
-case class LikeTransformer(
- substraitExprName: String,
- left: ExpressionTransformer,
- right: ExpressionTransformer,
- original: Expression)
- extends ExpressionTransformerWithOrigin {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val leftNode = left.doTransform(args)
- val rightNode = right.doTransform(args)
- val escapeCharNode = ExpressionBuilder.makeLiteral(
- original.asInstanceOf[Like].escapeChar.toString,
- StringType,
- false)
-
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(
- substraitExprName,
- original.children.map(_.dataType),
- FunctionConfig.OPT))
-
- // CH backend does not support escapeChar, so skip it here.
- val expressionNodes =
- BackendsApiManager.getTransformerApiInstance.createLikeParamList(
- leftNode,
- rightNode,
- escapeCharNode)
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId,
expressionNodes.toList.asJava, typeNode)
- }
-}
-
case class DecimalArithmeticExpressionTransformer(
substraitExprName: String,
left: ExpressionTransformer,
right: ExpressionTransformer,
resultType: DecimalType,
original: Expression)
- extends ExpressionTransformerWithOrigin {
+ extends BinaryExpressionTransformer {
override def dataType: DataType = resultType
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val leftNode = left.doTransform(args)
- val rightNode = right.doTransform(args)
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(
- substraitExprName,
- original.children.map(_.dataType),
- FunctionConfig.OPT))
-
- val expressionNodes = Lists.newArrayList(leftNode, rightNode)
- val typeNode = ConverterUtils.getTypeNode(resultType, original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
- }
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala
index 4f5a43d47..0accf9ffd 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala
@@ -19,10 +19,10 @@ package org.apache.gluten.expression
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.{BaseSubqueryExec, ScalarSubquery}
+import org.apache.spark.sql.execution.ScalarSubquery
-case class ScalarSubqueryTransformer(plan: BaseSubqueryExec, exprId: ExprId,
query: ScalarSubquery)
- extends ExpressionTransformerWithOrigin {
+case class ScalarSubqueryTransformer(substraitExprName: String, query:
ScalarSubquery)
+ extends LeafExpressionTransformer {
override def original: Expression = query
override def doTransform(args: java.lang.Object): ExpressionNode = {
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/StringExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/StringExpressionTransformer.scala
deleted file mode 100644
index b31d66b68..000000000
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/StringExpressionTransformer.scala
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.expression
-
-import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.substrait.expression._
-
-import org.apache.spark.sql.catalyst.expressions._
-
-import com.google.common.collect.Lists
-
-case class String2TrimExpressionTransformer(
- substraitExprName: String,
- trimStr: Option[ExpressionTransformer],
- srcStr: ExpressionTransformer,
- original: Expression)
- extends ExpressionTransformerWithOrigin {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val trimStrNode = trimStr.map(_.doTransform(args))
- val srcStrNode = srcStr.doTransform(args)
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionName =
- ConverterUtils.makeFuncName(
- substraitExprName,
- original.children.map(_.dataType),
- FunctionConfig.REQ)
- val functionId = ExpressionBuilder.newScalarFunction(functionMap,
functionName)
- val expressNodes = Lists.newArrayList[ExpressionNode]()
- trimStrNode.foreach(expressNodes.add)
- expressNodes.add(srcStrNode)
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, expressNodes, typeNode)
- }
-}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/StructExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/StructExpressionTransformer.scala
deleted file mode 100644
index 616971b6d..000000000
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/StructExpressionTransformer.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.expression
-
-import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode, StructLiteralNode}
-
-import org.apache.spark.sql.catalyst.expressions.GetStructField
-import org.apache.spark.sql.types.IntegerType
-
-import com.google.common.collect.Lists
-
-case class GetStructFieldTransformer(
- substraitExprName: String,
- childTransformer: ExpressionTransformer,
- ordinal: Int,
- original: GetStructField)
- extends ExpressionTransformerWithOrigin {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val childNode = childTransformer.doTransform(args)
- childNode match {
- case node: StructLiteralNode =>
- return node.getFieldLiteral(ordinal)
- case _ =>
- }
-
- val ordinalNode = ExpressionBuilder.makeLiteral(ordinal, IntegerType,
false)
- val exprNodes = Lists.newArrayList(childNode, ordinalNode)
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val fieldDataType = original.dataType
- val functionName = ConverterUtils.makeFuncName(
- substraitExprName,
- Seq(original.child.dataType, fieldDataType),
- FunctionConfig.OPT)
- val functionId = ExpressionBuilder.newScalarFunction(functionMap,
functionName)
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, exprNodes, typeNode)
- }
-}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/TimestampAddTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/TimestampAddTransformer.scala
deleted file mode 100644
index acede4523..000000000
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/TimestampAddTransformer.scala
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.expression
-
-import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
-
-import org.apache.spark.sql.types.DataType
-
-import com.google.common.collect.Lists
-
-case class TimestampAddTransformer(
- substraitExprName: String,
- unit: String,
- left: ExpressionTransformer,
- right: ExpressionTransformer,
- timeZoneId: String,
- dataType: DataType,
- nullable: Boolean)
- extends ExpressionTransformer {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val leftNode = left.doTransform(args)
- val rightNode = right.doTransform(args)
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(substraitExprName, Seq(), FunctionConfig.REQ)
- )
-
- val expressionNodes = Lists.newArrayList(
- ExpressionBuilder.makeStringLiteral(unit),
- leftNode,
- rightNode,
- ExpressionBuilder.makeStringLiteral(timeZoneId))
- val outputType = ConverterUtils.getTypeNode(dataType, nullable)
- ExpressionBuilder.makeScalarFunction(functionId, expressionNodes,
outputType)
- }
-}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
index d0ac19b4a..27f839525 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
@@ -18,30 +18,29 @@ package org.apache.gluten.expression
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
-import org.apache.gluten.expression.ConverterUtils.FunctionConfig
import org.apache.gluten.substrait.`type`.ListNode
import org.apache.gluten.substrait.`type`.MapNode
-import org.apache.gluten.substrait.expression.{BooleanLiteralNode,
ExpressionBuilder, ExpressionNode, IntLiteralNode}
+import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode, StructLiteralNode}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import com.google.common.collect.Lists
-case class ChildTransformer(child: ExpressionTransformer) extends
ExpressionTransformer {
+case class ChildTransformer(
+ substraitExprName: String,
+ child: ExpressionTransformer,
+ original: Expression)
+ extends UnaryExpressionTransformer {
+ override def dataType: DataType = child.dataType
+
override def doTransform(args: java.lang.Object): ExpressionNode = {
child.doTransform(args)
}
- override def dataType: DataType = child.dataType
}
-case class CastTransformer(
- child: ExpressionTransformer,
- dataType: DataType,
- timeZoneId: Option[String],
- original: Cast)
- extends ExpressionTransformer {
-
+case class CastTransformer(substraitExprName: String, child:
ExpressionTransformer, original: Cast)
+ extends UnaryExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val typeNode = ConverterUtils.getTypeNode(dataType, original.nullable)
ExpressionBuilder.makeCast(typeNode, child.doTransform(args),
original.ansiEnabled)
@@ -52,7 +51,7 @@ case class ExplodeTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: Explode)
- extends ExpressionTransformerWithOrigin {
+ extends UnaryExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val childNode: ExpressionNode = child.doTransform(args)
@@ -75,123 +74,23 @@ case class ExplodeTransformer(
}
}
-case class PosExplodeTransformer(
- substraitExprName: String,
- child: ExpressionTransformer,
- original: PosExplode,
- attributeSeq: Seq[Attribute])
- extends ExpressionTransformerWithOrigin {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val childNode: ExpressionNode = child.doTransform(args)
-
- // sequence(1, size(array_or_map))
- val startExpr = new Literal(1, IntegerType)
- val stopExpr = new Size(Size(original.child, false))
- val stepExpr = new Literal(1, IntegerType)
- val sequenceExpr = new Sequence(startExpr, stopExpr, stepExpr)
- val sequenceExprNode = ExpressionConverter
- .replaceWithExpressionTransformer(sequenceExpr, attributeSeq)
- .doTransform(args)
-
- val funcMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
-
- val mapFromArraysFuncId = ExpressionBuilder.newScalarFunction(
- funcMap,
- ConverterUtils.makeFuncName(
- ExpressionNames.MAP_FROM_ARRAYS,
- Seq(sequenceExpr.dataType, original.child.dataType),
- FunctionConfig.OPT))
-
- val keyType = IntegerType
- val (valType, valContainsNull) = original.child.dataType match {
- case a: ArrayType => (a.elementType, a.containsNull)
- case _ =>
- throw new GlutenNotSupportException(
- s"posexplode(${original.child.dataType}) not supported yet.")
- }
- val outputType = MapType(keyType, valType, valContainsNull)
- val mapFromArraysExprNode = ExpressionBuilder.makeScalarFunction(
- mapFromArraysFuncId,
- Lists.newArrayList(sequenceExprNode, childNode),
- ConverterUtils.getTypeNode(outputType, original.child.nullable))
-
- // posexplode(map_from_arrays(sequence(1, size(array_or_map)),
array_or_map))
- val funcId = ExpressionBuilder.newScalarFunction(
- funcMap,
- ConverterUtils.makeFuncName(ExpressionNames.POSEXPLODE, Seq(outputType),
FunctionConfig.OPT))
-
- val childType = original.child.dataType
- childType match {
- case a: ArrayType =>
- // Output pos, col when input is array
- val structType = StructType(
- Array(
- StructField("pos", IntegerType, false),
- StructField("col", a.elementType, a.containsNull)))
- ExpressionBuilder.makeScalarFunction(
- funcId,
- Lists.newArrayList(mapFromArraysExprNode),
- ConverterUtils.getTypeNode(structType, false))
- case m: MapType =>
- // Output pos, key, value when input is map
- val structType = StructType(
- Array(
- StructField("pos", IntegerType, false),
- StructField("key", m.keyType, false),
- StructField("value", m.valueType, m.valueContainsNull)))
- ExpressionBuilder.makeScalarFunction(
- funcId,
- Lists.newArrayList(mapFromArraysExprNode),
- ConverterUtils.getTypeNode(structType, false))
- case _ =>
- throw new GlutenNotSupportException(s"posexplode($childType) not
supported yet.")
- }
- }
-}
-
case class CheckOverflowTransformer(
substraitExprName: String,
child: ExpressionTransformer,
- childResultType: DataType,
original: CheckOverflow)
- extends ExpressionTransformerWithOrigin {
-
+ extends UnaryExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
BackendsApiManager.getTransformerApiInstance.createCheckOverflowExprNode(
args,
substraitExprName,
child.doTransform(args),
- childResultType,
+ original.child.dataType,
original.dataType,
original.nullable,
original.nullOnOverflow)
}
}
-case class MakeDecimalTransformer(
- substraitExprName: String,
- child: ExpressionTransformer,
- original: MakeDecimal)
- extends ExpressionTransformerWithOrigin {
-
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val childNode = child.doTransform(args)
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(
- substraitExprName,
- Seq(original.dataType, BooleanType),
- FunctionConfig.OPT))
-
- val expressionNodes =
- Lists.newArrayList(childNode, new
BooleanLiteralNode(original.nullOnOverflow))
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
- }
-}
-
/**
* User can specify a seed for this function. If lacked, spark will generate a
random number as
* seed. We also need to pass a unique partitionIndex provided by framework to
native library for
@@ -203,43 +102,32 @@ case class RandTransformer(
substraitExprName: String,
explicitSeed: ExpressionTransformer,
original: Rand)
- extends ExpressionTransformerWithOrigin {
+ extends LeafExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
if (!original.hideSeed) {
// TODO: for user-specified seed, we need to pass partition index to
native engine.
throw new GlutenNotSupportException("User-specified seed is not
supported.")
}
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(substraitExprName,
Seq(original.child.dataType)))
- val inputNodes = Lists.newArrayList[ExpressionNode]()
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, inputNodes, typeNode)
+ super.doTransform(args)
}
}
-case class GetArrayStructFieldsTransformer(
+case class GetStructFieldTransformer(
substraitExprName: String,
child: ExpressionTransformer,
- ordinal: Int,
- numFields: Int,
- containsNull: Boolean,
- original: GetArrayStructFields)
- extends ExpressionTransformerWithOrigin {
+ original: GetStructField)
+ extends BinaryExpressionTransformer {
+ override def left: ExpressionTransformer = child
+ override def right: ExpressionTransformer =
LiteralTransformer(original.ordinal)
override def doTransform(args: java.lang.Object): ExpressionNode = {
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(
- substraitExprName,
- Seq(original.child.dataType, IntegerType),
- FunctionConfig.OPT))
- val inputNodes =
- Lists.newArrayList(child.doTransform(args), new IntLiteralNode(ordinal))
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, inputNodes, typeNode)
+ val childNode = child.doTransform(args)
+ childNode match {
+ case node: StructLiteralNode =>
+ node.getFieldLiteral(original.ordinal)
+ case _ =>
+ super.doTransform(args)
+ }
}
}
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala
index c27159ceb..f6ff0ff45 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala
@@ -17,38 +17,16 @@
package org.apache.spark.sql.extension
import org.apache.gluten.expression._
-import org.apache.gluten.expression.ConverterUtils.FunctionConfig
import org.apache.gluten.extension.ExpressionExtensionTrait
-import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
-import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import com.google.common.collect.Lists
-
case class CustomAddExpressionTransformer(
substraitExprName: String,
left: ExpressionTransformer,
right: ExpressionTransformer,
original: Expression)
- extends ExpressionTransformerWithOrigin
- with Logging {
- override def doTransform(args: java.lang.Object): ExpressionNode = {
- val leftNode = left.doTransform(args)
- val rightNode = right.doTransform(args)
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(
- substraitExprName,
- original.children.map(_.dataType),
- FunctionConfig.OPT))
-
- val expressionNodes = Lists.newArrayList(leftNode, rightNode)
- val typeNode = ConverterUtils.getTypeNode(original.dataType,
original.nullable)
- ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
- }
-}
+ extends BinaryExpressionTransformer
case class CustomerExpressionTransformer() extends ExpressionExtensionTrait {
@@ -65,7 +43,7 @@ case class CustomerExpressionTransformer() extends
ExpressionExtensionTrait {
expr: Expression,
attributeSeq: Seq[Attribute]): ExpressionTransformer = expr match {
case custom: CustomAdd =>
- new CustomAddExpressionTransformer(
+ CustomAddExpressionTransformer(
substraitExprName,
ExpressionConverter.replaceWithExpressionTransformer(custom.left,
attributeSeq),
ExpressionConverter.replaceWithExpressionTransformer(custom.right,
attributeSeq),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]