This is an automated email from the ASF dual-hosted git repository. kxiao pushed a commit to branch branch-2.0 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 30dd599479866ca15449e40cc66855d3ecf11b19 Author: starocean999 <[email protected]> AuthorDate: Thu Sep 14 15:53:23 2023 +0800 [fix](nereids)the common type of decimalv2 and decimalv3 shoud be decimalv3 in BinaryArithmetic operator (#24215) the common type of decimalv2 and decimalv3 shoud be decimalv3 in BinaryArithmetic operator --- .../doris/nereids/util/TypeCoercionUtils.java | 52 +++++++++++++--------- .../doris/nereids/util/TypeCoercionUtilsTest.java | 35 +++++++++++++++ 2 files changed, 67 insertions(+), 20 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index 2efa0d36f8..dc3bbf78ae 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -561,6 +561,11 @@ public class TypeCoercionUtils { commonType = DoubleType.INSTANCE; } + if (t1.isDecimalV3Type() && t2.isDecimalV2Type() + || t1.isDecimalV2Type() && t2.isDecimalV3Type()) { + return processDecimalV3BinaryArithmetic(binaryArithmetic, left, right); + } + if (t1.isDecimalV2Type() || t2.isDecimalV2Type()) { // to be consitent with old planner // see findCommonType() method in ArithmeticExpr.java @@ -599,26 +604,7 @@ public class TypeCoercionUtils { // double and float already process, we only process decimalv2 and fixed point number. if (t1 instanceof DecimalV3Type || t2 instanceof DecimalV3Type) { - DecimalV3Type dt1 = DecimalV3Type.forType(t1); - DecimalV3Type dt2 = DecimalV3Type.forType(t2); - - // check return type whether overflow, if true, turn to double - DecimalV3Type retType; - try { - retType = binaryArithmetic.getDataTypeForDecimalV3(dt1, dt2); - } catch (Exception e) { - // exception means overflow. - return castChildren(binaryArithmetic, left, right, DoubleType.INSTANCE); - } - - // add, subtract and mod should cast children to exactly same type as return type - if (binaryArithmetic instanceof Add - || binaryArithmetic instanceof Subtract - || binaryArithmetic instanceof Mod) { - return castChildren(binaryArithmetic, left, right, retType); - } - // multiply do not need to cast children to same type - return binaryArithmetic.withChildren(castIfNotSameType(left, dt1), castIfNotSameType(right, dt2)); + return processDecimalV3BinaryArithmetic(binaryArithmetic, left, right); } // double, float and decimalv3 already process, we only process fixed point number @@ -1219,4 +1205,30 @@ public class TypeCoercionUtils { throw new AnalysisException(t.getMessage()); } } + + private static Expression processDecimalV3BinaryArithmetic(BinaryArithmetic binaryArithmetic, + Expression left, Expression right) { + DecimalV3Type dt1 = + DecimalV3Type.forType(TypeCoercionUtils.getNumResultType(left.getDataType())); + DecimalV3Type dt2 = + DecimalV3Type.forType(TypeCoercionUtils.getNumResultType(right.getDataType())); + + // check return type whether overflow, if true, turn to double + DecimalV3Type retType; + try { + retType = binaryArithmetic.getDataTypeForDecimalV3(dt1, dt2); + } catch (Exception e) { + // exception means overflow. + return castChildren(binaryArithmetic, left, right, DoubleType.INSTANCE); + } + + // add, subtract and mod should cast children to exactly same type as return type + if (binaryArithmetic instanceof Add || binaryArithmetic instanceof Subtract + || binaryArithmetic instanceof Mod) { + return castChildren(binaryArithmetic, left, right, retType); + } + // multiply do not need to cast children to same type + return binaryArithmetic.withChildren(castIfNotSameType(left, dt1), + castIfNotSameType(right, dt2)); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java index 8a11a4bf0a..fe9d6b5363 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java @@ -17,7 +17,14 @@ package org.apache.doris.nereids.util; +import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Divide; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Multiply; +import org.apache.doris.nereids.trees.expressions.Subtract; +import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.BigIntType; @@ -49,6 +56,7 @@ import org.apache.doris.nereids.types.coercion.IntegralType; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.math.BigDecimal; import java.util.Optional; public class TypeCoercionUtilsTest { @@ -688,4 +696,31 @@ public class TypeCoercionUtilsTest { Assertions.assertEquals(new Cast(new DoubleLiteral(5L), BooleanType.INSTANCE), TypeCoercionUtils.castIfNotMatchType(new DoubleLiteral(5L), BooleanType.INSTANCE)); } + + @Test + public void testDecimalArithmetic() { + Multiply multiply = new Multiply(new DecimalLiteral(new BigDecimal("987654.321")), + new DecimalV3Literal(new BigDecimal("123.45"))); + Expression expression = TypeCoercionUtils.processBinaryArithmetic(multiply); + Assertions.assertEquals(expression.child(0), + new Cast(multiply.child(0), DecimalV3Type.createDecimalV3Type(9, 3))); + + Divide divide = new Divide(new DecimalLiteral(new BigDecimal("987654.321")), + new DecimalV3Literal(new BigDecimal("123.45"))); + expression = TypeCoercionUtils.processBinaryArithmetic(divide); + Assertions.assertEquals(expression.child(0), + new Cast(multiply.child(0), DecimalV3Type.createDecimalV3Type(9, 3))); + + Add add = new Add(new DecimalLiteral(new BigDecimal("987654.321")), + new DecimalV3Literal(new BigDecimal("123.45"))); + expression = TypeCoercionUtils.processBinaryArithmetic(add); + Assertions.assertEquals(expression.child(0), + new Cast(multiply.child(0), DecimalV3Type.createDecimalV3Type(9, 3))); + + Subtract sub = new Subtract(new DecimalLiteral(new BigDecimal("987654.321")), + new DecimalV3Literal(new BigDecimal("123.45"))); + expression = TypeCoercionUtils.processBinaryArithmetic(sub); + Assertions.assertEquals(expression.child(0), + new Cast(multiply.child(0), DecimalV3Type.createDecimalV3Type(9, 3))); + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
