baibaichen commented on code in PR #5621:
URL: https://github.com/apache/incubator-gluten/pull/5621#discussion_r1593308395


##########
gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala:
##########
@@ -159,56 +163,30 @@ object DecimalArithmeticUtil {
   }
 
   // Returns whether the input expression is a combination of 
PromotePrecision(Cast as DecimalType).
-  private def isPromoteCast(expr: Expression): Boolean = {
-    expr match {
-      case precision: PromotePrecision =>
-        precision.child match {
-          case cast: Cast if cast.dataType.isInstanceOf[DecimalType] => true
-          case _ => false
-        }
-      case _ => false
-    }
+  private def isPromoteCast(expr: Expression): Boolean = expr match {
+    case PromotePrecision(Cast(_, _: DecimalType, _, _)) => true
+    case _ => false
   }
 
   def rescaleCastForDecimal(left: Expression, right: Expression): (Expression, 
Expression) = {
-    if (!BackendsApiManager.getSettings.rescaleDecimalIntegralExpression()) {
-      return (left, right)
+
+    def doScale(e1: Expression, e2: Expression): (Expression, Expression) = {
+      val newE2 = rescaleCastForOneSide(e2)
+      val isWiderType = checkIsWiderType(
+        e1.dataType.asInstanceOf[DecimalType],
+        newE2.dataType.asInstanceOf[DecimalType],
+        e2.dataType.asInstanceOf[DecimalType])
+      if (isWiderType) (e1, newE2) else (e1, e2)
     }
-    // Decimal * cast int.
-    if (!isPromoteCast(left)) {
-      // Have removed PromotePrecision(Cast(DecimalType)).
-      if (isPromoteCastIntegral(right)) {
-        val newRight = rescaleCastForOneSide(right)
-        val isWiderType = checkIsWiderType(
-          left.dataType.asInstanceOf[DecimalType],
-          newRight.dataType.asInstanceOf[DecimalType],
-          right.dataType.asInstanceOf[DecimalType])
-        if (isWiderType) {
-          (left, newRight)
-        } else {
-          (left, right)
-        }
-      } else {
-        (left, right)
-      }
-      // Cast int * decimal.
-    } else if (!isPromoteCast(right)) {
-      if (isPromoteCastIntegral(left)) {
-        val newLeft = rescaleCastForOneSide(left)
-        val isWiderType = checkIsWiderType(
-          newLeft.dataType.asInstanceOf[DecimalType],
-          right.dataType.asInstanceOf[DecimalType],
-          left.dataType.asInstanceOf[DecimalType])
-        if (isWiderType) {
-          (newLeft, right)
-        } else {
-          (left, right)
-        }
-      } else {
-        (left, right)
-      }
+
+    if (!BackendsApiManager.getSettings.rescaleDecimalIntegralExpression()) {
+      (left, right)
+    } else if (!isPromoteCast(left) && isPromoteCastIntegral(right)) {

Review Comment:
   done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to