This is an automated email from the ASF dual-hosted git repository.
yao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new a4b69a2f1 [CORE] Add decimal precision tests (#5752)
a4b69a2f1 is described below
commit a4b69a2f141bbc9eede669cc337112845cd712ef
Author: Xiduo You <[email protected]>
AuthorDate: Fri May 17 14:06:03 2024 +0800
[CORE] Add decimal precision tests (#5752)
* Add decimal precision tests
* fix ck test
* fix
* fix
---------
Co-authored-by: Kent Yao <[email protected]>
---
.../expression/CHExpressionTransformer.scala | 10 +-
.../gluten/expression/ExpressionTransformer.scala | 10 +-
.../apache/spark/sql/expression/UDFResolver.scala | 32 +++--
.../gluten/backendsapi/SparkPlanExecApi.scala | 8 --
.../expression/ArrayExpressionTransformer.scala | 4 +-
.../gluten/expression/ConditionalTransformer.scala | 4 +-
.../DateTimeExpressionsTransformer.scala | 10 +-
.../gluten/expression/ExpressionConverter.scala | 7 --
.../gluten/expression/ExpressionTransformer.scala | 9 ++
.../expression/GenericExpressionTransformer.scala | 2 +-
.../expression/HashExpressionTransformer.scala | 2 +-
.../JsonTupleExpressionTransformer.scala | 2 +-
.../expression/LambdaFunctionTransformer.scala | 2 +-
.../gluten/expression/LiteralTransformer.scala | 4 +-
.../expression/MapExpressionTransformer.scala | 4 +-
.../expression/NamedExpressionsTransformer.scala | 2 +-
.../PredicateExpressionTransformer.scala | 9 +-
.../expression/ScalarSubqueryTransformer.scala | 3 +-
.../expression/StringExpressionTransformer.scala | 2 +-
.../expression/StructExpressionTransformer.scala | 2 +-
.../expression/UnaryExpressionTransformer.scala | 17 +--
.../utils/clickhouse/ClickHouseTestSettings.scala | 1 +
.../gluten/utils/velox/VeloxTestSettings.scala | 1 +
.../expressions/GlutenDecimalPrecisionSuite.scala | 138 +++++++++++++++++++++
.../extension/CustomerExpressionTransformer.scala | 4 +-
.../utils/clickhouse/ClickHouseTestSettings.scala | 1 +
.../gluten/utils/velox/VeloxTestSettings.scala | 3 +-
.../expressions/GlutenDecimalPrecisionSuite.scala | 138 +++++++++++++++++++++
.../utils/clickhouse/ClickHouseTestSettings.scala | 1 +
.../gluten/utils/velox/VeloxTestSettings.scala | 3 +-
.../expressions/GlutenDecimalPrecisionSuite.scala | 138 +++++++++++++++++++++
.../utils/clickhouse/ClickHouseTestSettings.scala | 1 +
.../gluten/utils/velox/VeloxTestSettings.scala | 3 +-
.../expressions/GlutenDecimalPrecisionSuite.scala | 138 +++++++++++++++++++++
34 files changed, 639 insertions(+), 76 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala
index 7d9dbaddc..98cc4a930 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala
@@ -34,7 +34,7 @@ case class CHSizeExpressionTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: Size)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
// Pass legacyLiteral as second argument in substrait function
@@ -51,7 +51,7 @@ case class CHTruncTimestampTransformer(
timestamp: ExpressionTransformer,
timeZoneId: Option[String] = None,
original: TruncTimestamp)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
// The format must be constant string in the function date_trunc of ch.
@@ -126,7 +126,7 @@ case class CHStringTranslateTransformer(
matchingExpr: ExpressionTransformer,
replaceExpr: ExpressionTransformer,
original: StringTranslate)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
// In CH, translateUTF8 requires matchingExpr and replaceExpr argument
have the same length
@@ -158,7 +158,7 @@ case class CHPosExplodeTransformer(
child: ExpressionTransformer,
original: PosExplode,
attributeSeq: Seq[Attribute])
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val childNode: ExpressionNode = child.doTransform(args)
@@ -202,7 +202,7 @@ case class CHRegExpReplaceTransformer(
substraitExprName: String,
children: Seq[ExpressionTransformer],
original: RegExpReplace)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
// In CH: replaceRegexpAll(subject, regexp, rep), which is equivalent
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
index 75a2c3a62..da8433fa2 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
@@ -35,7 +35,7 @@ case class VeloxAliasTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: Expression)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
child.doTransform(args)
@@ -46,7 +46,7 @@ case class VeloxNamedStructTransformer(
substraitExprName: String,
original: CreateNamedStruct,
attributeSeq: Seq[Attribute])
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: Object): ExpressionNode = {
val expressionNodes = Lists.newArrayList[ExpressionNode]()
original.valExprs.foreach(
@@ -67,7 +67,7 @@ case class VeloxGetStructFieldTransformer(
childTransformer: ExpressionTransformer,
ordinal: Int,
original: GetStructField)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: Object): ExpressionNode = {
val childNode = childTransformer.doTransform(args)
childNode match {
@@ -86,7 +86,7 @@ case class VeloxHashExpressionTransformer(
substraitExprName: String,
exps: Seq[ExpressionTransformer],
original: Expression)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
// As of Spark 3.3, there are 3 kinds of HashExpression.
// HiveHash is not supported in native backend and will fail native
validation.
@@ -121,7 +121,7 @@ case class VeloxStringSplitTransformer(
regexExpr: ExpressionTransformer,
limitExpr: ExpressionTransformer,
original: StringSplit)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
if (
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
index bdfd24ed5..847e5a2e6 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
@@ -20,7 +20,7 @@ import
org.apache.gluten.backendsapi.velox.VeloxBackendSettings
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.expression.{ConverterUtils, ExpressionTransformer,
ExpressionType, Transformable}
import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.substrait.expression.ExpressionBuilder
+import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode}
import org.apache.gluten.udf.UdfJniWrapper
import org.apache.gluten.vectorized.JniWorkspace
@@ -110,18 +110,24 @@ case class UDFExpression(
this.getClass.getSimpleName +
": getTransformer called before children transformer initialized.")
}
- (args: Object) => {
- val transformers = childrenTransformers.map(_.doTransform(args))
- val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- val functionId = ExpressionBuilder.newScalarFunction(
- functionMap,
- ConverterUtils.makeFuncName(name, children.map(_.dataType),
FunctionConfig.REQ))
-
- val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
- ExpressionBuilder.makeScalarFunction(
- functionId,
- Lists.newArrayList(transformers: _*),
- typeNode)
+
+ val localDataType = dataType
+ new ExpressionTransformer {
+ override def doTransform(args: Object): ExpressionNode = {
+ val transformers = childrenTransformers.map(_.doTransform(args))
+ val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
+ val functionId = ExpressionBuilder.newScalarFunction(
+ functionMap,
+ ConverterUtils.makeFuncName(name, children.map(_.dataType),
FunctionConfig.REQ))
+
+ val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
+ ExpressionBuilder.makeScalarFunction(
+ functionId,
+ Lists.newArrayList(transformers: _*),
+ typeNode)
+ }
+
+ override def dataType: DataType = localDataType
}
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 8df74bb88..aa27d1ce1 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -451,14 +451,6 @@ trait SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, children, original)
}
- def genEqualNullSafeTransformer(
- substraitExprName: String,
- left: ExpressionTransformer,
- right: ExpressionTransformer,
- original: EqualNullSafe): ExpressionTransformer = {
- GenericExpressionTransformer(substraitExprName, Seq(left, right), original)
- }
-
def genMd5Transformer(
substraitExprName: String,
child: ExpressionTransformer,
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala
index 85a1f58fb..68a464f13 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala
@@ -33,7 +33,7 @@ case class CreateArrayTransformer(
children: Seq[ExpressionTransformer],
useStringTypeWhenEmpty: Boolean,
original: CreateArray)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
// If children is empty,
@@ -62,7 +62,7 @@ case class GetArrayItemTransformer(
right: ExpressionTransformer,
failOnError: Boolean,
original: GetArrayItem)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
// Ignore failOnError for clickhouse backend
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala
index 18a46d7ca..0fdd68511 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala
@@ -27,7 +27,7 @@ case class CaseWhenTransformer(
branches: Seq[(ExpressionTransformer, ExpressionTransformer)],
elseValue: Option[ExpressionTransformer],
original: Expression)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
// generate branches nodes
@@ -52,7 +52,7 @@ case class IfTransformer(
trueValue: ExpressionTransformer,
falseValue: ExpressionTransformer,
original: Expression)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val ifNodes = new JArrayList[ExpressionNode]
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/DateTimeExpressionsTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/DateTimeExpressionsTransformer.scala
index 797dc81d3..66004291a 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/DateTimeExpressionsTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/DateTimeExpressionsTransformer.scala
@@ -36,7 +36,7 @@ case class ExtractDateTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: Expression)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val childNode = child.doTransform(args)
@@ -65,7 +65,7 @@ case class DateDiffTransformer(
endDate: ExpressionTransformer,
startDate: ExpressionTransformer,
original: DateDiff)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val endDateNode = endDate.doTransform(args)
@@ -99,7 +99,7 @@ case class ToUnixTimestampTransformer(
timeZoneId: Option[String],
failOnError: Boolean,
original: ToUnixTimestamp)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val dataTypes = Seq(original.timeExp.dataType, StringType)
@@ -124,7 +124,7 @@ case class TruncTimestampTransformer(
timestamp: ExpressionTransformer,
timeZoneId: Option[String] = None,
original: TruncTimestamp)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val timestampNode = timestamp.doTransform(args)
@@ -160,7 +160,7 @@ case class MonthsBetweenTransformer(
roundOff: ExpressionTransformer,
timeZoneId: Option[String] = None,
original: MonthsBetween)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val date1Node = date1.doTransform(args)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index b7b946268..e692890c4 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -415,13 +415,6 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
),
r
)
- case equal: EqualNullSafe =>
-
BackendsApiManager.getSparkPlanExecApiInstance.genEqualNullSafeTransformer(
- substraitExprName,
- replaceWithExpressionTransformerInternal(equal.left, attributeSeq,
expressionsMap),
- replaceWithExpressionTransformerInternal(equal.right, attributeSeq,
expressionsMap),
- equal
- )
case md5: Md5 =>
BackendsApiManager.getSparkPlanExecApiInstance.genMd5Transformer(
substraitExprName,
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
index 65badcbae..6b6587862 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala
@@ -18,6 +18,15 @@ package org.apache.gluten.expression
import org.apache.gluten.substrait.expression.ExpressionNode
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.types.DataType
+
trait ExpressionTransformer {
def doTransform(args: java.lang.Object): ExpressionNode
+ def dataType: DataType
+}
+
+trait ExpressionTransformerWithOrigin extends ExpressionTransformer {
+ def original: Expression
+ def dataType: DataType = original.dataType
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/GenericExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/GenericExpressionTransformer.scala
index 62afcad28..8faf4965f 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/GenericExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/GenericExpressionTransformer.scala
@@ -27,7 +27,7 @@ case class GenericExpressionTransformer(
substraitExprName: String,
children: Seq[ExpressionTransformer],
original: Expression)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: Object): ExpressionNode = {
val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
val functionId = ExpressionBuilder.newScalarFunction(
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/HashExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/HashExpressionTransformer.scala
index d813f8250..28f2dda01 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/HashExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/HashExpressionTransformer.scala
@@ -25,7 +25,7 @@ case class HashExpressionTransformer(
substraitExprName: String,
exps: Seq[ExpressionTransformer],
original: Expression)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val nodes = new java.util.ArrayList[ExpressionNode]()
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala
index 25e3e12a5..e8ff3d360 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala
@@ -28,7 +28,7 @@ case class JsonTupleExpressionTransformer(
substraitExprName: String,
children: Seq[ExpressionTransformer],
original: Expression)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: Object): ExpressionNode = {
val jsonExpr = children.head
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala
index 492de2b76..ce6d13a95 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala
@@ -27,7 +27,7 @@ case class LambdaFunctionTransformer(
arguments: Seq[ExpressionTransformer],
hidden: Boolean = false,
original: LambdaFunction)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: Object): ExpressionNode = {
// Need to fallback when hidden be true as it's not supported in Velox
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/LiteralTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/LiteralTransformer.scala
index 05787858e..8fb9943d6 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/LiteralTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/LiteralTransformer.scala
@@ -20,9 +20,9 @@ import
org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode
import org.apache.spark.sql.catalyst.expressions._
-case class LiteralTransformer(lit: Literal) extends ExpressionTransformer {
+case class LiteralTransformer(original: Literal) extends
ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
- ExpressionBuilder.makeLiteral(lit.value, lit.dataType, lit.nullable)
+ ExpressionBuilder.makeLiteral(original.value, original.dataType,
original.nullable)
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala
index e136f1b3a..c09afaebc 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala
@@ -30,7 +30,7 @@ case class CreateMapTransformer(
children: Seq[ExpressionTransformer],
useStringTypeWhenEmpty: Boolean,
original: CreateMap)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
// If children is empty,
@@ -64,7 +64,7 @@ case class GetMapValueTransformer(
key: ExpressionTransformer,
failOnError: Boolean,
original: GetMapValue)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
if (BackendsApiManager.getSettings.alwaysFailOnMapExpression()) {
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala
index 70ad13584..2af4a5fa2 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala
@@ -28,7 +28,7 @@ case class AliasTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: Expression)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val childNode = child.doTransform(args)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala
index dfa4ceed6..7d34466e5 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala
@@ -32,7 +32,7 @@ case class InTransformer(
list: Seq[Expression],
valueType: DataType,
original: Expression)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
assert(list.forall(_.foldable))
// Stores the values in a List Literal.
@@ -46,7 +46,7 @@ case class InSetTransformer(
hset: Set[Any],
valueType: DataType,
original: Expression)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
InExpressionTransformer.toTransformer(value.doTransform(args), hset,
valueType)
}
@@ -74,7 +74,7 @@ case class LikeTransformer(
left: ExpressionTransformer,
right: ExpressionTransformer,
original: Expression)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val leftNode = left.doTransform(args)
val rightNode = right.doTransform(args)
@@ -108,7 +108,8 @@ case class DecimalArithmeticExpressionTransformer(
right: ExpressionTransformer,
resultType: DecimalType,
original: Expression)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
+ override def dataType: DataType = resultType
override def doTransform(args: java.lang.Object): ExpressionNode = {
val leftNode = left.doTransform(args)
val rightNode = right.doTransform(args)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala
index 534bde3b3..4f5a43d47 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{BaseSubqueryExec, ScalarSubquery}
case class ScalarSubqueryTransformer(plan: BaseSubqueryExec, exprId: ExprId,
query: ScalarSubquery)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
+ override def original: Expression = query
override def doTransform(args: java.lang.Object): ExpressionNode = {
// don't trigger collect when in validation phase
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/StringExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/StringExpressionTransformer.scala
index da021be24..b31d66b68 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/StringExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/StringExpressionTransformer.scala
@@ -28,7 +28,7 @@ case class String2TrimExpressionTransformer(
trimStr: Option[ExpressionTransformer],
srcStr: ExpressionTransformer,
original: Expression)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val trimStrNode = trimStr.map(_.doTransform(args))
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/StructExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/StructExpressionTransformer.scala
index c70395a7d..616971b6d 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/StructExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/StructExpressionTransformer.scala
@@ -29,7 +29,7 @@ case class GetStructFieldTransformer(
childTransformer: ExpressionTransformer,
ordinal: Int,
original: GetStructField)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val childNode = childTransformer.doTransform(args)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
index 2d3840ce4..d0ac19b4a 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
@@ -32,17 +32,18 @@ case class ChildTransformer(child: ExpressionTransformer)
extends ExpressionTran
override def doTransform(args: java.lang.Object): ExpressionNode = {
child.doTransform(args)
}
+ override def dataType: DataType = child.dataType
}
case class CastTransformer(
child: ExpressionTransformer,
- datatype: DataType,
+ dataType: DataType,
timeZoneId: Option[String],
original: Cast)
extends ExpressionTransformer {
override def doTransform(args: java.lang.Object): ExpressionNode = {
- val typeNode = ConverterUtils.getTypeNode(datatype, original.nullable)
+ val typeNode = ConverterUtils.getTypeNode(dataType, original.nullable)
ExpressionBuilder.makeCast(typeNode, child.doTransform(args),
original.ansiEnabled)
}
}
@@ -51,7 +52,7 @@ case class ExplodeTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: Explode)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val childNode: ExpressionNode = child.doTransform(args)
@@ -79,7 +80,7 @@ case class PosExplodeTransformer(
child: ExpressionTransformer,
original: PosExplode,
attributeSeq: Seq[Attribute])
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val childNode: ExpressionNode = child.doTransform(args)
@@ -154,7 +155,7 @@ case class CheckOverflowTransformer(
child: ExpressionTransformer,
childResultType: DataType,
original: CheckOverflow)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
BackendsApiManager.getTransformerApiInstance.createCheckOverflowExprNode(
@@ -172,7 +173,7 @@ case class MakeDecimalTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: MakeDecimal)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val childNode = child.doTransform(args)
@@ -202,7 +203,7 @@ case class RandTransformer(
substraitExprName: String,
explicitSeed: ExpressionTransformer,
original: Rand)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
if (!original.hideSeed) {
@@ -226,7 +227,7 @@ case class GetArrayStructFieldsTransformer(
numFields: Int,
containsNull: Boolean,
original: GetArrayStructFields)
- extends ExpressionTransformer {
+ extends ExpressionTransformerWithOrigin {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index bc0410834..afc427cd3 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -756,6 +756,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeGlutenTest("to_unix_timestamp")
.excludeGlutenTest("Hour")
enableSuite[GlutenDecimalExpressionSuite]
+ enableSuite[GlutenDecimalPrecisionSuite]
enableSuite[GlutenHashExpressionsSuite]
.exclude("sha2")
.exclude("murmur3/xxHash64/hive hash:
struct<null:void,boolean:boolean,byte:tinyint,short:smallint,int:int,long:bigint,float:float,double:double,bigDecimal:decimal(38,18),smallDecimal:decimal(10,0),string:string,binary:binary,date:date,timestamp:timestamp,udt:examplepoint>")
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index 366796a57..5e3591203 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -226,6 +226,7 @@ class VeloxTestSettings extends BackendTestSettings {
// Replaced by a gluten test to pass timezone through config.
.exclude("from_unixtime")
enableSuite[GlutenDecimalExpressionSuite]
+ enableSuite[GlutenDecimalPrecisionSuite]
enableSuite[GlutenStringFunctionsSuite]
enableSuite[GlutenRegexpExpressionsSuite]
enableSuite[GlutenNullExpressionsSuite]
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala
new file mode 100644
index 000000000..97e752d7d
--- /dev/null
+++
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.gluten.expression._
+
+import org.apache.spark.sql.GlutenTestsTrait
+import org.apache.spark.sql.catalyst.analysis.{Analyzer,
EmptyFunctionRegistry, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.sql.types._
+
+class GlutenDecimalPrecisionSuite extends GlutenTestsTrait {
+ private val catalog = new SessionCatalog(new InMemoryCatalog,
EmptyFunctionRegistry)
+ private val analyzer = new Analyzer(catalog)
+
+ private val relation = LocalRelation(
+ AttributeReference("i", IntegerType)(),
+ AttributeReference("d1", DecimalType(2, 1))(),
+ AttributeReference("d2", DecimalType(5, 2))(),
+ AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
+ AttributeReference("f", FloatType)(),
+ AttributeReference("b", DoubleType)()
+ )
+
+ private val i: Expression = UnresolvedAttribute("i")
+ private val d1: Expression = UnresolvedAttribute("d1")
+ private val d2: Expression = UnresolvedAttribute("d2")
+ private val u: Expression = UnresolvedAttribute("u")
+ private val f: Expression = UnresolvedAttribute("f")
+ private val b: Expression = UnresolvedAttribute("b")
+
+ private def checkType(expression: Expression, expectedType: DataType): Unit
= {
+ val plan = analyzer.execute(Project(Seq(Alias(expression, "c")()),
relation))
+ assert(plan.isInstanceOf[Project])
+ val expr = plan.asInstanceOf[Project].projectList.head
+ assert(expr.dataType == expectedType)
+ val transformedExpr =
+ ExpressionConverter.replaceWithExpressionTransformer(expr,
plan.inputSet.toSeq)
+ assert(transformedExpr.dataType == expectedType)
+ }
+
+ private def stripAlias(expr: Expression): Expression = {
+ expr match {
+ case a: Alias => stripAlias(a.child)
+ case _ => expr
+ }
+ }
+
+ private def checkComparison(expression: Expression, expectedType: DataType):
Unit = {
+ val plan = analyzer.execute(Project(Alias(expression, "c")() :: Nil,
relation))
+ assert(plan.isInstanceOf[Project])
+ val expr = stripAlias(plan.asInstanceOf[Project].projectList.head)
+ val transformedExpr =
+ ExpressionConverter.replaceWithExpressionTransformer(expr,
plan.inputSet.toSeq)
+ assert(transformedExpr.isInstanceOf[GenericExpressionTransformer])
+ val binaryComparison =
transformedExpr.asInstanceOf[GenericExpressionTransformer]
+ assert(binaryComparison.original.isInstanceOf[BinaryComparison])
+ assert(binaryComparison.children.size == 2)
+ assert(binaryComparison.children.forall(_.dataType == expectedType))
+ }
+
+ test("basic operations") {
+ checkType(Add(d1, d2), DecimalType(6, 2))
+ checkType(Subtract(d1, d2), DecimalType(6, 2))
+ checkType(Multiply(d1, d2), DecimalType(8, 3))
+ checkType(Divide(d1, d2), DecimalType(10, 7))
+ checkType(Divide(d2, d1), DecimalType(10, 6))
+
+ checkType(Add(Add(d1, d2), d1), DecimalType(7, 2))
+ checkType(Add(Add(d1, d1), d1), DecimalType(4, 1))
+ checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1))
+ checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2))
+ checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2))
+ checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2))
+ checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4))
+ checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6))
+ }
+
+ test("Comparison operations") {
+ checkComparison(EqualTo(i, d1), DecimalType(11, 1))
+ checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2))
+ checkComparison(LessThan(i, d1), DecimalType(11, 1))
+ checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2))
+ checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT)
+ checkComparison(GreaterThanOrEqual(d1, f), DoubleType)
+ checkComparison(GreaterThan(d2, d2), DecimalType(5, 2))
+ }
+
+ test("bringing in primitive types") {
+ checkType(Add(d1, i), DecimalType(12, 1))
+ checkType(Add(d1, f), DoubleType)
+ checkType(Add(i, d1), DecimalType(12, 1))
+ checkType(Add(f, d1), DoubleType)
+ checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1))
+ checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1))
+ checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1))
+ checkType(Add(d1, Cast(i, DoubleType)), DoubleType)
+ }
+
+ test("maximum decimals") {
+ for (expr <- Seq(d1, d2, i, u)) {
+ checkType(Add(expr, u), DecimalType(38, 17))
+ checkType(Subtract(expr, u), DecimalType(38, 17))
+ }
+
+ checkType(Multiply(d1, u), DecimalType(38, 16))
+ checkType(Multiply(d2, u), DecimalType(38, 14))
+ checkType(Multiply(i, u), DecimalType(38, 7))
+ checkType(Multiply(u, u), DecimalType(38, 6))
+
+ checkType(Divide(u, d1), DecimalType(38, 17))
+ checkType(Divide(u, d2), DecimalType(38, 16))
+ checkType(Divide(u, i), DecimalType(38, 18))
+ checkType(Divide(u, u), DecimalType(38, 6))
+
+ for (expr <- Seq(f, b)) {
+ checkType(Add(expr, u), DoubleType)
+ checkType(Subtract(expr, u), DoubleType)
+ checkType(Multiply(expr, u), DoubleType)
+ checkType(Divide(expr, u), DoubleType)
+ }
+ }
+}
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala
index a3720fc62..c27159ceb 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/CustomerExpressionTransformer.scala
@@ -26,12 +26,12 @@ import
org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import com.google.common.collect.Lists
-class CustomAddExpressionTransformer(
+case class CustomAddExpressionTransformer(
substraitExprName: String,
left: ExpressionTransformer,
right: ExpressionTransformer,
original: Expression)
- extends ExpressionTransformer
+ extends ExpressionTransformerWithOrigin
with Logging {
override def doTransform(args: java.lang.Object): ExpressionNode = {
val leftNode = left.doTransform(args)
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 6a403204f..85f3f94cc 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -800,6 +800,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeGlutenTest("to_unix_timestamp")
.excludeGlutenTest("Hour")
enableSuite[GlutenDecimalExpressionSuite]
+ enableSuite[GlutenDecimalPrecisionSuite]
enableSuite[GlutenHashExpressionsSuite]
.exclude("sha2")
.exclude("murmur3/xxHash64/hive hash:
struct<null:void,boolean:boolean,byte:tinyint,short:smallint,int:int,long:bigint,float:float,double:double,bigDecimal:decimal(38,18),smallDecimal:decimal(10,0),string:string,binary:binary,date:date,timestamp:timestamp,udt:examplepoint>")
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index 128e52a79..1d796aa1b 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -19,7 +19,7 @@ package org.apache.gluten.utils.velox
import org.apache.gluten.utils.{BackendTestSettings, SQLQueryTestSettings}
import org.apache.spark.sql._
-import
org.apache.spark.sql.catalyst.expressions.{GlutenAnsiCastSuiteWithAnsiModeOff,
GlutenAnsiCastSuiteWithAnsiModeOn, GlutenArithmeticExpressionSuite,
GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCastSuiteWithAnsiModeOn,
GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite,
GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite,
GlutenDecimalExpressionSuite, GlutenHashExpressionsSuite,
GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite,
GlutenLiteralExp [...]
+import
org.apache.spark.sql.catalyst.expressions.{GlutenAnsiCastSuiteWithAnsiModeOff,
GlutenAnsiCastSuiteWithAnsiModeOn, GlutenArithmeticExpressionSuite,
GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCastSuiteWithAnsiModeOn,
GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite,
GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite,
GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite,
GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite,
GlutenIntervalExpre [...]
import org.apache.spark.sql.connector._
import org.apache.spark.sql.errors.{GlutenQueryCompilationErrorsDSv2Suite,
GlutenQueryExecutionErrorsSuite, GlutenQueryParsingErrorsSuite}
import org.apache.spark.sql.execution._
@@ -141,6 +141,7 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("from_unixtime")
.exclude("test timestamp add")
enableSuite[GlutenDecimalExpressionSuite]
+ enableSuite[GlutenDecimalPrecisionSuite]
enableSuite[GlutenHashExpressionsSuite]
enableSuite[GlutenHigherOrderFunctionsSuite]
enableSuite[GlutenIntervalExpressionsSuite]
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala
new file mode 100644
index 000000000..97e752d7d
--- /dev/null
+++
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.gluten.expression._
+
+import org.apache.spark.sql.GlutenTestsTrait
+import org.apache.spark.sql.catalyst.analysis.{Analyzer,
EmptyFunctionRegistry, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.sql.types._
+
+class GlutenDecimalPrecisionSuite extends GlutenTestsTrait {
+ private val catalog = new SessionCatalog(new InMemoryCatalog,
EmptyFunctionRegistry)
+ private val analyzer = new Analyzer(catalog)
+
+ private val relation = LocalRelation(
+ AttributeReference("i", IntegerType)(),
+ AttributeReference("d1", DecimalType(2, 1))(),
+ AttributeReference("d2", DecimalType(5, 2))(),
+ AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
+ AttributeReference("f", FloatType)(),
+ AttributeReference("b", DoubleType)()
+ )
+
+ private val i: Expression = UnresolvedAttribute("i")
+ private val d1: Expression = UnresolvedAttribute("d1")
+ private val d2: Expression = UnresolvedAttribute("d2")
+ private val u: Expression = UnresolvedAttribute("u")
+ private val f: Expression = UnresolvedAttribute("f")
+ private val b: Expression = UnresolvedAttribute("b")
+
+ private def checkType(expression: Expression, expectedType: DataType): Unit
= {
+ val plan = analyzer.execute(Project(Seq(Alias(expression, "c")()),
relation))
+ assert(plan.isInstanceOf[Project])
+ val expr = plan.asInstanceOf[Project].projectList.head
+ assert(expr.dataType == expectedType)
+ val transformedExpr =
+ ExpressionConverter.replaceWithExpressionTransformer(expr,
plan.inputSet.toSeq)
+ assert(transformedExpr.dataType == expectedType)
+ }
+
+ private def stripAlias(expr: Expression): Expression = {
+ expr match {
+ case a: Alias => stripAlias(a.child)
+ case _ => expr
+ }
+ }
+
+ private def checkComparison(expression: Expression, expectedType: DataType):
Unit = {
+ val plan = analyzer.execute(Project(Alias(expression, "c")() :: Nil,
relation))
+ assert(plan.isInstanceOf[Project])
+ val expr = stripAlias(plan.asInstanceOf[Project].projectList.head)
+ val transformedExpr =
+ ExpressionConverter.replaceWithExpressionTransformer(expr,
plan.inputSet.toSeq)
+ assert(transformedExpr.isInstanceOf[GenericExpressionTransformer])
+ val binaryComparison =
transformedExpr.asInstanceOf[GenericExpressionTransformer]
+ assert(binaryComparison.original.isInstanceOf[BinaryComparison])
+ assert(binaryComparison.children.size == 2)
+ assert(binaryComparison.children.forall(_.dataType == expectedType))
+ }
+
+ test("basic operations") {
+ checkType(Add(d1, d2), DecimalType(6, 2))
+ checkType(Subtract(d1, d2), DecimalType(6, 2))
+ checkType(Multiply(d1, d2), DecimalType(8, 3))
+ checkType(Divide(d1, d2), DecimalType(10, 7))
+ checkType(Divide(d2, d1), DecimalType(10, 6))
+
+ checkType(Add(Add(d1, d2), d1), DecimalType(7, 2))
+ checkType(Add(Add(d1, d1), d1), DecimalType(4, 1))
+ checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1))
+ checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2))
+ checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2))
+ checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2))
+ checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4))
+ checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6))
+ }
+
+ test("Comparison operations") {
+ checkComparison(EqualTo(i, d1), DecimalType(11, 1))
+ checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2))
+ checkComparison(LessThan(i, d1), DecimalType(11, 1))
+ checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2))
+ checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT)
+ checkComparison(GreaterThanOrEqual(d1, f), DoubleType)
+ checkComparison(GreaterThan(d2, d2), DecimalType(5, 2))
+ }
+
+ test("bringing in primitive types") {
+ checkType(Add(d1, i), DecimalType(12, 1))
+ checkType(Add(d1, f), DoubleType)
+ checkType(Add(i, d1), DecimalType(12, 1))
+ checkType(Add(f, d1), DoubleType)
+ checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1))
+ checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1))
+ checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1))
+ checkType(Add(d1, Cast(i, DoubleType)), DoubleType)
+ }
+
+ test("maximum decimals") {
+ for (expr <- Seq(d1, d2, i, u)) {
+ checkType(Add(expr, u), DecimalType(38, 17))
+ checkType(Subtract(expr, u), DecimalType(38, 17))
+ }
+
+ checkType(Multiply(d1, u), DecimalType(38, 16))
+ checkType(Multiply(d2, u), DecimalType(38, 14))
+ checkType(Multiply(i, u), DecimalType(38, 7))
+ checkType(Multiply(u, u), DecimalType(38, 6))
+
+ checkType(Divide(u, d1), DecimalType(38, 17))
+ checkType(Divide(u, d2), DecimalType(38, 16))
+ checkType(Divide(u, i), DecimalType(38, 18))
+ checkType(Divide(u, u), DecimalType(38, 6))
+
+ for (expr <- Seq(f, b)) {
+ checkType(Add(expr, u), DoubleType)
+ checkType(Subtract(expr, u), DoubleType)
+ checkType(Multiply(expr, u), DoubleType)
+ checkType(Divide(expr, u), DoubleType)
+ }
+ }
+}
diff --git
a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 37e4c68f7..069d697bd 100644
---
a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -639,6 +639,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeGlutenTest("to_unix_timestamp")
.excludeGlutenTest("Hour")
enableSuite[GlutenDecimalExpressionSuite].exclude("MakeDecimal")
+ enableSuite[GlutenDecimalPrecisionSuite]
enableSuite[GlutenHashExpressionsSuite]
.exclude("sha2")
.exclude("murmur3/xxHash64/hive hash:
struct<null:void,boolean:boolean,byte:tinyint,short:smallint,int:int,long:bigint,float:float,double:double,bigDecimal:decimal(38,18),smallDecimal:decimal(10,0),string:string,binary:binary,date:date,timestamp:timestamp,udt:examplepoint>")
diff --git
a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index 6ea29847b..7c8509f80 100644
---
a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -19,7 +19,7 @@ package org.apache.gluten.utils.velox
import org.apache.gluten.utils.{BackendTestSettings, SQLQueryTestSettings}
import org.apache.spark.sql._
-import
org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite,
GlutenBitwiseExpressionsSuite, GlutenCastSuite,
GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite,
GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite,
GlutenDecimalExpressionSuite, GlutenHashExpressionsSuite,
GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite,
GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite,
GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, Glu [...]
+import
org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite,
GlutenBitwiseExpressionsSuite, GlutenCastSuite,
GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite,
GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite,
GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite,
GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite,
GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite,
GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, Glu [...]
import
org.apache.spark.sql.connector.{GlutenDataSourceV2DataFrameSessionCatalogSuite,
GlutenDataSourceV2DataFrameSuite, GlutenDataSourceV2FunctionSuite,
GlutenDataSourceV2SQLSessionCatalogSuite, GlutenDataSourceV2SQLSuiteV1Filter,
GlutenDataSourceV2SQLSuiteV2Filter, GlutenDataSourceV2Suite,
GlutenDeleteFromTableSuite, GlutenDeltaBasedDeleteFromTableSuite,
GlutenFileDataSourceV2FallBackSuite, GlutenGroupBasedDeleteFromTableSuite,
GlutenKeyGroupedPartitioningSuite, GlutenLocalScanSuite, G [...]
import org.apache.spark.sql.errors.{GlutenQueryCompilationErrorsDSv2Suite,
GlutenQueryCompilationErrorsSuite, GlutenQueryExecutionErrorsSuite,
GlutenQueryParsingErrorsSuite}
import org.apache.spark.sql.execution.{FallbackStrategiesSuite,
GlutenBroadcastExchangeSuite, GlutenCoalesceShufflePartitionsSuite,
GlutenExchangeSuite, GlutenLocalBroadcastExchangeSuite,
GlutenReplaceHashWithSortAggSuite, GlutenReuseExchangeAndSubquerySuite,
GlutenSameResultSuite, GlutenSortSuite, GlutenSQLAggregateFunctionSuite,
GlutenSQLWindowFunctionSuite, GlutenTakeOrderedAndProjectSuite}
@@ -121,6 +121,7 @@ class VeloxTestSettings extends BackendTestSettings {
// Replaced by a gluten test to pass timezone through config.
.exclude("from_unixtime")
enableSuite[GlutenDecimalExpressionSuite]
+ enableSuite[GlutenDecimalPrecisionSuite]
enableSuite[GlutenHashExpressionsSuite]
enableSuite[GlutenHigherOrderFunctionsSuite]
enableSuite[GlutenIntervalExpressionsSuite]
diff --git
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala
new file mode 100644
index 000000000..97e752d7d
--- /dev/null
+++
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.gluten.expression._
+
+import org.apache.spark.sql.GlutenTestsTrait
+import org.apache.spark.sql.catalyst.analysis.{Analyzer,
EmptyFunctionRegistry, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.sql.types._
+
+class GlutenDecimalPrecisionSuite extends GlutenTestsTrait {
+ private val catalog = new SessionCatalog(new InMemoryCatalog,
EmptyFunctionRegistry)
+ private val analyzer = new Analyzer(catalog)
+
+ private val relation = LocalRelation(
+ AttributeReference("i", IntegerType)(),
+ AttributeReference("d1", DecimalType(2, 1))(),
+ AttributeReference("d2", DecimalType(5, 2))(),
+ AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
+ AttributeReference("f", FloatType)(),
+ AttributeReference("b", DoubleType)()
+ )
+
+ private val i: Expression = UnresolvedAttribute("i")
+ private val d1: Expression = UnresolvedAttribute("d1")
+ private val d2: Expression = UnresolvedAttribute("d2")
+ private val u: Expression = UnresolvedAttribute("u")
+ private val f: Expression = UnresolvedAttribute("f")
+ private val b: Expression = UnresolvedAttribute("b")
+
+ private def checkType(expression: Expression, expectedType: DataType): Unit
= {
+ val plan = analyzer.execute(Project(Seq(Alias(expression, "c")()),
relation))
+ assert(plan.isInstanceOf[Project])
+ val expr = plan.asInstanceOf[Project].projectList.head
+ assert(expr.dataType == expectedType)
+ val transformedExpr =
+ ExpressionConverter.replaceWithExpressionTransformer(expr,
plan.inputSet.toSeq)
+ assert(transformedExpr.dataType == expectedType)
+ }
+
+ private def stripAlias(expr: Expression): Expression = {
+ expr match {
+ case a: Alias => stripAlias(a.child)
+ case _ => expr
+ }
+ }
+
+ private def checkComparison(expression: Expression, expectedType: DataType):
Unit = {
+ val plan = analyzer.execute(Project(Alias(expression, "c")() :: Nil,
relation))
+ assert(plan.isInstanceOf[Project])
+ val expr = stripAlias(plan.asInstanceOf[Project].projectList.head)
+ val transformedExpr =
+ ExpressionConverter.replaceWithExpressionTransformer(expr,
plan.inputSet.toSeq)
+ assert(transformedExpr.isInstanceOf[GenericExpressionTransformer])
+ val binaryComparison =
transformedExpr.asInstanceOf[GenericExpressionTransformer]
+ assert(binaryComparison.original.isInstanceOf[BinaryComparison])
+ assert(binaryComparison.children.size == 2)
+ assert(binaryComparison.children.forall(_.dataType == expectedType))
+ }
+
+ test("basic operations") {
+ checkType(Add(d1, d2), DecimalType(6, 2))
+ checkType(Subtract(d1, d2), DecimalType(6, 2))
+ checkType(Multiply(d1, d2), DecimalType(8, 3))
+ checkType(Divide(d1, d2), DecimalType(10, 7))
+ checkType(Divide(d2, d1), DecimalType(10, 6))
+
+ checkType(Add(Add(d1, d2), d1), DecimalType(7, 2))
+ checkType(Add(Add(d1, d1), d1), DecimalType(4, 1))
+ checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1))
+ checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2))
+ checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2))
+ checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2))
+ checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4))
+ checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6))
+ }
+
+ test("Comparison operations") {
+ checkComparison(EqualTo(i, d1), DecimalType(11, 1))
+ checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2))
+ checkComparison(LessThan(i, d1), DecimalType(11, 1))
+ checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2))
+ checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT)
+ checkComparison(GreaterThanOrEqual(d1, f), DoubleType)
+ checkComparison(GreaterThan(d2, d2), DecimalType(5, 2))
+ }
+
+ test("bringing in primitive types") {
+ checkType(Add(d1, i), DecimalType(12, 1))
+ checkType(Add(d1, f), DoubleType)
+ checkType(Add(i, d1), DecimalType(12, 1))
+ checkType(Add(f, d1), DoubleType)
+ checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1))
+ checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1))
+ checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1))
+ checkType(Add(d1, Cast(i, DoubleType)), DoubleType)
+ }
+
+ test("maximum decimals") {
+ for (expr <- Seq(d1, d2, i, u)) {
+ checkType(Add(expr, u), DecimalType(38, 17))
+ checkType(Subtract(expr, u), DecimalType(38, 17))
+ }
+
+ checkType(Multiply(d1, u), DecimalType(38, 16))
+ checkType(Multiply(d2, u), DecimalType(38, 14))
+ checkType(Multiply(i, u), DecimalType(38, 7))
+ checkType(Multiply(u, u), DecimalType(38, 6))
+
+ checkType(Divide(u, d1), DecimalType(38, 17))
+ checkType(Divide(u, d2), DecimalType(38, 16))
+ checkType(Divide(u, i), DecimalType(38, 18))
+ checkType(Divide(u, u), DecimalType(38, 6))
+
+ for (expr <- Seq(f, b)) {
+ checkType(Add(expr, u), DoubleType)
+ checkType(Subtract(expr, u), DoubleType)
+ checkType(Multiply(expr, u), DoubleType)
+ checkType(Divide(expr, u), DoubleType)
+ }
+ }
+}
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 37e4c68f7..069d697bd 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -639,6 +639,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeGlutenTest("to_unix_timestamp")
.excludeGlutenTest("Hour")
enableSuite[GlutenDecimalExpressionSuite].exclude("MakeDecimal")
+ enableSuite[GlutenDecimalPrecisionSuite]
enableSuite[GlutenHashExpressionsSuite]
.exclude("sha2")
.exclude("murmur3/xxHash64/hive hash:
struct<null:void,boolean:boolean,byte:tinyint,short:smallint,int:int,long:bigint,float:float,double:double,bigDecimal:decimal(38,18),smallDecimal:decimal(10,0),string:string,binary:binary,date:date,timestamp:timestamp,udt:examplepoint>")
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index e6e42acb3..40ecc3c35 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -19,7 +19,7 @@ package org.apache.gluten.utils.velox
import org.apache.gluten.utils.{BackendTestSettings, SQLQueryTestSettings}
import org.apache.spark.sql._
-import
org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite,
GlutenBitwiseExpressionsSuite, GlutenCastSuite,
GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite,
GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite,
GlutenDecimalExpressionSuite, GlutenHashExpressionsSuite,
GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite,
GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite,
GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, Glu [...]
+import
org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite,
GlutenBitwiseExpressionsSuite, GlutenCastSuite,
GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite,
GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite,
GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite,
GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite,
GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite,
GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, Glu [...]
import org.apache.spark.sql.connector._
import org.apache.spark.sql.errors.{GlutenQueryCompilationErrorsDSv2Suite,
GlutenQueryCompilationErrorsSuite, GlutenQueryExecutionErrorsSuite,
GlutenQueryParsingErrorsSuite}
import org.apache.spark.sql.execution._
@@ -122,6 +122,7 @@ class VeloxTestSettings extends BackendTestSettings {
// Replaced by a gluten test to pass timezone through config.
.exclude("from_unixtime")
enableSuite[GlutenDecimalExpressionSuite]
+ enableSuite[GlutenDecimalPrecisionSuite]
enableSuite[GlutenHashExpressionsSuite]
enableSuite[GlutenHigherOrderFunctionsSuite]
enableSuite[GlutenIntervalExpressionsSuite]
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala
new file mode 100644
index 000000000..97e752d7d
--- /dev/null
+++
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenDecimalPrecisionSuite.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.gluten.expression._
+
+import org.apache.spark.sql.GlutenTestsTrait
+import org.apache.spark.sql.catalyst.analysis.{Analyzer,
EmptyFunctionRegistry, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.sql.types._
+
+class GlutenDecimalPrecisionSuite extends GlutenTestsTrait {
+ private val catalog = new SessionCatalog(new InMemoryCatalog,
EmptyFunctionRegistry)
+ private val analyzer = new Analyzer(catalog)
+
+ private val relation = LocalRelation(
+ AttributeReference("i", IntegerType)(),
+ AttributeReference("d1", DecimalType(2, 1))(),
+ AttributeReference("d2", DecimalType(5, 2))(),
+ AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
+ AttributeReference("f", FloatType)(),
+ AttributeReference("b", DoubleType)()
+ )
+
+ private val i: Expression = UnresolvedAttribute("i")
+ private val d1: Expression = UnresolvedAttribute("d1")
+ private val d2: Expression = UnresolvedAttribute("d2")
+ private val u: Expression = UnresolvedAttribute("u")
+ private val f: Expression = UnresolvedAttribute("f")
+ private val b: Expression = UnresolvedAttribute("b")
+
+ private def checkType(expression: Expression, expectedType: DataType): Unit
= {
+ val plan = analyzer.execute(Project(Seq(Alias(expression, "c")()),
relation))
+ assert(plan.isInstanceOf[Project])
+ val expr = plan.asInstanceOf[Project].projectList.head
+ assert(expr.dataType == expectedType)
+ val transformedExpr =
+ ExpressionConverter.replaceWithExpressionTransformer(expr,
plan.inputSet.toSeq)
+ assert(transformedExpr.dataType == expectedType)
+ }
+
+ private def stripAlias(expr: Expression): Expression = {
+ expr match {
+ case a: Alias => stripAlias(a.child)
+ case _ => expr
+ }
+ }
+
+ private def checkComparison(expression: Expression, expectedType: DataType):
Unit = {
+ val plan = analyzer.execute(Project(Alias(expression, "c")() :: Nil,
relation))
+ assert(plan.isInstanceOf[Project])
+ val expr = stripAlias(plan.asInstanceOf[Project].projectList.head)
+ val transformedExpr =
+ ExpressionConverter.replaceWithExpressionTransformer(expr,
plan.inputSet.toSeq)
+ assert(transformedExpr.isInstanceOf[GenericExpressionTransformer])
+ val binaryComparison =
transformedExpr.asInstanceOf[GenericExpressionTransformer]
+ assert(binaryComparison.original.isInstanceOf[BinaryComparison])
+ assert(binaryComparison.children.size == 2)
+ assert(binaryComparison.children.forall(_.dataType == expectedType))
+ }
+
+ test("basic operations") {
+ checkType(Add(d1, d2), DecimalType(6, 2))
+ checkType(Subtract(d1, d2), DecimalType(6, 2))
+ checkType(Multiply(d1, d2), DecimalType(8, 3))
+ checkType(Divide(d1, d2), DecimalType(10, 7))
+ checkType(Divide(d2, d1), DecimalType(10, 6))
+
+ checkType(Add(Add(d1, d2), d1), DecimalType(7, 2))
+ checkType(Add(Add(d1, d1), d1), DecimalType(4, 1))
+ checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1))
+ checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2))
+ checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2))
+ checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2))
+ checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4))
+ checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6))
+ }
+
+ test("Comparison operations") {
+ checkComparison(EqualTo(i, d1), DecimalType(11, 1))
+ checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2))
+ checkComparison(LessThan(i, d1), DecimalType(11, 1))
+ checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2))
+ checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT)
+ checkComparison(GreaterThanOrEqual(d1, f), DoubleType)
+ checkComparison(GreaterThan(d2, d2), DecimalType(5, 2))
+ }
+
+ test("bringing in primitive types") {
+ checkType(Add(d1, i), DecimalType(12, 1))
+ checkType(Add(d1, f), DoubleType)
+ checkType(Add(i, d1), DecimalType(12, 1))
+ checkType(Add(f, d1), DoubleType)
+ checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1))
+ checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1))
+ checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1))
+ checkType(Add(d1, Cast(i, DoubleType)), DoubleType)
+ }
+
+ test("maximum decimals") {
+ for (expr <- Seq(d1, d2, i, u)) {
+ checkType(Add(expr, u), DecimalType(38, 17))
+ checkType(Subtract(expr, u), DecimalType(38, 17))
+ }
+
+ checkType(Multiply(d1, u), DecimalType(38, 16))
+ checkType(Multiply(d2, u), DecimalType(38, 14))
+ checkType(Multiply(i, u), DecimalType(38, 7))
+ checkType(Multiply(u, u), DecimalType(38, 6))
+
+ checkType(Divide(u, d1), DecimalType(38, 17))
+ checkType(Divide(u, d2), DecimalType(38, 16))
+ checkType(Divide(u, i), DecimalType(38, 18))
+ checkType(Divide(u, u), DecimalType(38, 6))
+
+ for (expr <- Seq(f, b)) {
+ checkType(Add(expr, u), DoubleType)
+ checkType(Subtract(expr, u), DoubleType)
+ checkType(Multiply(expr, u), DoubleType)
+ checkType(Divide(expr, u), DoubleType)
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]