This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.1 by this push:
     new db0724dfe0c [Fix-2.1](function) fix function covar core for not null 
input (#39943)
db0724dfe0c is described below

commit db0724dfe0c39f71ccc4b9d44b6f9f030dac721f
Author: zclllhhjj <[email protected]>
AuthorDate: Tue Aug 27 08:39:47 2024 +0800

    [Fix-2.1](function) fix function covar core for not null input (#39943)
    
    ## Proposed changes
    
    Issue Number: close #xxx
    
    add testcases like:
    ```groovy
        qt_notnull1 "select covar_samp(non_nullable(x), non_nullable(y)) from 
test_covar_samp"
        qt_notnull2 "select covar_samp(x, non_nullable(y)) from test_covar_samp"
        qt_notnull3 "select covar_samp(non_nullable(x), y) from test_covar_samp"
    ```
    
    before they will all coredump in 2.1
---
 .../aggregate_function_covar.cpp                   |  4 +-
 .../aggregate_functions/aggregate_function_covar.h | 85 ++++++++++++++--------
 .../trees/expressions/functions/agg/CovarSamp.java |  4 +-
 .../agg_function/test_covar_samp.out               | 14 +++-
 .../agg_function/test_covar_samp.groovy            |  4 +-
 5 files changed, 76 insertions(+), 35 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_covar.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_covar.cpp
index aa5bd511d90..4eb03c05fd3 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_covar.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_covar.cpp
@@ -76,13 +76,15 @@ AggregateFunctionPtr 
create_aggregate_function_covariance_pop(const std::string&
             name, argument_types, result_is_nullable, NOTNULLABLE);
 }
 
+// register covar_pop for nullable/non_nullable both.
 void register_aggregate_function_covar_pop(AggregateFunctionSimpleFactory& 
factory) {
     factory.register_function_both("covar", 
create_aggregate_function_covariance_pop);
     factory.register_alias("covar", "covar_pop");
 }
 
 void register_aggregate_function_covar_samp(AggregateFunctionSimpleFactory& 
factory) {
-    factory.register_function("covar_samp", 
create_aggregate_function_covariance_samp<NOTNULLABLE>);
+    factory.register_function("covar_samp", 
create_aggregate_function_covariance_samp<NOTNULLABLE>,
+                              NOTNULLABLE);
     factory.register_function("covar_samp", 
create_aggregate_function_covariance_samp<NULLABLE>,
                               NULLABLE);
 }
diff --git a/be/src/vec/aggregate_functions/aggregate_function_covar.h 
b/be/src/vec/aggregate_functions/aggregate_function_covar.h
index 31f0d7d2830..51a07f21145 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_covar.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_covar.h
@@ -17,17 +17,16 @@
 
 #pragma once
 
+#include "common/exception.h"
+#include "common/status.h"
 #define POP true
 #define NOTPOP false
 #define NULLABLE true
 #define NOTNULLABLE false
 
-#include <stddef.h>
-#include <stdint.h>
-
-#include <algorithm>
 #include <boost/iterator/iterator_facade.hpp>
-#include <cmath>
+#include <cstddef>
+#include <cstdint>
 #include <memory>
 #include <type_traits>
 
@@ -43,8 +42,8 @@
 #include "vec/data_types/data_type_number.h"
 #include "vec/io/io_helper.h"
 
-namespace doris {
-namespace vectorized {
+namespace doris::vectorized {
+
 class Arena;
 class BufferReadable;
 class BufferWritable;
@@ -52,10 +51,6 @@ template <typename T>
 class ColumnDecimal;
 template <typename>
 class ColumnVector;
-} // namespace vectorized
-} // namespace doris
-
-namespace doris::vectorized {
 
 template <typename T>
 struct BaseData {
@@ -228,17 +223,30 @@ struct SampData : Data {
     using ColVecResult = std::conditional_t<IsDecimalNumber<T>, 
ColumnDecimal<Decimal128V2>,
                                             ColumnVector<Float64>>;
     void insert_result_into(IColumn& to) const {
-        ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to);
-        if (this->count == 1 || this->count == 0) {
-            nullable_column.insert_default();
+        if (to.is_nullable()) {
+            auto& nullable_column = assert_cast<ColumnNullable&>(to);
+            if (this->count == 1 || this->count == 0) {
+                nullable_column.insert_default();
+            } else {
+                auto& col = 
assert_cast<ColVecResult&>(nullable_column.get_nested_column());
+                if constexpr (IsDecimalNumber<T>) {
+                    col.get_data().push_back(this->get_samp_result().value());
+                } else {
+                    col.get_data().push_back(this->get_samp_result());
+                }
+                nullable_column.get_null_map_data().push_back(0);
+            }
         } else {
-            auto& col = 
assert_cast<ColVecResult&>(nullable_column.get_nested_column());
-            if constexpr (IsDecimalNumber<T>) {
-                col.get_data().push_back(this->get_samp_result().value());
+            if (this->count == 1 || this->count == 0) {
+                to.insert_default();
             } else {
-                col.get_data().push_back(this->get_samp_result());
+                auto& col = assert_cast<ColVecResult&>(to);
+                if constexpr (IsDecimalNumber<T>) {
+                    col.get_data().push_back(this->get_samp_result().value());
+                } else {
+                    col.get_data().push_back(this->get_samp_result());
+                }
             }
-            nullable_column.get_null_map_data().push_back(0);
         }
     }
 };
@@ -266,26 +274,45 @@ public:
     String get_name() const override { return Data::name(); }
 
     DataTypePtr get_return_type() const override {
-        if constexpr (is_pop) {
+        if constexpr (is_pop || !is_nullable) { // covar and 
covar_samp(non_nullable)
             return Data::get_return_type();
-        } else {
+        } else { // covar_samp
             return make_nullable(Data::get_return_type());
         }
     }
 
     void add(AggregateDataPtr __restrict place, const IColumn** columns, 
ssize_t row_num,
              Arena*) const override {
-        if constexpr (is_pop) {
+        if constexpr (is_pop) { // covar_samp
             this->data(place).add(columns[0], columns[1], row_num);
-        } else {
+        } else { // covar
             if constexpr (is_nullable) {
+                // nullable means at least one child is null.
+                // so here, maybe JUST ONE OF ups is null. so nullptr perhaps 
in ..._x or ..._y!
                 const auto* nullable_column_x = 
check_and_get_column<ColumnNullable>(columns[0]);
                 const auto* nullable_column_y = 
check_and_get_column<ColumnNullable>(columns[1]);
-                if (!nullable_column_x->is_null_at(row_num) &&
-                    !nullable_column_y->is_null_at(row_num)) {
-                    
this->data(place).add(&nullable_column_x->get_nested_column(),
-                                          
&nullable_column_y->get_nested_column(), row_num);
+
+                if (nullable_column_x && nullable_column_y) { // both nullable
+                    if (!nullable_column_x->is_null_at(row_num) &&
+                        !nullable_column_y->is_null_at(row_num)) {
+                        
this->data(place).add(&nullable_column_x->get_nested_column(),
+                                              
&nullable_column_y->get_nested_column(), row_num);
+                    }
+                } else if (nullable_column_x) { // x nullable
+                    if (!nullable_column_x->is_null_at(row_num)) {
+                        
this->data(place).add(&nullable_column_x->get_nested_column(), columns[1],
+                                              row_num);
+                    }
+                } else if (nullable_column_y) { // y nullable
+                    if (!nullable_column_y->is_null_at(row_num)) {
+                        this->data(place).add(columns[0], 
&nullable_column_y->get_nested_column(),
+                                              row_num);
+                    }
+                } else {
+                    throw Exception(ErrorCode::INTERNAL_ERROR,
+                                    "Nullable function {} get non-nullable 
columns!", get_name());
                 }
+
             } else {
                 this->data(place).add(columns[0], columns[1], row_num);
             }
@@ -317,14 +344,14 @@ template <typename Data, bool is_nullable>
 class AggregateFunctionSamp final
         : public AggregateFunctionSampCovariance<NOTPOP, Data, is_nullable> {
 public:
-    AggregateFunctionSamp(const DataTypes& argument_types_)
+    AggregateFunctionSamp(const DataTypes& argument_types_) // covar_samp
             : AggregateFunctionSampCovariance<NOTPOP, Data, 
is_nullable>(argument_types_) {}
 };
 
 template <typename Data, bool is_nullable>
 class AggregateFunctionPop final : public AggregateFunctionSampCovariance<POP, 
Data, is_nullable> {
 public:
-    AggregateFunctionPop(const DataTypes& argument_types_)
+    AggregateFunctionPop(const DataTypes& argument_types_) // covar
             : AggregateFunctionSampCovariance<POP, Data, 
is_nullable>(argument_types_) {}
 };
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CovarSamp.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CovarSamp.java
index 2693d7636f0..0ffe6e88af0 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CovarSamp.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CovarSamp.java
@@ -19,8 +19,8 @@ package 
org.apache.doris.nereids.trees.expressions.functions.agg;
 
 import org.apache.doris.catalog.FunctionSignature;
 import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
 import 
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
 import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.types.BigIntType;
@@ -39,7 +39,7 @@ import java.util.List;
  * AggregateFunction 'covar_samp'. This class is generated by GenerateFunction.
  */
 public class CovarSamp extends AggregateFunction
-        implements UnaryExpression, ExplicitlyCastableSignature, 
AlwaysNullable {
+        implements UnaryExpression, ExplicitlyCastableSignature, 
PropagateNullable {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
             
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, 
DoubleType.INSTANCE),
diff --git 
a/regression-test/data/nereids_function_p0/agg_function/test_covar_samp.out 
b/regression-test/data/nereids_function_p0/agg_function/test_covar_samp.out
index 728beed6cc4..806badbb465 100644
--- a/regression-test/data/nereids_function_p0/agg_function/test_covar_samp.out
+++ b/regression-test/data/nereids_function_p0/agg_function/test_covar_samp.out
@@ -1,6 +1,6 @@
 -- This file is automatically generated. You should know what you did if you 
want to edit this
 -- !sql --
-1
+1.0
 
 -- !sql --
 -1.5
@@ -12,4 +12,14 @@
 4.5
 
 -- !sql --
-1.666667
\ No newline at end of file
+1.666666666666666
+
+-- !notnull1 --
+1.666666666666666
+
+-- !notnull2 --
+1.666666666666666
+
+-- !notnull3 --
+1.666666666666666
+
diff --git 
a/regression-test/suites/nereids_function_p0/agg_function/test_covar_samp.groovy
 
b/regression-test/suites/nereids_function_p0/agg_function/test_covar_samp.groovy
index a75a933e748..c9e82a86a96 100644
--- 
a/regression-test/suites/nereids_function_p0/agg_function/test_covar_samp.groovy
+++ 
b/regression-test/suites/nereids_function_p0/agg_function/test_covar_samp.groovy
@@ -86,5 +86,7 @@ suite("test_covar_samp") {
     """
     qt_sql "select covar_samp(x,y) from test_covar_samp"
     
-    sql """ DROP TABLE IF EXISTS test_covar_samp """
+    qt_notnull1 "select covar_samp(non_nullable(x), non_nullable(y)) from 
test_covar_samp"
+    qt_notnull2 "select covar_samp(x, non_nullable(y)) from test_covar_samp"
+    qt_notnull3 "select covar_samp(non_nullable(x), y) from test_covar_samp"
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to