This is an automated email from the ASF dual-hosted git repository.

yuanzhou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new e80785625 [GLUTEN-5620][CORE] Remove check_overflow and refactor code 
(#5654)
e80785625 is described below

commit e80785625c345271ecb279540468d74bdcafd394
Author: Jin Chengcheng <[email protected]>
AuthorDate: Tue May 14 17:54:57 2024 +0800

    [GLUTEN-5620][CORE] Remove check_overflow and refactor code (#5654)
---
 .../backendsapi/clickhouse/CHTransformerApi.scala  |   1 +
 .../gluten/backendsapi/velox/VeloxBackend.scala    |   4 +-
 .../backendsapi/velox/VeloxTransformerApi.scala    |   9 +-
 .../gluten/backendsapi/BackendSettingsApi.scala    |   4 +-
 .../apache/gluten/backendsapi/TransformerApi.scala |   3 +-
 .../gluten/expression/ExpressionConverter.scala    |  71 +++++-----
 .../expression/UnaryExpressionTransformer.scala    |   2 +
 .../gluten/utils/DecimalArithmeticUtil.scala       | 143 +++++----------------
 .../apache/spark/sql/utils/DecimalTypeUtil.scala   |  26 ++++
 9 files changed, 104 insertions(+), 159 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
index df1ca9c68..ee46d685c 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
@@ -219,6 +219,7 @@ class CHTransformerApi extends TransformerApi with Logging {
       args: java.lang.Object,
       substraitExprName: String,
       childNode: ExpressionNode,
+      childResultType: DataType,
       dataType: DecimalType,
       nullable: Boolean,
       nullOnOverflow: Boolean): ExpressionNode = {
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index 5509d37e8..c16b3624f 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -449,14 +449,12 @@ object VeloxBackendSettings extends BackendSettingsApi {
   override def fallbackAggregateWithEmptyOutputChild(): Boolean = true
 
   override def recreateJoinExecOnFallback(): Boolean = true
-  override def rescaleDecimalLiteral(): Boolean = true
+  override def rescaleDecimalArithmetic(): Boolean = true
 
   /** Get the config prefix for each backend */
   override def getBackendConfigPrefix(): String =
     GlutenConfig.GLUTEN_CONFIG_PREFIX + VeloxBackend.BACKEND_NAME
 
-  override def rescaleDecimalIntegralExpression(): Boolean = true
-
   override def shuffleSupportedCodec(): Set[String] = SHUFFLE_SUPPORTED_CODEC
 
   override def resolveNativeConf(nativeConf: java.util.Map[String, String]): 
Unit = {}
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
index e5aa281a8..33f612440 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala
@@ -80,11 +80,16 @@ class VeloxTransformerApi extends TransformerApi with 
Logging {
       args: java.lang.Object,
       substraitExprName: String,
       childNode: ExpressionNode,
+      childResultType: DataType,
       dataType: DecimalType,
       nullable: Boolean,
       nullOnOverflow: Boolean): ExpressionNode = {
-    val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
-    ExpressionBuilder.makeCast(typeNode, childNode, !nullOnOverflow)
+    if (childResultType.equals(dataType)) {
+      childNode
+    } else {
+      val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
+      ExpressionBuilder.makeCast(typeNode, childNode, !nullOnOverflow)
+    }
   }
 
   override def getNativePlanString(substraitPlan: Array[Byte], details: 
Boolean): String = {
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
index c8729561d..9c5c13271 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
@@ -85,7 +85,7 @@ trait BackendSettingsApi {
   def supportShuffleWithProject(outputPartitioning: Partitioning, child: 
SparkPlan): Boolean = false
   def utilizeShuffledHashJoinHint(): Boolean = false
   def excludeScanExecFromCollapsedStage(): Boolean = false
-  def rescaleDecimalLiteral: Boolean = false
+  def rescaleDecimalArithmetic: Boolean = false
 
   /**
    * Whether to replace sort agg with hash agg., e.g., sort agg will be used 
in spark's planning for
@@ -106,8 +106,6 @@ trait BackendSettingsApi {
    */
   def transformCheckOverflow: Boolean = true
 
-  def rescaleDecimalIntegralExpression(): Boolean = false
-
   def shuffleSupportedCodec(): Set[String]
 
   def needOutputSchemaForPlan(): Boolean = false
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
index 49a97a8a4..7a10dc68c 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, 
Expression}
 import org.apache.spark.sql.connector.read.InputPartition
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, 
PartitionDirectory}
-import org.apache.spark.sql.types.DecimalType
+import org.apache.spark.sql.types.{DataType, DecimalType}
 import org.apache.spark.util.collection.BitSet
 
 import com.google.protobuf.{Any, Message}
@@ -69,6 +69,7 @@ trait TransformerApi {
       args: java.lang.Object,
       substraitExprName: String,
       childNode: ExpressionNode,
+      childResultType: DataType,
       dataType: DecimalType,
       nullable: Boolean,
       nullOnOverflow: Boolean): ExpressionNode
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index 495fbf8d5..b7b946268 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -101,6 +101,28 @@ object ExpressionConverter extends SQLConfHelper with 
Logging {
     }
   }
 
+  private def genRescaleDecimalTransformer(
+      substraitName: String,
+      b: BinaryArithmetic,
+      attributeSeq: Seq[Attribute],
+      expressionsMap: Map[Class[_], String]): 
DecimalArithmeticExpressionTransformer = {
+    val rescaleBinary = DecimalArithmeticUtil.rescaleLiteral(b)
+    val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal(
+      DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left),
+      DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right))
+    val resultType = DecimalArithmeticUtil.getResultType(
+      b,
+      left.dataType.asInstanceOf[DecimalType],
+      right.dataType.asInstanceOf[DecimalType]
+    )
+
+    val leftChild =
+      replaceWithExpressionTransformerInternal(left, attributeSeq, 
expressionsMap)
+    val rightChild =
+      replaceWithExpressionTransformerInternal(right, attributeSeq, 
expressionsMap)
+    DecimalArithmeticExpressionTransformer(substraitName, leftChild, 
rightChild, resultType, b)
+  }
+
   private def replaceWithExpressionTransformerInternal(
       expr: Expression,
       attributeSeq: Seq[Attribute],
@@ -492,7 +514,6 @@ object ExpressionConverter extends SQLConfHelper with 
Logging {
           expr.children.map(
             replaceWithExpressionTransformerInternal(_, attributeSeq, 
expressionsMap)),
           expr)
-
       case CheckOverflow(b: BinaryArithmetic, decimalType, _)
           if !BackendsApiManager.getSettings.transformCheckOverflow &&
             DecimalArithmeticUtil.isDecimalArithmetic(b) =>
@@ -507,55 +528,25 @@ object ExpressionConverter extends SQLConfHelper with 
Logging {
           rightChild,
           decimalType,
           b)
-
       case c: CheckOverflow =>
         CheckOverflowTransformer(
           substraitExprName,
           replaceWithExpressionTransformerInternal(c.child, attributeSeq, 
expressionsMap),
+          c.child.dataType,
           c)
-
       case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) 
=>
         DecimalArithmeticUtil.checkAllowDecimalArithmetic()
         if (!BackendsApiManager.getSettings.transformCheckOverflow) {
-          val leftChild =
-            replaceWithExpressionTransformerInternal(b.left, attributeSeq, 
expressionsMap)
-          val rightChild =
-            replaceWithExpressionTransformerInternal(b.right, attributeSeq, 
expressionsMap)
-          DecimalArithmeticExpressionTransformer(
+          GenericExpressionTransformer(
             substraitExprName,
-            leftChild,
-            rightChild,
-            b.dataType.asInstanceOf[DecimalType],
-            b)
-        } else {
-          val rescaleBinary = if 
(BackendsApiManager.getSettings.rescaleDecimalLiteral) {
-            DecimalArithmeticUtil.rescaleLiteral(b)
-          } else {
-            b
-          }
-          val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal(
-            DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left),
-            DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right))
-          val leftChild =
-            replaceWithExpressionTransformerInternal(left, attributeSeq, 
expressionsMap)
-          val rightChild =
-            replaceWithExpressionTransformerInternal(right, attributeSeq, 
expressionsMap)
-
-          val resultType = DecimalArithmeticUtil.getResultTypeForOperation(
-            DecimalArithmeticUtil.getOperationType(b),
-            DecimalArithmeticUtil
-              .getResultType(leftChild)
-              .getOrElse(left.dataType.asInstanceOf[DecimalType]),
-            DecimalArithmeticUtil
-              .getResultType(rightChild)
-              .getOrElse(right.dataType.asInstanceOf[DecimalType])
+            expr.children.map(
+              replaceWithExpressionTransformerInternal(_, attributeSeq, 
expressionsMap)),
+            expr
           )
-          DecimalArithmeticExpressionTransformer(
-            substraitExprName,
-            leftChild,
-            rightChild,
-            resultType,
-            b)
+        } else {
+          // Without the rescale and remove cast, result is right for high 
version Spark,
+          // but performance regression in velox
+          genRescaleDecimalTransformer(substraitExprName, b, attributeSeq, 
expressionsMap)
         }
       case n: NaNvl =>
         BackendsApiManager.getSparkPlanExecApiInstance.genNaNvlTransformer(
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
index 88df12b84..2d3840ce4 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala
@@ -152,6 +152,7 @@ case class PosExplodeTransformer(
 case class CheckOverflowTransformer(
     substraitExprName: String,
     child: ExpressionTransformer,
+    childResultType: DataType,
     original: CheckOverflow)
   extends ExpressionTransformer {
 
@@ -160,6 +161,7 @@ case class CheckOverflowTransformer(
       args,
       substraitExprName,
       child.doTransform(args),
+      childResultType,
       original.dataType,
       original.nullable,
       original.nullOnOverflow)
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala
index 148cc4e60..479eb8bb5 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala
@@ -18,69 +18,40 @@ package org.apache.gluten.utils
 
 import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.exception.GlutenNotSupportException
-import org.apache.gluten.expression.{CheckOverflowTransformer, 
ChildTransformer, DecimalArithmeticExpressionTransformer, ExpressionTransformer}
 import org.apache.gluten.expression.ExpressionConverter.conf
 
 import org.apache.spark.sql.catalyst.analysis.DecimalPrecision
 import org.apache.spark.sql.catalyst.expressions.{Add, BinaryArithmetic, Cast, 
Divide, Expression, Literal, Multiply, Pmod, PromotePrecision, Remainder, 
Subtract}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, 
IntegerType, LongType, ShortType}
-
-import scala.annotation.tailrec
+import org.apache.spark.sql.utils.DecimalTypeUtil
 
 object DecimalArithmeticUtil {
 
-  object OperationType extends Enumeration {
-    type Config = Value
-    val ADD, SUBTRACT, MULTIPLY, DIVIDE, MOD = Value
-  }
-
-  private val MIN_ADJUSTED_SCALE = 6
-  val MAX_PRECISION = 38
-
   // Returns the result decimal type of a decimal arithmetic computing.
-  def getResultTypeForOperation(
-      operationType: OperationType.Config,
-      type1: DecimalType,
-      type2: DecimalType): DecimalType = {
+  def getResultType(expr: BinaryArithmetic, type1: DecimalType, type2: 
DecimalType): DecimalType = {
     var resultScale = 0
     var resultPrecision = 0
-    operationType match {
-      case OperationType.ADD =>
+    expr match {
+      case _: Add =>
         resultScale = Math.max(type1.scale, type2.scale)
         resultPrecision =
           resultScale + Math.max(type1.precision - type1.scale, 
type2.precision - type2.scale) + 1
-      case OperationType.SUBTRACT =>
+      case _: Subtract =>
         resultScale = Math.max(type1.scale, type2.scale)
         resultPrecision =
           resultScale + Math.max(type1.precision - type1.scale, 
type2.precision - type2.scale) + 1
-      case OperationType.MULTIPLY =>
+      case _: Multiply =>
         resultScale = type1.scale + type2.scale
         resultPrecision = type1.precision + type2.precision + 1
-      case OperationType.DIVIDE =>
-        resultScale = Math.max(MIN_ADJUSTED_SCALE, type1.scale + 
type2.precision + 1)
+      case _: Divide =>
+        resultScale =
+          Math.max(DecimalType.MINIMUM_ADJUSTED_SCALE, type1.scale + 
type2.precision + 1)
         resultPrecision = type1.precision - type1.scale + type2.scale + 
resultScale
-      case OperationType.MOD =>
-        resultScale = Math.max(type1.scale, type2.scale)
-        resultPrecision =
-          Math.min(type1.precision - type1.scale, type2.precision - 
type2.scale + resultScale)
       case other =>
         throw new GlutenNotSupportException(s"$other is not supported.")
     }
-    adjustScaleIfNeeded(resultPrecision, resultScale)
-  }
-
-  // Returns the adjusted decimal type when the precision is larger the 
maximum.
-  private def adjustScaleIfNeeded(precision: Int, scale: Int): DecimalType = {
-    var typePrecision = precision
-    var typeScale = scale
-    if (precision > MAX_PRECISION) {
-      val minScale = Math.min(scale, MIN_ADJUSTED_SCALE)
-      val delta = precision - MAX_PRECISION
-      typePrecision = MAX_PRECISION
-      typeScale = Math.max(scale - delta, minScale)
-    }
-    DecimalType(typePrecision, typeScale)
+    DecimalTypeUtil.adjustPrecisionScale(resultPrecision, resultScale)
   }
 
   // If casting between DecimalType, unnecessary cast is skipped to avoid data 
loss,
@@ -98,18 +69,6 @@ object DecimalArithmeticUtil {
     } else false
   }
 
-  // Returns the operation type of a binary arithmetic expression.
-  def getOperationType(b: BinaryArithmetic): OperationType.Config = {
-    b match {
-      case _: Add => OperationType.ADD
-      case _: Subtract => OperationType.SUBTRACT
-      case _: Multiply => OperationType.MULTIPLY
-      case _: Divide => OperationType.DIVIDE
-      case other =>
-        throw new GlutenNotSupportException(s"$other is not supported.")
-    }
-  }
-
   // For decimal * 10 case, dec will be Decimal(38, 18), then the result 
precision is wrong,
   // so here we will get the real precision and scale of the literal.
   private def getNewPrecisionScale(dec: Decimal): (Integer, Integer) = {
@@ -179,9 +138,7 @@ object DecimalArithmeticUtil {
       if (isWiderType) (e1, newE2) else (e1, e2)
     }
 
-    if (!BackendsApiManager.getSettings.rescaleDecimalIntegralExpression()) {
-      (left, right)
-    } else if (!isPromoteCast(left) && isPromoteCastIntegral(right)) {
+    if (!isPromoteCast(left) && isPromoteCastIntegral(right)) {
       // Have removed PromotePrecision(Cast(DecimalType)).
       // Decimal * cast int.
       doScale(left, right)
@@ -202,66 +159,32 @@ object DecimalArithmeticUtil {
    * @return
    *   expression removed child PromotePrecision->Cast
    */
-  def removeCastForDecimal(arithmeticExpr: Expression): Expression = {
-    arithmeticExpr match {
-      case precision: PromotePrecision =>
-        precision.child match {
-          case cast: Cast
-              if cast.dataType.isInstanceOf[DecimalType]
-                && cast.child.dataType.isInstanceOf[DecimalType] =>
-            cast.child
-          case _ => arithmeticExpr
-        }
-      case _ => arithmeticExpr
-    }
+  def removeCastForDecimal(arithmeticExpr: Expression): Expression = 
arithmeticExpr match {
+    case PromotePrecision(_ @Cast(child, _: DecimalType, _, _))
+        if child.dataType.isInstanceOf[DecimalType] =>
+      child
+    case _ => arithmeticExpr
   }
 
-  @tailrec
-  def getResultType(transformer: ExpressionTransformer): Option[DecimalType] = 
{
-    transformer match {
-      case ChildTransformer(child) =>
-        getResultType(child)
-      case CheckOverflowTransformer(_, _, original) =>
-        Some(original.dataType)
-      case DecimalArithmeticExpressionTransformer(_, _, _, resultType, _) =>
-        Some(resultType)
-      case _ => None
-    }
-  }
-
-  private def isPromoteCastIntegral(expr: Expression): Boolean = {
-    expr match {
-      case precision: PromotePrecision =>
-        precision.child match {
-          case cast: Cast if cast.dataType.isInstanceOf[DecimalType] =>
-            cast.child.dataType match {
-              case IntegerType | ByteType | ShortType | LongType => true
-              case _ => false
-            }
-          case _ => false
-        }
-      case _ => false
-    }
+  private def isPromoteCastIntegral(expr: Expression): Boolean = expr match {
+    case PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) =>
+      child.dataType match {
+        case IntegerType | ByteType | ShortType | LongType => true
+        case _ => false
+      }
+    case _ => false
   }
 
-  private def rescaleCastForOneSide(expr: Expression): Expression = {
-    expr match {
-      case precision: PromotePrecision =>
-        precision.child match {
-          case castInt: Cast
-              if castInt.dataType.isInstanceOf[DecimalType] &&
-                
BackendsApiManager.getSettings.rescaleDecimalIntegralExpression() =>
-            castInt.child.dataType match {
-              case IntegerType | ByteType | ShortType =>
-                precision.withNewChildren(Seq(Cast(castInt.child, 
DecimalType(10, 0))))
-              case LongType =>
-                precision.withNewChildren(Seq(Cast(castInt.child, 
DecimalType(20, 0))))
-              case _ => expr
-            }
-          case _ => expr
-        }
-      case _ => expr
-    }
+  private def rescaleCastForOneSide(expr: Expression): Expression = expr match 
{
+    case precision @ PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) =>
+      child.dataType match {
+        case IntegerType | ByteType | ShortType =>
+          precision.withNewChildren(Seq(Cast(child, DecimalType(10, 0))))
+        case LongType =>
+          precision.withNewChildren(Seq(Cast(child, DecimalType(20, 0))))
+        case _ => expr
+      }
+    case _ => expr
   }
 
   private def checkIsWiderType(
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/sql/utils/DecimalTypeUtil.scala 
b/gluten-core/src/main/scala/org/apache/spark/sql/utils/DecimalTypeUtil.scala
new file mode 100644
index 000000000..f7334bcb2
--- /dev/null
+++ 
b/gluten-core/src/main/scala/org/apache/spark/sql/utils/DecimalTypeUtil.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.utils
+
+import org.apache.spark.sql.types.DecimalType
+
+object DecimalTypeUtil {
+  def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
+    DecimalType.adjustPrecisionScale(precision, scale)
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to