This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch dev-1.0.0
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git


The following commit(s) were added to refs/heads/dev-1.0.0 by this push:
     new 4e7ea50  [fix](vectorized) fix arithmetic calculate get wrong 
result(#8226)
4e7ea50 is described below

commit 4e7ea50988ecba28011669540212eb1ff7ce801c
Author: Pxl <[email protected]>
AuthorDate: Wed Mar 9 13:03:57 2022 +0800

    [fix](vectorized) fix arithmetic calculate get wrong result(#8226)
---
 be/src/exprs/arithmetic_expr.cpp       | 66 ++++++++++++++++++++++++++++++++++
 be/src/exprs/arithmetic_expr.h         | 11 ++++++
 be/src/exprs/expr.cpp                  |  2 ++
 be/src/vec/functions/minus.cpp         |  1 +
 be/src/vec/functions/modulo.cpp        |  3 ++
 be/src/vec/functions/multiply.cpp      |  1 +
 be/src/vec/functions/plus.cpp          |  1 +
 fe/fe-core/src/main/cup/sql_parser.cup |  2 ++
 8 files changed, 87 insertions(+)

diff --git a/be/src/exprs/arithmetic_expr.cpp b/be/src/exprs/arithmetic_expr.cpp
index dc4399b..2366880 100644
--- a/be/src/exprs/arithmetic_expr.cpp
+++ b/be/src/exprs/arithmetic_expr.cpp
@@ -21,6 +21,10 @@
 
 namespace doris {
 
+std::set<std::string> ArithmeticExpr::_s_valid_fn_names = {
+        "add", "subtract", "multiply", "divide", "int_divide",
+        "mod", "bitand",   "bitor",    "bitxor", "bitnot"};
+
 Expr* ArithmeticExpr::from_thrift(const TExprNode& node) {
     switch (node.opcode) {
     case TExprOpcode::ADD:
@@ -48,6 +52,31 @@ Expr* ArithmeticExpr::from_thrift(const TExprNode& node) {
     return nullptr;
 }
 
+Expr* ArithmeticExpr::from_fn_name(const TExprNode& node) {
+    std::string fn_name = node.fn.name.function_name;
+    if (fn_name == "add") {
+        return new AddExpr(node);
+    } else if (fn_name == "subtract") {
+        return new SubExpr(node);
+    } else if (fn_name == "multiply") {
+        return new MulExpr(node);
+    } else if (fn_name == "divide" || fn_name == "int_divide") {
+        return new DivExpr(node);
+    } else if (fn_name == "mod") {
+        return new ModExpr(node);
+    } else if (fn_name == "bitand") {
+        return new BitAndExpr(node);
+    } else if (fn_name == "bitor") {
+        return new BitOrExpr(node);
+    } else if (fn_name == "bitxor") {
+        return new BitXorExpr(node);
+    } else if (fn_name == "bitnot") {
+        return new BitNotExpr(node);
+    }
+
+    return nullptr;
+}
+
 #define BINARY_OP_CHECK_ZERO_FN(TYPE, CLASS, FN, OP)      \
     TYPE CLASS::FN(ExprContext* context, TupleRow* row) { \
         TYPE v1 = _children[0]->FN(context, row);         \
@@ -159,4 +188,41 @@ BINARY_BIT_FNS(BitXorExpr, ^)
     BITNOT_OP_FN(LargeIntVal, get_large_int_val)
 
 BITNOT_FNS()
+
+#define DECIMAL_ARITHMETIC_OP(EXPR_NAME, OP)                                   
      \
+    DecimalV2Val EXPR_NAME::get_decimalv2_val(ExprContext* context, TupleRow* 
row) { \
+        DecimalV2Val v1 = _children[0]->get_decimalv2_val(context, row);       
      \
+        DecimalV2Val v2 = _children[1]->get_decimalv2_val(context, row);       
      \
+        if (v1.is_null || v2.is_null) {                                        
      \
+            return DecimalV2Val::null();                                       
      \
+        }                                                                      
      \
+        DecimalV2Value iv1 = DecimalV2Value::from_decimal_val(v1);             
      \
+        DecimalV2Value iv2 = DecimalV2Value::from_decimal_val(v2);             
      \
+        DecimalV2Value ir = iv1 OP iv2;                                        
      \
+        DecimalV2Val result;                                                   
      \
+        ir.to_decimal_val(&result);                                            
      \
+        return result;                                                         
      \
+    }
+
+#define DECIMAL_ARITHMETIC_OP_DIVIDE(EXPR_NAME, OP)                            
      \
+    DecimalV2Val EXPR_NAME::get_decimalv2_val(ExprContext* context, TupleRow* 
row) { \
+        DecimalV2Val v1 = _children[0]->get_decimalv2_val(context, row);       
      \
+        DecimalV2Val v2 = _children[1]->get_decimalv2_val(context, row);       
      \
+        if (v1.is_null || v2.is_null || v2.value() == 0) {                     
      \
+            return DecimalV2Val::null();                                       
      \
+        }                                                                      
      \
+        DecimalV2Value iv1 = DecimalV2Value::from_decimal_val(v1);             
      \
+        DecimalV2Value iv2 = DecimalV2Value::from_decimal_val(v2);             
      \
+        DecimalV2Value ir = iv1 OP iv2;                                        
      \
+        DecimalV2Val result;                                                   
      \
+        ir.to_decimal_val(&result);                                            
      \
+        return result;                                                         
      \
+    }
+
+DECIMAL_ARITHMETIC_OP(AddExpr, +);
+DECIMAL_ARITHMETIC_OP(SubExpr, -);
+DECIMAL_ARITHMETIC_OP(MulExpr, *);
+DECIMAL_ARITHMETIC_OP_DIVIDE(DivExpr, /);
+DECIMAL_ARITHMETIC_OP_DIVIDE(ModExpr, %);
+
 } // namespace doris
diff --git a/be/src/exprs/arithmetic_expr.h b/be/src/exprs/arithmetic_expr.h
index 5fd56a9..4062847 100644
--- a/be/src/exprs/arithmetic_expr.h
+++ b/be/src/exprs/arithmetic_expr.h
@@ -18,6 +18,8 @@
 #ifndef DORIS_BE_SRC_EXPRS_ARITHMETIC_EXPR_H
 #define DORIS_BE_SRC_EXPRS_ARITHMETIC_EXPR_H
 
+#include <set>
+
 #include "common/object_pool.h"
 #include "exprs/expr.h"
 
@@ -25,7 +27,9 @@ namespace doris {
 
 class ArithmeticExpr : public Expr {
 public:
+    static bool is_valid(std::string fn_name) { return 
_s_valid_fn_names.count(fn_name); }
     static Expr* from_thrift(const TExprNode& node);
+    static Expr* from_fn_name(const TExprNode& node);
 
 protected:
     enum BinaryOpType {
@@ -42,6 +46,8 @@ protected:
 
     ArithmeticExpr(const TExprNode& node) : Expr(node) {}
     virtual ~ArithmeticExpr() {}
+
+    static std::set<std::string> _s_valid_fn_names;
 };
 
 class AddExpr : public ArithmeticExpr {
@@ -56,6 +62,7 @@ public:
     virtual LargeIntVal get_large_int_val(ExprContext* context, TupleRow*) 
override;
     virtual FloatVal get_float_val(ExprContext* context, TupleRow*) override;
     virtual DoubleVal get_double_val(ExprContext* context, TupleRow*) override;
+    virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*) 
override;
 };
 
 class SubExpr : public ArithmeticExpr {
@@ -70,6 +77,7 @@ public:
     virtual LargeIntVal get_large_int_val(ExprContext* context, TupleRow*) 
override;
     virtual FloatVal get_float_val(ExprContext* context, TupleRow*) override;
     virtual DoubleVal get_double_val(ExprContext* context, TupleRow*) override;
+    virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*) 
override;
 };
 
 class MulExpr : public ArithmeticExpr {
@@ -84,6 +92,7 @@ public:
     virtual LargeIntVal get_large_int_val(ExprContext* context, TupleRow*) 
override;
     virtual FloatVal get_float_val(ExprContext* context, TupleRow*) override;
     virtual DoubleVal get_double_val(ExprContext* context, TupleRow*) override;
+    virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*) 
override;
 };
 
 class DivExpr : public ArithmeticExpr {
@@ -98,6 +107,7 @@ public:
     virtual LargeIntVal get_large_int_val(ExprContext* context, TupleRow*) 
override;
     virtual FloatVal get_float_val(ExprContext* context, TupleRow*) override;
     virtual DoubleVal get_double_val(ExprContext* context, TupleRow*) override;
+    virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*) 
override;
 };
 
 class ModExpr : public ArithmeticExpr {
@@ -112,6 +122,7 @@ public:
     virtual LargeIntVal get_large_int_val(ExprContext* context, TupleRow*) 
override;
     virtual FloatVal get_float_val(ExprContext* context, TupleRow*) override;
     virtual DoubleVal get_double_val(ExprContext* context, TupleRow*) override;
+    virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*) 
override;
 };
 
 class BitAndExpr : public ArithmeticExpr {
diff --git a/be/src/exprs/expr.cpp b/be/src/exprs/expr.cpp
index 73f3775..7429068 100644
--- a/be/src/exprs/expr.cpp
+++ b/be/src/exprs/expr.cpp
@@ -360,6 +360,8 @@ Status Expr::create_expr(ObjectPool* pool, const TExprNode& 
texpr_node, Expr** e
             *expr = pool->add(new CoalesceExpr(texpr_node));
         } else if (texpr_node.fn.binary_type == TFunctionBinaryType::RPC) {
             *expr = pool->add(new RPCFnCall(texpr_node));
+        } else if (ArithmeticExpr::is_valid(texpr_node.fn.name.function_name)) 
{
+            *expr = pool->add(ArithmeticExpr::from_fn_name(texpr_node));
         } else {
             *expr = pool->add(new ScalarFnCall(texpr_node));
         }
diff --git a/be/src/vec/functions/minus.cpp b/be/src/vec/functions/minus.cpp
index e215a52..52b86a1 100644
--- a/be/src/vec/functions/minus.cpp
+++ b/be/src/vec/functions/minus.cpp
@@ -34,6 +34,7 @@ struct MinusImpl {
         return static_cast<Result>(a) - b;
     }
 
+    template <typename Result = DecimalV2Value>
     static inline DecimalV2Value apply(DecimalV2Value a, DecimalV2Value b) {
         return a - b;
     }
diff --git a/be/src/vec/functions/modulo.cpp b/be/src/vec/functions/modulo.cpp
index 0e8bf49..28b0ec7 100644
--- a/be/src/vec/functions/modulo.cpp
+++ b/be/src/vec/functions/modulo.cpp
@@ -18,6 +18,7 @@
 // 
https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/Modulo.cpp
 // and modified by Doris
 
+#include "runtime/decimalv2_value.h"
 #ifdef __SSE2__
 #define LIBDIVIDE_SSE2 1
 #endif
@@ -47,6 +48,7 @@ struct ModuloImpl {
         }
     }
 
+    template <typename Result = DecimalV2Value>
     static inline DecimalV2Value apply(DecimalV2Value a, DecimalV2Value b, 
NullMap& null_map,
                                        size_t index) {
         null_map[index] = b == DecimalV2Value(0);
@@ -72,6 +74,7 @@ struct PModuloImpl {
         }
     }
 
+    template <typename Result = DecimalV2Value>
     static inline DecimalV2Value apply(DecimalV2Value a, DecimalV2Value b, 
NullMap& null_map,
                                        size_t index) {
         null_map[index] = b == DecimalV2Value(0);
diff --git a/be/src/vec/functions/multiply.cpp 
b/be/src/vec/functions/multiply.cpp
index 73cc854..b2840e4 100644
--- a/be/src/vec/functions/multiply.cpp
+++ b/be/src/vec/functions/multiply.cpp
@@ -34,6 +34,7 @@ struct MultiplyImpl {
         return static_cast<Result>(a) * b;
     }
 
+    template <typename Result = DecimalV2Value>
     static inline DecimalV2Value apply(DecimalV2Value a, DecimalV2Value b) {
         return a * b;
     }
diff --git a/be/src/vec/functions/plus.cpp b/be/src/vec/functions/plus.cpp
index 9f64070..0da30df 100644
--- a/be/src/vec/functions/plus.cpp
+++ b/be/src/vec/functions/plus.cpp
@@ -35,6 +35,7 @@ struct PlusImpl {
         return static_cast<Result>(a) + b;
     }
 
+    template <typename Result = DecimalV2Value>
     static inline DecimalV2Value apply(DecimalV2Value a, DecimalV2Value b) {
         return a + b;
     }
diff --git a/fe/fe-core/src/main/cup/sql_parser.cup 
b/fe/fe-core/src/main/cup/sql_parser.cup
index f92d4eb..01c863b 100644
--- a/fe/fe-core/src/main/cup/sql_parser.cup
+++ b/fe/fe-core/src/main/cup/sql_parser.cup
@@ -4735,6 +4735,8 @@ expr ::=
 function_call_expr ::=
   function_name:fn_name LPAREN RPAREN
   {: RESULT = new FunctionCallExpr(fn_name, new ArrayList<Expr>()); :}
+  | KW_ADD LPAREN function_params:params RPAREN
+  {: RESULT = new FunctionCallExpr("add", params); :}
   | function_name:fn_name LPAREN function_params:params RPAREN
   {:
     if ("grouping".equalsIgnoreCase(fn_name.getFunction())) {

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to