This is an automated email from the ASF dual-hosted git repository.
yuanzhou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new e80785625 [GLUTEN-5620][CORE] Remove check_overflow and refactor code
(#5654)
e80785625 is described below
commit e80785625c345271ecb279540468d74bdcafd394
Author: Jin Chengcheng <[email protected]>
AuthorDate: Tue May 14 17:54:57 2024 +0800
[GLUTEN-5620][CORE] Remove check_overflow and refactor code (#5654)
---
.../backendsapi/clickhouse/CHTransformerApi.scala | 1 +
.../gluten/backendsapi/velox/VeloxBackend.scala | 4 +-
.../backendsapi/velox/VeloxTransformerApi.scala | 9 +-
.../gluten/backendsapi/BackendSettingsApi.scala | 4 +-
.../apache/gluten/backendsapi/TransformerApi.scala | 3 +-
.../gluten/expression/ExpressionConverter.scala | 71 +++++-----
.../expression/UnaryExpressionTransformer.scala | 2 +
.../gluten/utils/DecimalArithmeticUtil.scala | 143 +++++----------------
.../apache/spark/sql/utils/DecimalTypeUtil.scala | 26 ++++
9 files changed, 104 insertions(+), 159 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
index df1ca9c68..ee46d685c 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
@@ -219,6 +219,7 @@ class CHTransformerApi extends TransformerApi with Logging {
args: java.lang.Object,
substraitExprName: String,
childNode: ExpressionNode,
+ childResultType: DataType,
dataType: DecimalType,
nullable: Boolean,
nullOnOverflow: Boolean): ExpressionNode = {
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index 5509d37e8..c16b3624f 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -449,14 +449,12 @@ object VeloxBackendSettings extends BackendSettingsApi {
override def fallbackAggregateWithEmptyOutputChild(): Boolean = true
override def recreateJoinExecOnFallback(): Boolean = true
- override def rescaleDecimalLiteral(): Boolean = true
+ override def rescaleDecimalArithmetic(): Boolean = true
/** Get the config prefix for each backend */
override def getBackendConfigPrefix(): String =
GlutenConfig.GLUTEN_CONFIG_PREFIX + VeloxBackend.BACKEND_NAME
- override def rescaleDecimalIntegralExpression(): Boolean = true
-
override def shuffleSupportedCodec(): Set[String] = SHUFFLE_SUPPORTED_CODEC
override def resolveNativeConf(nativeConf: java.util.Map[String, String]):
Unit = {}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
index e5aa281a8..33f612440 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
@@ -80,11 +80,16 @@ class VeloxTransformerApi extends TransformerApi with
Logging {
args: java.lang.Object,
substraitExprName: String,
childNode: ExpressionNode,
+ childResultType: DataType,
dataType: DecimalType,
nullable: Boolean,
nullOnOverflow: Boolean): ExpressionNode = {
- val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
- ExpressionBuilder.makeCast(typeNode, childNode, !nullOnOverflow)
+ if (childResultType.equals(dataType)) {
+ childNode
+ } else {
+ val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
+ ExpressionBuilder.makeCast(typeNode, childNode, !nullOnOverflow)
+ }
}
override def getNativePlanString(substraitPlan: Array[Byte], details:
Boolean): String = {
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
index c8729561d..9c5c13271 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
@@ -85,7 +85,7 @@ trait BackendSettingsApi {
def supportShuffleWithProject(outputPartitioning: Partitioning, child:
SparkPlan): Boolean = false
def utilizeShuffledHashJoinHint(): Boolean = false
def excludeScanExecFromCollapsedStage(): Boolean = false
- def rescaleDecimalLiteral: Boolean = false
+ def rescaleDecimalArithmetic: Boolean = false
/**
* Whether to replace sort agg with hash agg., e.g., sort agg will be used
in spark's planning for
@@ -106,8 +106,6 @@ trait BackendSettingsApi {
*/
def transformCheckOverflow: Boolean = true
- def rescaleDecimalIntegralExpression(): Boolean = false
-
def shuffleSupportedCodec(): Set[String]
def needOutputSchemaForPlan(): Boolean = false
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
index 49a97a8a4..7a10dc68c 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute,
Expression}
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation,
PartitionDirectory}
-import org.apache.spark.sql.types.DecimalType
+import org.apache.spark.sql.types.{DataType, DecimalType}
import org.apache.spark.util.collection.BitSet
import com.google.protobuf.{Any, Message}
@@ -69,6 +69,7 @@ trait TransformerApi {
args: java.lang.Object,
substraitExprName: String,
childNode: ExpressionNode,
+ childResultType: DataType,
dataType: DecimalType,
nullable: Boolean,
nullOnOverflow: Boolean): ExpressionNode
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index 495fbf8d5..b7b946268 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -101,6 +101,28 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
}
}
+ private def genRescaleDecimalTransformer(
+ substraitName: String,
+ b: BinaryArithmetic,
+ attributeSeq: Seq[Attribute],
+ expressionsMap: Map[Class[_], String]):
DecimalArithmeticExpressionTransformer = {
+ val rescaleBinary = DecimalArithmeticUtil.rescaleLiteral(b)
+ val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal(
+ DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left),
+ DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right))
+ val resultType = DecimalArithmeticUtil.getResultType(
+ b,
+ left.dataType.asInstanceOf[DecimalType],
+ right.dataType.asInstanceOf[DecimalType]
+ )
+
+ val leftChild =
+ replaceWithExpressionTransformerInternal(left, attributeSeq,
expressionsMap)
+ val rightChild =
+ replaceWithExpressionTransformerInternal(right, attributeSeq,
expressionsMap)
+ DecimalArithmeticExpressionTransformer(substraitName, leftChild,
rightChild, resultType, b)
+ }
+
private def replaceWithExpressionTransformerInternal(
expr: Expression,
attributeSeq: Seq[Attribute],
@@ -492,7 +514,6 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
expr.children.map(
replaceWithExpressionTransformerInternal(_, attributeSeq,
expressionsMap)),
expr)
-
case CheckOverflow(b: BinaryArithmetic, decimalType, _)
if !BackendsApiManager.getSettings.transformCheckOverflow &&
DecimalArithmeticUtil.isDecimalArithmetic(b) =>
@@ -507,55 +528,25 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
rightChild,
decimalType,
b)
-
case c: CheckOverflow =>
CheckOverflowTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(c.child, attributeSeq,
expressionsMap),
+ c.child.dataType,
c)
-
case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b)
=>
DecimalArithmeticUtil.checkAllowDecimalArithmetic()
if (!BackendsApiManager.getSettings.transformCheckOverflow) {
- val leftChild =
- replaceWithExpressionTransformerInternal(b.left, attributeSeq,
expressionsMap)
- val rightChild =
- replaceWithExpressionTransformerInternal(b.right, attributeSeq,
expressionsMap)
- DecimalArithmeticExpressionTransformer(
+ GenericExpressionTransformer(
substraitExprName,
- leftChild,
- rightChild,
- b.dataType.asInstanceOf[DecimalType],
- b)
- } else {
- val rescaleBinary = if
(BackendsApiManager.getSettings.rescaleDecimalLiteral) {
- DecimalArithmeticUtil.rescaleLiteral(b)
- } else {
- b
- }
- val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal(
- DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left),
- DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right))
- val leftChild =
- replaceWithExpressionTransformerInternal(left, attributeSeq,
expressionsMap)
- val rightChild =
- replaceWithExpressionTransformerInternal(right, attributeSeq,
expressionsMap)
-
- val resultType = DecimalArithmeticUtil.getResultTypeForOperation(
- DecimalArithmeticUtil.getOperationType(b),
- DecimalArithmeticUtil
- .getResultType(leftChild)
- .getOrElse(left.dataType.asInstanceOf[DecimalType]),
- DecimalArithmeticUtil
- .getResultType(rightChild)
- .getOrElse(right.dataType.asInstanceOf[DecimalType])
+ expr.children.map(
+ replaceWithExpressionTransformerInternal(_, attributeSeq,
expressionsMap)),
+ expr
)
- DecimalArithmeticExpressionTransformer(
- substraitExprName,
- leftChild,
- rightChild,
- resultType,
- b)
+ } else {
+ // Without the rescale and remove cast, result is right for high
version Spark,
+ // but performance regression in velox
+ genRescaleDecimalTransformer(substraitExprName, b, attributeSeq,
expressionsMap)
}
case n: NaNvl =>
BackendsApiManager.getSparkPlanExecApiInstance.genNaNvlTransformer(
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
index 88df12b84..2d3840ce4 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
@@ -152,6 +152,7 @@ case class PosExplodeTransformer(
case class CheckOverflowTransformer(
substraitExprName: String,
child: ExpressionTransformer,
+ childResultType: DataType,
original: CheckOverflow)
extends ExpressionTransformer {
@@ -160,6 +161,7 @@ case class CheckOverflowTransformer(
args,
substraitExprName,
child.doTransform(args),
+ childResultType,
original.dataType,
original.nullable,
original.nullOnOverflow)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala
b/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala
index 148cc4e60..479eb8bb5 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala
@@ -18,69 +18,40 @@ package org.apache.gluten.utils
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
-import org.apache.gluten.expression.{CheckOverflowTransformer,
ChildTransformer, DecimalArithmeticExpressionTransformer, ExpressionTransformer}
import org.apache.gluten.expression.ExpressionConverter.conf
import org.apache.spark.sql.catalyst.analysis.DecimalPrecision
import org.apache.spark.sql.catalyst.expressions.{Add, BinaryArithmetic, Cast,
Divide, Expression, Literal, Multiply, Pmod, PromotePrecision, Remainder,
Subtract}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType,
IntegerType, LongType, ShortType}
-
-import scala.annotation.tailrec
+import org.apache.spark.sql.utils.DecimalTypeUtil
object DecimalArithmeticUtil {
- object OperationType extends Enumeration {
- type Config = Value
- val ADD, SUBTRACT, MULTIPLY, DIVIDE, MOD = Value
- }
-
- private val MIN_ADJUSTED_SCALE = 6
- val MAX_PRECISION = 38
-
// Returns the result decimal type of a decimal arithmetic computing.
- def getResultTypeForOperation(
- operationType: OperationType.Config,
- type1: DecimalType,
- type2: DecimalType): DecimalType = {
+ def getResultType(expr: BinaryArithmetic, type1: DecimalType, type2:
DecimalType): DecimalType = {
var resultScale = 0
var resultPrecision = 0
- operationType match {
- case OperationType.ADD =>
+ expr match {
+ case _: Add =>
resultScale = Math.max(type1.scale, type2.scale)
resultPrecision =
resultScale + Math.max(type1.precision - type1.scale,
type2.precision - type2.scale) + 1
- case OperationType.SUBTRACT =>
+ case _: Subtract =>
resultScale = Math.max(type1.scale, type2.scale)
resultPrecision =
resultScale + Math.max(type1.precision - type1.scale,
type2.precision - type2.scale) + 1
- case OperationType.MULTIPLY =>
+ case _: Multiply =>
resultScale = type1.scale + type2.scale
resultPrecision = type1.precision + type2.precision + 1
- case OperationType.DIVIDE =>
- resultScale = Math.max(MIN_ADJUSTED_SCALE, type1.scale +
type2.precision + 1)
+ case _: Divide =>
+ resultScale =
+ Math.max(DecimalType.MINIMUM_ADJUSTED_SCALE, type1.scale +
type2.precision + 1)
resultPrecision = type1.precision - type1.scale + type2.scale +
resultScale
- case OperationType.MOD =>
- resultScale = Math.max(type1.scale, type2.scale)
- resultPrecision =
- Math.min(type1.precision - type1.scale, type2.precision -
type2.scale + resultScale)
case other =>
throw new GlutenNotSupportException(s"$other is not supported.")
}
- adjustScaleIfNeeded(resultPrecision, resultScale)
- }
-
- // Returns the adjusted decimal type when the precision is larger the
maximum.
- private def adjustScaleIfNeeded(precision: Int, scale: Int): DecimalType = {
- var typePrecision = precision
- var typeScale = scale
- if (precision > MAX_PRECISION) {
- val minScale = Math.min(scale, MIN_ADJUSTED_SCALE)
- val delta = precision - MAX_PRECISION
- typePrecision = MAX_PRECISION
- typeScale = Math.max(scale - delta, minScale)
- }
- DecimalType(typePrecision, typeScale)
+ DecimalTypeUtil.adjustPrecisionScale(resultPrecision, resultScale)
}
// If casting between DecimalType, unnecessary cast is skipped to avoid data
loss,
@@ -98,18 +69,6 @@ object DecimalArithmeticUtil {
} else false
}
- // Returns the operation type of a binary arithmetic expression.
- def getOperationType(b: BinaryArithmetic): OperationType.Config = {
- b match {
- case _: Add => OperationType.ADD
- case _: Subtract => OperationType.SUBTRACT
- case _: Multiply => OperationType.MULTIPLY
- case _: Divide => OperationType.DIVIDE
- case other =>
- throw new GlutenNotSupportException(s"$other is not supported.")
- }
- }
-
// For decimal * 10 case, dec will be Decimal(38, 18), then the result
precision is wrong,
// so here we will get the real precision and scale of the literal.
private def getNewPrecisionScale(dec: Decimal): (Integer, Integer) = {
@@ -179,9 +138,7 @@ object DecimalArithmeticUtil {
if (isWiderType) (e1, newE2) else (e1, e2)
}
- if (!BackendsApiManager.getSettings.rescaleDecimalIntegralExpression()) {
- (left, right)
- } else if (!isPromoteCast(left) && isPromoteCastIntegral(right)) {
+ if (!isPromoteCast(left) && isPromoteCastIntegral(right)) {
// Have removed PromotePrecision(Cast(DecimalType)).
// Decimal * cast int.
doScale(left, right)
@@ -202,66 +159,32 @@ object DecimalArithmeticUtil {
* @return
* expression removed child PromotePrecision->Cast
*/
- def removeCastForDecimal(arithmeticExpr: Expression): Expression = {
- arithmeticExpr match {
- case precision: PromotePrecision =>
- precision.child match {
- case cast: Cast
- if cast.dataType.isInstanceOf[DecimalType]
- && cast.child.dataType.isInstanceOf[DecimalType] =>
- cast.child
- case _ => arithmeticExpr
- }
- case _ => arithmeticExpr
- }
+ def removeCastForDecimal(arithmeticExpr: Expression): Expression =
arithmeticExpr match {
+ case PromotePrecision(_ @Cast(child, _: DecimalType, _, _))
+ if child.dataType.isInstanceOf[DecimalType] =>
+ child
+ case _ => arithmeticExpr
}
- @tailrec
- def getResultType(transformer: ExpressionTransformer): Option[DecimalType] =
{
- transformer match {
- case ChildTransformer(child) =>
- getResultType(child)
- case CheckOverflowTransformer(_, _, original) =>
- Some(original.dataType)
- case DecimalArithmeticExpressionTransformer(_, _, _, resultType, _) =>
- Some(resultType)
- case _ => None
- }
- }
-
- private def isPromoteCastIntegral(expr: Expression): Boolean = {
- expr match {
- case precision: PromotePrecision =>
- precision.child match {
- case cast: Cast if cast.dataType.isInstanceOf[DecimalType] =>
- cast.child.dataType match {
- case IntegerType | ByteType | ShortType | LongType => true
- case _ => false
- }
- case _ => false
- }
- case _ => false
- }
+ private def isPromoteCastIntegral(expr: Expression): Boolean = expr match {
+ case PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) =>
+ child.dataType match {
+ case IntegerType | ByteType | ShortType | LongType => true
+ case _ => false
+ }
+ case _ => false
}
- private def rescaleCastForOneSide(expr: Expression): Expression = {
- expr match {
- case precision: PromotePrecision =>
- precision.child match {
- case castInt: Cast
- if castInt.dataType.isInstanceOf[DecimalType] &&
-
BackendsApiManager.getSettings.rescaleDecimalIntegralExpression() =>
- castInt.child.dataType match {
- case IntegerType | ByteType | ShortType =>
- precision.withNewChildren(Seq(Cast(castInt.child,
DecimalType(10, 0))))
- case LongType =>
- precision.withNewChildren(Seq(Cast(castInt.child,
DecimalType(20, 0))))
- case _ => expr
- }
- case _ => expr
- }
- case _ => expr
- }
+ private def rescaleCastForOneSide(expr: Expression): Expression = expr match
{
+ case precision @ PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) =>
+ child.dataType match {
+ case IntegerType | ByteType | ShortType =>
+ precision.withNewChildren(Seq(Cast(child, DecimalType(10, 0))))
+ case LongType =>
+ precision.withNewChildren(Seq(Cast(child, DecimalType(20, 0))))
+ case _ => expr
+ }
+ case _ => expr
}
private def checkIsWiderType(
diff --git
a/gluten-core/src/main/scala/org/apache/spark/sql/utils/DecimalTypeUtil.scala
b/gluten-core/src/main/scala/org/apache/spark/sql/utils/DecimalTypeUtil.scala
new file mode 100644
index 000000000..f7334bcb2
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/spark/sql/utils/DecimalTypeUtil.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.utils
+
+import org.apache.spark.sql.types.DecimalType
+
+object DecimalTypeUtil {
+ def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
+ DecimalType.adjustPrecisionScale(precision, scale)
+ }
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]