This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.1 by this push:
new db0724dfe0c [Fix-2.1](function) fix function covar core for not null
input (#39943)
db0724dfe0c is described below
commit db0724dfe0c39f71ccc4b9d44b6f9f030dac721f
Author: zclllhhjj <[email protected]>
AuthorDate: Tue Aug 27 08:39:47 2024 +0800
[Fix-2.1](function) fix function covar core for not null input (#39943)
## Proposed changes
Issue Number: close #xxx
add testcases like:
```groovy
qt_notnull1 "select covar_samp(non_nullable(x), non_nullable(y)) from
test_covar_samp"
qt_notnull2 "select covar_samp(x, non_nullable(y)) from test_covar_samp"
qt_notnull3 "select covar_samp(non_nullable(x), y) from test_covar_samp"
```
before they will all coredump in 2.1
---
.../aggregate_function_covar.cpp | 4 +-
.../aggregate_functions/aggregate_function_covar.h | 85 ++++++++++++++--------
.../trees/expressions/functions/agg/CovarSamp.java | 4 +-
.../agg_function/test_covar_samp.out | 14 +++-
.../agg_function/test_covar_samp.groovy | 4 +-
5 files changed, 76 insertions(+), 35 deletions(-)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_covar.cpp
b/be/src/vec/aggregate_functions/aggregate_function_covar.cpp
index aa5bd511d90..4eb03c05fd3 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_covar.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_covar.cpp
@@ -76,13 +76,15 @@ AggregateFunctionPtr
create_aggregate_function_covariance_pop(const std::string&
name, argument_types, result_is_nullable, NOTNULLABLE);
}
+// register covar_pop for nullable/non_nullable both.
void register_aggregate_function_covar_pop(AggregateFunctionSimpleFactory&
factory) {
factory.register_function_both("covar",
create_aggregate_function_covariance_pop);
factory.register_alias("covar", "covar_pop");
}
void register_aggregate_function_covar_samp(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function("covar_samp",
create_aggregate_function_covariance_samp<NOTNULLABLE>);
+ factory.register_function("covar_samp",
create_aggregate_function_covariance_samp<NOTNULLABLE>,
+ NOTNULLABLE);
factory.register_function("covar_samp",
create_aggregate_function_covariance_samp<NULLABLE>,
NULLABLE);
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_covar.h
b/be/src/vec/aggregate_functions/aggregate_function_covar.h
index 31f0d7d2830..51a07f21145 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_covar.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_covar.h
@@ -17,17 +17,16 @@
#pragma once
+#include "common/exception.h"
+#include "common/status.h"
#define POP true
#define NOTPOP false
#define NULLABLE true
#define NOTNULLABLE false
-#include <stddef.h>
-#include <stdint.h>
-
-#include <algorithm>
#include <boost/iterator/iterator_facade.hpp>
-#include <cmath>
+#include <cstddef>
+#include <cstdint>
#include <memory>
#include <type_traits>
@@ -43,8 +42,8 @@
#include "vec/data_types/data_type_number.h"
#include "vec/io/io_helper.h"
-namespace doris {
-namespace vectorized {
+namespace doris::vectorized {
+
class Arena;
class BufferReadable;
class BufferWritable;
@@ -52,10 +51,6 @@ template <typename T>
class ColumnDecimal;
template <typename>
class ColumnVector;
-} // namespace vectorized
-} // namespace doris
-
-namespace doris::vectorized {
template <typename T>
struct BaseData {
@@ -228,17 +223,30 @@ struct SampData : Data {
using ColVecResult = std::conditional_t<IsDecimalNumber<T>,
ColumnDecimal<Decimal128V2>,
ColumnVector<Float64>>;
void insert_result_into(IColumn& to) const {
- ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to);
- if (this->count == 1 || this->count == 0) {
- nullable_column.insert_default();
+ if (to.is_nullable()) {
+ auto& nullable_column = assert_cast<ColumnNullable&>(to);
+ if (this->count == 1 || this->count == 0) {
+ nullable_column.insert_default();
+ } else {
+ auto& col =
assert_cast<ColVecResult&>(nullable_column.get_nested_column());
+ if constexpr (IsDecimalNumber<T>) {
+ col.get_data().push_back(this->get_samp_result().value());
+ } else {
+ col.get_data().push_back(this->get_samp_result());
+ }
+ nullable_column.get_null_map_data().push_back(0);
+ }
} else {
- auto& col =
assert_cast<ColVecResult&>(nullable_column.get_nested_column());
- if constexpr (IsDecimalNumber<T>) {
- col.get_data().push_back(this->get_samp_result().value());
+ if (this->count == 1 || this->count == 0) {
+ to.insert_default();
} else {
- col.get_data().push_back(this->get_samp_result());
+ auto& col = assert_cast<ColVecResult&>(to);
+ if constexpr (IsDecimalNumber<T>) {
+ col.get_data().push_back(this->get_samp_result().value());
+ } else {
+ col.get_data().push_back(this->get_samp_result());
+ }
}
- nullable_column.get_null_map_data().push_back(0);
}
}
};
@@ -266,26 +274,45 @@ public:
String get_name() const override { return Data::name(); }
DataTypePtr get_return_type() const override {
- if constexpr (is_pop) {
+ if constexpr (is_pop || !is_nullable) { // covar and
covar_samp(non_nullable)
return Data::get_return_type();
- } else {
+ } else { // covar_samp
return make_nullable(Data::get_return_type());
}
}
void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
- if constexpr (is_pop) {
+ if constexpr (is_pop) { // covar_samp
this->data(place).add(columns[0], columns[1], row_num);
- } else {
+ } else { // covar
if constexpr (is_nullable) {
+ // nullable means at least one child is null.
+ // so here, maybe JUST ONE OF ups is null. so nullptr perhaps
in ..._x or ..._y!
const auto* nullable_column_x =
check_and_get_column<ColumnNullable>(columns[0]);
const auto* nullable_column_y =
check_and_get_column<ColumnNullable>(columns[1]);
- if (!nullable_column_x->is_null_at(row_num) &&
- !nullable_column_y->is_null_at(row_num)) {
-
this->data(place).add(&nullable_column_x->get_nested_column(),
-
&nullable_column_y->get_nested_column(), row_num);
+
+ if (nullable_column_x && nullable_column_y) { // both nullable
+ if (!nullable_column_x->is_null_at(row_num) &&
+ !nullable_column_y->is_null_at(row_num)) {
+
this->data(place).add(&nullable_column_x->get_nested_column(),
+
&nullable_column_y->get_nested_column(), row_num);
+ }
+ } else if (nullable_column_x) { // x nullable
+ if (!nullable_column_x->is_null_at(row_num)) {
+
this->data(place).add(&nullable_column_x->get_nested_column(), columns[1],
+ row_num);
+ }
+ } else if (nullable_column_y) { // y nullable
+ if (!nullable_column_y->is_null_at(row_num)) {
+ this->data(place).add(columns[0],
&nullable_column_y->get_nested_column(),
+ row_num);
+ }
+ } else {
+ throw Exception(ErrorCode::INTERNAL_ERROR,
+ "Nullable function {} get non-nullable
columns!", get_name());
}
+
} else {
this->data(place).add(columns[0], columns[1], row_num);
}
@@ -317,14 +344,14 @@ template <typename Data, bool is_nullable>
class AggregateFunctionSamp final
: public AggregateFunctionSampCovariance<NOTPOP, Data, is_nullable> {
public:
- AggregateFunctionSamp(const DataTypes& argument_types_)
+ AggregateFunctionSamp(const DataTypes& argument_types_) // covar_samp
: AggregateFunctionSampCovariance<NOTPOP, Data,
is_nullable>(argument_types_) {}
};
template <typename Data, bool is_nullable>
class AggregateFunctionPop final : public AggregateFunctionSampCovariance<POP,
Data, is_nullable> {
public:
- AggregateFunctionPop(const DataTypes& argument_types_)
+ AggregateFunctionPop(const DataTypes& argument_types_) // covar
: AggregateFunctionSampCovariance<POP, Data,
is_nullable>(argument_types_) {}
};
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CovarSamp.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CovarSamp.java
index 2693d7636f0..0ffe6e88af0 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CovarSamp.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CovarSamp.java
@@ -19,8 +19,8 @@ package
org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
@@ -39,7 +39,7 @@ import java.util.List;
* AggregateFunction 'covar_samp'. This class is generated by GenerateFunction.
*/
public class CovarSamp extends AggregateFunction
- implements UnaryExpression, ExplicitlyCastableSignature,
AlwaysNullable {
+ implements UnaryExpression, ExplicitlyCastableSignature,
PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE,
DoubleType.INSTANCE),
diff --git
a/regression-test/data/nereids_function_p0/agg_function/test_covar_samp.out
b/regression-test/data/nereids_function_p0/agg_function/test_covar_samp.out
index 728beed6cc4..806badbb465 100644
--- a/regression-test/data/nereids_function_p0/agg_function/test_covar_samp.out
+++ b/regression-test/data/nereids_function_p0/agg_function/test_covar_samp.out
@@ -1,6 +1,6 @@
-- This file is automatically generated. You should know what you did if you
want to edit this
-- !sql --
-1
+1.0
-- !sql --
-1.5
@@ -12,4 +12,14 @@
4.5
-- !sql --
-1.666667
\ No newline at end of file
+1.666666666666666
+
+-- !notnull1 --
+1.666666666666666
+
+-- !notnull2 --
+1.666666666666666
+
+-- !notnull3 --
+1.666666666666666
+
diff --git
a/regression-test/suites/nereids_function_p0/agg_function/test_covar_samp.groovy
b/regression-test/suites/nereids_function_p0/agg_function/test_covar_samp.groovy
index a75a933e748..c9e82a86a96 100644
---
a/regression-test/suites/nereids_function_p0/agg_function/test_covar_samp.groovy
+++
b/regression-test/suites/nereids_function_p0/agg_function/test_covar_samp.groovy
@@ -86,5 +86,7 @@ suite("test_covar_samp") {
"""
qt_sql "select covar_samp(x,y) from test_covar_samp"
- sql """ DROP TABLE IF EXISTS test_covar_samp """
+ qt_notnull1 "select covar_samp(non_nullable(x), non_nullable(y)) from
test_covar_samp"
+ qt_notnull2 "select covar_samp(x, non_nullable(y)) from test_covar_samp"
+ qt_notnull3 "select covar_samp(non_nullable(x), y) from test_covar_samp"
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]