This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch branch-1.2-lts
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 2ffffbc1065d788ae7ac2d3591c3fc0f27f7aa60
Author: Gabriel <[email protected]>
AuthorDate: Sat Mar 25 09:42:39 2023 +0800

    [Bug](DECIMALV3) Fix wrong precision for plus/minus (#18052)
    
    Result type for DECIMAL(x, y) plus/minus DECIMAL(m, n) should be 
DECIMAL(max(x - y, m - n) + max(y + n) + 1, max(y + n))
---
 be/src/vec/data_types/data_type_decimal.h          |  2 +-
 .../main/java/org/apache/doris/catalog/Type.java   | 23 ++++++++++++++++++++++
 .../org/apache/doris/analysis/ArithmeticExpr.java  |  6 +++---
 .../java/org/apache/doris/analysis/CastExpr.java   |  2 +-
 .../decimalv3/test_arithmetic_expressions.out      | 15 ++++++++++++++
 .../decimalv3/test_arithmetic_expressions.groovy   | 16 +++++++++++++++
 6 files changed, 59 insertions(+), 5 deletions(-)

diff --git a/be/src/vec/data_types/data_type_decimal.h 
b/be/src/vec/data_types/data_type_decimal.h
index 358fe79438..c644fbcfc0 100644
--- a/be/src/vec/data_types/data_type_decimal.h
+++ b/be/src/vec/data_types/data_type_decimal.h
@@ -236,7 +236,7 @@ DataTypePtr decimal_result_type(const DataTypeDecimal<T>& 
tx, const DataTypeDeci
         size_t divide_precision = tx.get_precision() + ty.get_scale();
         size_t plus_minus_precision =
                 std::max(tx.get_precision() - tx.get_scale(), 
ty.get_precision() - ty.get_scale()) +
-                scale;
+                scale + 1;
         if (is_multiply) {
             scale = tx.get_scale() + ty.get_scale();
             precision = std::min(multiply_precision, 
max_decimal_precision<Decimal128I>());
diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java 
b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java
index f5d02f5ae1..671da2c252 100644
--- a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java
+++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java
@@ -1713,6 +1713,10 @@ public abstract class Type {
 
     // Whether `type1` matches the exact type of `type2`.
     public static boolean matchExactType(Type type1, Type type2) {
+        return matchExactType(type1, type2, false);
+    }
+
+    public static boolean matchExactType(Type type1, Type type2, boolean 
ignorePrecision) {
         if (type1.matchesType(type2)) {
             if 
(PrimitiveType.typeWithPrecision.contains(type2.getPrimitiveType())) {
                 // For types which has precision and scale, we also need to 
check quality between precisions and scales
@@ -1720,6 +1724,10 @@ public abstract class Type {
                         == ((ScalarType) type1).decimalPrecision()) && 
(((ScalarType) type2).decimalScale()
                         == ((ScalarType) type1).decimalScale())) {
                     return true;
+                } else if (((ScalarType) type2).decimalScale() == 
((ScalarType) type1).decimalScale()
+                        && ignorePrecision) {
+                    return 
isSameDecimalTypeWithDifferentPrecision(((ScalarType) type2).decimalPrecision(),
+                            ((ScalarType) type1).decimalPrecision());
                 }
             } else if (type2.isArrayType()) {
                 // For types array, we also need to check contains null for 
case like
@@ -1733,5 +1741,20 @@ public abstract class Type {
         }
         return false;
     }
+
+    public static boolean isSameDecimalTypeWithDifferentPrecision(int 
precision1, int precision2) {
+        if (precision1 <= ScalarType.MAX_DECIMAL32_PRECISION && precision2 <= 
ScalarType.MAX_DECIMAL32_PRECISION) {
+            return true;
+        } else if (precision1 > ScalarType.MAX_DECIMAL32_PRECISION && 
precision2 > ScalarType.MAX_DECIMAL32_PRECISION
+                && precision1 <= ScalarType.MAX_DECIMAL64_PRECISION
+                && precision2 <= ScalarType.MAX_DECIMAL64_PRECISION) {
+            return true;
+        } else if (precision1 > ScalarType.MAX_DECIMAL64_PRECISION && 
precision2 > ScalarType.MAX_DECIMAL64_PRECISION
+                && precision1 <= ScalarType.MAX_DECIMAL128_PRECISION
+                && precision2 <= ScalarType.MAX_DECIMAL128_PRECISION) {
+            return true;
+        }
+        return false;
+    }
 }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java
index 02ea5579cb..a6ab76df9c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java
@@ -532,7 +532,7 @@ public class ArithmeticExpr extends Expr {
                     // target type: DECIMALV3(max(widthOfIntPart1, 
widthOfIntPart2) + max(scale1, scale2) + 1,
                     // max(scale1, scale2))
                     scale = Math.max(t1Scale, t2Scale);
-                    precision = Math.max(widthOfIntPart1, widthOfIntPart2) + 
scale;
+                    precision = Math.max(widthOfIntPart1, widthOfIntPart2) + 
scale + 1;
                 } else {
                     scale = Math.max(t1Scale, t2Scale);
                     precision = widthOfIntPart2 + scale;
@@ -547,10 +547,10 @@ public class ArithmeticExpr extends Expr {
                 }
                 type = ScalarType.createDecimalV3Type(precision, scale);
                 if (op == Operator.ADD || op == Operator.SUBTRACT) {
-                    if (!Type.matchExactType(type, children.get(0).type)) {
+                    if (((ScalarType) type).getScalarScale() != ((ScalarType) 
children.get(0).type).getScalarScale()) {
                         castChild(type, 0);
                     }
-                    if (!Type.matchExactType(type, children.get(1).type)) {
+                    if (((ScalarType) type).getScalarScale() != ((ScalarType) 
children.get(1).type).getScalarScale()) {
                         castChild(type, 1);
                     }
                 } else if (op == Operator.DIVIDE && (t2Scale != 0) && 
t1.isDecimalV3()) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
index d2cb8ce442..e364de5279 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
@@ -281,7 +281,7 @@ public class CastExpr extends Expr {
         Type childType = getChild(0).getType();
 
         // this cast may result in loss of precision, but the user requested it
-        noOp = Type.matchExactType(childType, type);
+        noOp = Type.matchExactType(childType, type, true);
 
         if (noOp) {
             // For decimalv2, we do not perform an actual cast between 
different precision/scale. Instead, we just
diff --git 
a/regression-test/data/datatype_p0/decimalv3/test_arithmetic_expressions.out 
b/regression-test/data/datatype_p0/decimalv3/test_arithmetic_expressions.out
index 4f68777f2a..085b844d7c 100644
--- a/regression-test/data/datatype_p0/decimalv3/test_arithmetic_expressions.out
+++ b/regression-test/data/datatype_p0/decimalv3/test_arithmetic_expressions.out
@@ -32,3 +32,18 @@
 2.0736
 3.2399999999999998
 
+-- !select_all --
+999999.999     999999.999      999999.999      999999.999      999999.999      
999999.999      999999.999      999999.999      999999.999      999999.999      
999999.999
+
+-- !select --
+2999999.997
+
+-- !select --
+2999999994000.000003
+
+-- !select --
+3.000
+
+-- !select --
+10999999.989
+
diff --git 
a/regression-test/suites/datatype_p0/decimalv3/test_arithmetic_expressions.groovy
 
b/regression-test/suites/datatype_p0/decimalv3/test_arithmetic_expressions.groovy
index 301d719b15..284cf482e4 100644
--- 
a/regression-test/suites/datatype_p0/decimalv3/test_arithmetic_expressions.groovy
+++ 
b/regression-test/suites/datatype_p0/decimalv3/test_arithmetic_expressions.groovy
@@ -49,4 +49,20 @@ suite("test_arithmetic_expressions") {
     qt_select "select k1 * k2 * k3 * k1 * k2 * k3 from ${table1} order by k1"
     qt_select "select k1 * k2 / k3 * k1 * k2 * k3 from ${table1} order by k1"
     sql "drop table if exists ${table1}"
+
+    sql """
+        CREATE TABLE IF NOT EXISTS ${table1} (             `a` DECIMALV3(9, 3) 
NOT NULL, `b` DECIMALV3(9, 3) NOT NULL, `c` DECIMALV3(9, 3) NOT NULL, `d` 
DECIMALV3(9, 3) NOT NULL, `e` DECIMALV3(9, 3) NOT NULL, `f` DECIMALV3(9, 3) NOT
+        NULL, `g` DECIMALV3(9, 3) NOT NULL , `h` DECIMALV3(9, 3) NOT NULL, `i` 
DECIMALV3(9, 3) NOT NULL, `j` DECIMALV3(9, 3) NOT NULL, `k` DECIMALV3(9, 3) NOT 
NULL)            DISTRIBUTED BY HASH(a) PROPERTIES("replication_num" = "1");
+    """
+
+    sql """
+    insert into ${table1} 
values(999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999);
+    """
+    qt_select_all "select * from ${table1} order by a"
+
+    qt_select "select a + b + c from ${table1};"
+    qt_select "select (a + b + c) * d from ${table1};"
+    qt_select "select (a + b + c) / d from ${table1};"
+    qt_select "select a + b + c + d + e + f + g + h + i + j + k from 
${table1};"
+    sql "drop table if exists ${table1}"
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to