This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 30dd599479866ca15449e40cc66855d3ecf11b19
Author: starocean999 <[email protected]>
AuthorDate: Thu Sep 14 15:53:23 2023 +0800

    [fix](nereids)the common type of decimalv2 and decimalv3 shoud be decimalv3 
in BinaryArithmetic operator (#24215)
    
    the common type of decimalv2 and decimalv3 shoud be decimalv3 in 
BinaryArithmetic operator
---
 .../doris/nereids/util/TypeCoercionUtils.java      | 52 +++++++++++++---------
 .../doris/nereids/util/TypeCoercionUtilsTest.java  | 35 +++++++++++++++
 2 files changed, 67 insertions(+), 20 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
index 2efa0d36f8..dc3bbf78ae 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
@@ -561,6 +561,11 @@ public class TypeCoercionUtils {
             commonType = DoubleType.INSTANCE;
         }
 
+        if (t1.isDecimalV3Type() && t2.isDecimalV2Type()
+                || t1.isDecimalV2Type() && t2.isDecimalV3Type()) {
+            return processDecimalV3BinaryArithmetic(binaryArithmetic, left, 
right);
+        }
+
         if (t1.isDecimalV2Type() || t2.isDecimalV2Type()) {
             // to be consitent with old planner
             // see findCommonType() method in ArithmeticExpr.java
@@ -599,26 +604,7 @@ public class TypeCoercionUtils {
 
         // double and float already process, we only process decimalv2 and 
fixed point number.
         if (t1 instanceof DecimalV3Type || t2 instanceof DecimalV3Type) {
-            DecimalV3Type dt1 = DecimalV3Type.forType(t1);
-            DecimalV3Type dt2 = DecimalV3Type.forType(t2);
-
-            // check return type whether overflow, if true, turn to double
-            DecimalV3Type retType;
-            try {
-                retType = binaryArithmetic.getDataTypeForDecimalV3(dt1, dt2);
-            } catch (Exception e) {
-                // exception means overflow.
-                return castChildren(binaryArithmetic, left, right, 
DoubleType.INSTANCE);
-            }
-
-            // add, subtract and mod should cast children to exactly same type 
as return type
-            if (binaryArithmetic instanceof Add
-                    || binaryArithmetic instanceof Subtract
-                    || binaryArithmetic instanceof Mod) {
-                return castChildren(binaryArithmetic, left, right, retType);
-            }
-            // multiply do not need to cast children to same type
-            return binaryArithmetic.withChildren(castIfNotSameType(left, dt1), 
castIfNotSameType(right, dt2));
+            return processDecimalV3BinaryArithmetic(binaryArithmetic, left, 
right);
         }
 
         // double, float and decimalv3 already process, we only process fixed 
point number
@@ -1219,4 +1205,30 @@ public class TypeCoercionUtils {
             throw new AnalysisException(t.getMessage());
         }
     }
+
+    private static Expression 
processDecimalV3BinaryArithmetic(BinaryArithmetic binaryArithmetic,
+            Expression left, Expression right) {
+        DecimalV3Type dt1 =
+                
DecimalV3Type.forType(TypeCoercionUtils.getNumResultType(left.getDataType()));
+        DecimalV3Type dt2 =
+                
DecimalV3Type.forType(TypeCoercionUtils.getNumResultType(right.getDataType()));
+
+        // check return type whether overflow, if true, turn to double
+        DecimalV3Type retType;
+        try {
+            retType = binaryArithmetic.getDataTypeForDecimalV3(dt1, dt2);
+        } catch (Exception e) {
+            // exception means overflow.
+            return castChildren(binaryArithmetic, left, right, 
DoubleType.INSTANCE);
+        }
+
+        // add, subtract and mod should cast children to exactly same type as 
return type
+        if (binaryArithmetic instanceof Add || binaryArithmetic instanceof 
Subtract
+                || binaryArithmetic instanceof Mod) {
+            return castChildren(binaryArithmetic, left, right, retType);
+        }
+        // multiply do not need to cast children to same type
+        return binaryArithmetic.withChildren(castIfNotSameType(left, dt1),
+                castIfNotSameType(right, dt2));
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java
index 8a11a4bf0a..fe9d6b5363 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java
@@ -17,7 +17,14 @@
 
 package org.apache.doris.nereids.util;
 
+import org.apache.doris.nereids.trees.expressions.Add;
 import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.Divide;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Multiply;
+import org.apache.doris.nereids.trees.expressions.Subtract;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
 import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
 import org.apache.doris.nereids.types.ArrayType;
 import org.apache.doris.nereids.types.BigIntType;
@@ -49,6 +56,7 @@ import org.apache.doris.nereids.types.coercion.IntegralType;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import java.math.BigDecimal;
 import java.util.Optional;
 
 public class TypeCoercionUtilsTest {
@@ -688,4 +696,31 @@ public class TypeCoercionUtilsTest {
         Assertions.assertEquals(new Cast(new DoubleLiteral(5L), 
BooleanType.INSTANCE),
                 TypeCoercionUtils.castIfNotMatchType(new DoubleLiteral(5L), 
BooleanType.INSTANCE));
     }
+
+    @Test
+    public void testDecimalArithmetic() {
+        Multiply multiply = new Multiply(new DecimalLiteral(new 
BigDecimal("987654.321")),
+                new DecimalV3Literal(new BigDecimal("123.45")));
+        Expression expression = 
TypeCoercionUtils.processBinaryArithmetic(multiply);
+        Assertions.assertEquals(expression.child(0),
+                new Cast(multiply.child(0), 
DecimalV3Type.createDecimalV3Type(9, 3)));
+
+        Divide divide = new Divide(new DecimalLiteral(new 
BigDecimal("987654.321")),
+                new DecimalV3Literal(new BigDecimal("123.45")));
+        expression = TypeCoercionUtils.processBinaryArithmetic(divide);
+        Assertions.assertEquals(expression.child(0),
+                new Cast(multiply.child(0), 
DecimalV3Type.createDecimalV3Type(9, 3)));
+
+        Add add = new Add(new DecimalLiteral(new BigDecimal("987654.321")),
+                new DecimalV3Literal(new BigDecimal("123.45")));
+        expression = TypeCoercionUtils.processBinaryArithmetic(add);
+        Assertions.assertEquals(expression.child(0),
+                new Cast(multiply.child(0), 
DecimalV3Type.createDecimalV3Type(9, 3)));
+
+        Subtract sub = new Subtract(new DecimalLiteral(new 
BigDecimal("987654.321")),
+                new DecimalV3Literal(new BigDecimal("123.45")));
+        expression = TypeCoercionUtils.processBinaryArithmetic(sub);
+        Assertions.assertEquals(expression.child(0),
+                new Cast(multiply.child(0), 
DecimalV3Type.createDecimalV3Type(9, 3)));
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to