This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 5ec8ae27730 [feature](agg-func) support corr function #30822
5ec8ae27730 is described below
commit 5ec8ae2773013512dd4a5aa389b850f3494041a0
Author: nanfeng <[email protected]>
AuthorDate: Wed Feb 7 08:32:06 2024 +0800
[feature](agg-func) support corr function #30822
---
.../aggregate_function_binary.h | 130 +++++++++++++++++++++
.../aggregate_function_corr.cpp | 92 +++++++++++++++
.../aggregate_function_simple_factory.cpp | 3 +
.../sql-functions/aggregate-functions/corr.md | 49 ++++++++
.../sql-functions/aggregate-functions/corr.md | 50 ++++++++
.../doris/catalog/BuiltinAggregateFunctions.java | 2 +
.../java/org/apache/doris/catalog/FunctionSet.java | 25 ++++
.../trees/expressions/functions/agg/Corr.java | 85 ++++++++++++++
.../visitor/AggregateFunctionVisitor.java | 5 +
.../nereids_function_p0/agg_function/test_corr.out | 13 +++
.../agg_function/test_corr.groovy | 85 ++++++++++++++
11 files changed, 539 insertions(+)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_binary.h
b/be/src/vec/aggregate_functions/aggregate_function_binary.h
new file mode 100644
index 00000000000..422919c52af
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_binary.h
@@ -0,0 +1,130 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <glog/logging.h>
+
+#include <cmath>
+
+#include "common/status.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/factory_helpers.h"
+#include "vec/aggregate_functions/helpers.h"
+#include "vec/columns/column_decimal.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/arithmetic_overflow.h"
+#include "vec/common/string_buffer.hpp"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/io/io_helper.h"
+
+namespace doris::vectorized {
+
+template <typename T1, typename T2, template <typename> typename Moments>
+struct StatFunc {
+ using Type1 = T1;
+ using Type2 = T2;
+ using ResultType = std::conditional_t<std::is_same_v<T1, T2> &&
std::is_same_v<T1, Float32>,
+ Float32, Float64>;
+ using Data = Moments<ResultType>;
+};
+
+template <typename StatFunc>
+struct AggregateFunctionBinary
+ : public IAggregateFunctionDataHelper<typename StatFunc::Data,
+
AggregateFunctionBinary<StatFunc>> {
+ using ResultType = typename StatFunc::ResultType;
+
+ using ColVecT1 = ColumnVectorOrDecimal<typename StatFunc::Type1>;
+ using ColVecT2 = ColumnVectorOrDecimal<typename StatFunc::Type2>;
+ using ColVecResult = ColumnVector<ResultType>;
+ static constexpr UInt32 num_args = 2;
+
+ AggregateFunctionBinary(const DataTypes& argument_types_)
+ : IAggregateFunctionDataHelper<typename StatFunc::Data,
+
AggregateFunctionBinary<StatFunc>>(argument_types_) {}
+
+ String get_name() const override { return StatFunc::Data::name(); }
+
+ DataTypePtr get_return_type() const override {
+ return std::make_shared<DataTypeNumber<ResultType>>();
+ }
+
+ bool allocates_memory_in_arena() const override { return false; }
+
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ Arena*) const override {
+ this->data(place).add(
+ static_cast<ResultType>(
+ static_cast<const
ColVecT1&>(*columns[0]).get_data()[row_num]),
+ static_cast<ResultType>(
+ static_cast<const
ColVecT2&>(*columns[1]).get_data()[row_num]));
+ }
+
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
+ this->data(place).merge(this->data(rhs));
+ }
+
+ void serialize(ConstAggregateDataPtr __restrict place, BufferWritable&
buf) const override {
+ this->data(place).write(buf);
+ }
+
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
+ this->data(place).read(buf);
+ }
+
+ void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
+ const auto& data = this->data(place);
+ auto& dst = static_cast<ColVecResult&>(to).get_data();
+ dst.push_back(data.get());
+ }
+};
+
+template <template <typename> typename Moments, typename FirstType,
typename... TArgs>
+AggregateFunctionPtr create_with_two_basic_numeric_types_second(const
DataTypePtr& second_type,
+ TArgs&&...
args) {
+ WhichDataType which(remove_nullable(second_type));
+#define DISPATCH(TYPE) \
+ if (which.idx == TypeIndex::TYPE) \
+ return creator_without_type::create< \
+ AggregateFunctionBinary<StatFunc<FirstType, TYPE, Moments>>>( \
+ std::forward<TArgs>(args)...);
+ FOR_NUMERIC_TYPES(DISPATCH)
+#undef DISPATCH
+ return nullptr;
+}
+
+template <template <typename> typename Moments, typename... TArgs>
+AggregateFunctionPtr create_with_two_basic_numeric_types(const DataTypePtr&
first_type,
+ const DataTypePtr&
second_type,
+ TArgs&&... args) {
+ WhichDataType which(remove_nullable(first_type));
+#define DISPATCH(TYPE) \
+ if (which.idx == TypeIndex::TYPE) \
+ return create_with_two_basic_numeric_types_second<Moments, TYPE>( \
+ second_type, std::forward<TArgs>(args)...);
+ FOR_NUMERIC_TYPES(DISPATCH)
+#undef DISPATCH
+ return nullptr;
+}
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_corr.cpp
b/be/src/vec/aggregate_functions/aggregate_function_corr.cpp
new file mode 100644
index 00000000000..fb84e92e0e6
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_corr.cpp
@@ -0,0 +1,92 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_binary.h"
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/core/types.h"
+
+namespace doris::vectorized {
+
+template <typename T>
+struct CorrMoment {
+ T m0 {};
+ T x1 {};
+ T y1 {};
+ T xy {};
+ T x2 {};
+ T y2 {};
+
+ void add(T x, T y) {
+ ++m0;
+ x1 += x;
+ y1 += y;
+ xy += x * y;
+ x2 += x * x;
+ y2 += y * y;
+ }
+
+ void merge(const CorrMoment& rhs) {
+ m0 += rhs.m0;
+ x1 += rhs.x1;
+ y1 += rhs.y1;
+ xy += rhs.xy;
+ x2 += rhs.x2;
+ y2 += rhs.y2;
+ }
+
+ void write(BufferWritable& buf) const {
+ write_binary(m0, buf);
+ write_binary(x1, buf);
+ write_binary(y1, buf);
+ write_binary(xy, buf);
+ write_binary(x2, buf);
+ write_binary(y2, buf);
+ }
+
+ void read(BufferReadable& buf) {
+ read_binary(m0, buf);
+ read_binary(x1, buf);
+ read_binary(y1, buf);
+ read_binary(xy, buf);
+ read_binary(x2, buf);
+ read_binary(y2, buf);
+ }
+
+ T get() const {
+ if ((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1) == 0) [[unlikely]] {
+ return 0;
+ }
+ return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1
* y1));
+ }
+
+ static String name() { return "corr"; }
+};
+
+AggregateFunctionPtr create_aggregate_corr_function(const std::string& name,
+ const DataTypes&
argument_types,
+ const bool
result_is_nullable) {
+ assert_binary(name, argument_types);
+ return create_with_two_basic_numeric_types<CorrMoment>(argument_types[0],
argument_types[1],
+ argument_types,
result_is_nullable);
+}
+
+void register_aggregate_functions_corr(AggregateFunctionSimpleFactory&
factory) {
+ factory.register_function_both("corr", create_aggregate_corr_function);
+}
+
+} // namespace doris::vectorized
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index 068e2efaac4..9f99a64f2bc 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -60,6 +60,7 @@ void
register_aggregate_function_avg_weighted(AggregateFunctionSimpleFactory& fa
void register_aggregate_function_histogram(AggregateFunctionSimpleFactory&
factory);
void register_aggregate_function_map_agg(AggregateFunctionSimpleFactory&
factory);
void register_aggregate_function_bitmap_agg(AggregateFunctionSimpleFactory&
factory);
+void register_aggregate_functions_corr(AggregateFunctionSimpleFactory&
factory);
AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
static std::once_flag oc;
@@ -100,6 +101,8 @@ AggregateFunctionSimpleFactory&
AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_replace_reader_load(instance);
register_aggregate_function_window_lead_lag_first_last(instance);
register_aggregate_function_HLL_union_agg(instance);
+
+ register_aggregate_functions_corr(instance);
});
return instance;
}
diff --git a/docs/en/docs/sql-manual/sql-functions/aggregate-functions/corr.md
b/docs/en/docs/sql-manual/sql-functions/aggregate-functions/corr.md
new file mode 100644
index 00000000000..862dbad02b1
--- /dev/null
+++ b/docs/en/docs/sql-manual/sql-functions/aggregate-functions/corr.md
@@ -0,0 +1,49 @@
+---
+{
+ "title": "CORR",
+ "language": "en"
+}
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## CORR
+### Description
+#### Syntax
+
+` double corr(x, y)`
+
+Calculate the Pearson correlation coefficient, which is returned as the
covariance of x and y divided by the product of the standard deviations of x
and y.
+If the standard deviation of x or y is 0, the result will be 0.
+
+### example
+
+```
+mysql> select corr(x,y) from baseall;
++---------------------+
+| corr(x, y) |
++---------------------+
+| 0.89442719099991586 |
++---------------------+
+1 row in set (0.21 sec)
+
+```
+### keywords
+CORR
diff --git
a/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/corr.md
b/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/corr.md
new file mode 100644
index 00000000000..0437d5e9d8f
--- /dev/null
+++ b/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/corr.md
@@ -0,0 +1,50 @@
+---
+{
+ "title": "CORR",
+ "language": "zh-CN"
+}
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## CORR
+### Description
+#### Syntax
+
+` double corr(x, y)`
+
+计算皮尔逊系数, 即返回结果为: x和y的协方差,除x和y的标准差乘积。
+如果x或y的标准差为0, 将返回0。
+
+
+### example
+
+```
+mysql> select corr(x,y) from baseall;
++---------------------+
+| corr(x, y) |
++---------------------+
+| 0.89442719099991586 |
++---------------------+
+1 row in set (0.21 sec)
+
+```
+### keywords
+CORR
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
index 5a101e71014..a8fc246d239 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
@@ -28,6 +28,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionInt;
import org.apache.doris.nereids.trees.expressions.functions.agg.CollectList;
import org.apache.doris.nereids.trees.expressions.functions.agg.CollectSet;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Corr;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
@@ -93,6 +94,7 @@ public class BuiltinAggregateFunctions implements
FunctionHelper {
agg(BitmapUnionInt.class, "bitmap_union_int"),
agg(CollectList.class, "collect_list", "group_array"),
agg(CollectSet.class, "collect_set", "group_uniq_array"),
+ agg(Corr.class, "corr"),
agg(Count.class, "count"),
agg(CountByEnum.class, "count_by_enum"),
agg(GroupBitAnd.class, "group_bit_and"),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
index c589cbbf505..629e4556df2 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
@@ -1710,6 +1710,31 @@ public class FunctionSet<T> {
"",
false, true, false, true));
+ // corr
+ addBuiltin(AggregateFunction.createBuiltin("corr",
+ Lists.<Type>newArrayList(Type.TINYINT, Type.TINYINT),
Type.DOUBLE, Type.DOUBLE,
+ "", "", "", "", "", "", "",
+ false, false, false, true));
+ addBuiltin(AggregateFunction.createBuiltin("corr",
+ Lists.<Type>newArrayList(Type.SMALLINT, Type.SMALLINT),
Type.DOUBLE, Type.DOUBLE,
+ "", "", "", "", "", "", "",
+ false, false, false, true));
+ addBuiltin(AggregateFunction.createBuiltin("corr",
+ Lists.<Type>newArrayList(Type.INT, Type.INT), Type.DOUBLE,
Type.DOUBLE,
+ "", "", "", "", "", "", "",
+ false, false, false, true));
+ addBuiltin(AggregateFunction.createBuiltin("corr",
+ Lists.<Type>newArrayList(Type.BIGINT, Type.BIGINT),
Type.DOUBLE, Type.DOUBLE,
+ "", "", "", "", "", "", "",
+ false, false, false, true));
+ addBuiltin(AggregateFunction.createBuiltin("corr",
+ Lists.<Type>newArrayList(Type.FLOAT, Type.FLOAT), Type.DOUBLE,
Type.DOUBLE,
+ "", "", "", "", "", "", "",
+ false, false, false, true));
+ addBuiltin(AggregateFunction.createBuiltin("corr",
+ Lists.<Type>newArrayList(Type.DOUBLE, Type.DOUBLE),
Type.DOUBLE, Type.DOUBLE,
+ "", "", "", "", "", "", "",
+ false, false, false, true));
}
public Map<String, List<Function>> getVectorizedFunctions() {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Corr.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Corr.java
new file mode 100644
index 00000000000..26f8a720c26
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Corr.java
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.agg;
+
+import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
+import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.FloatType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.SmallIntType;
+import org.apache.doris.nereids.types.TinyIntType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * AggregateFunction 'corr'. This class is generated by GenerateFunction.
+ */
+public class Corr extends AggregateFunction
+ implements UnaryExpression, ExplicitlyCastableSignature,
AlwaysNullable {
+
+ public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
+
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE,
TinyIntType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE,
SmallIntType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE,
IntegerType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE,
BigIntType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE,
FloatType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE,
DoubleType.INSTANCE)
+ );
+
+ /**
+ * constructor with 2 argument.
+ */
+ public Corr(Expression arg1, Expression arg2) {
+ super("corr", arg1, arg2);
+ }
+
+ /**
+ * constructor with 3 arguments.
+ */
+ public Corr(boolean distinct, Expression arg1, Expression arg2) {
+ super("corr", distinct, arg1, arg2);
+ }
+
+ /**
+ * withDistinctAndChildren.
+ */
+ @Override
+ public Corr withDistinctAndChildren(boolean distinct, List<Expression>
children) {
+ Preconditions.checkArgument(children.size() == 2);
+ return new Corr(distinct, children.get(0), children.get(1));
+ }
+
+ @Override
+ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+ return visitor.visitCorr(this, context);
+ }
+
+ @Override
+ public List<FunctionSignature> getSignatures() {
+ return SIGNATURES;
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
index 14e3dc304e9..73dd6a838b9 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
@@ -29,6 +29,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionInt;
import org.apache.doris.nereids.trees.expressions.functions.agg.CollectList;
import org.apache.doris.nereids.trees.expressions.functions.agg.CollectSet;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Corr;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
@@ -126,6 +127,10 @@ public interface AggregateFunctionVisitor<R, C> {
return visitAggregateFunction(collectSet, context);
}
+ default R visitCorr(Corr corr, C context) {
+ return visitAggregateFunction(corr, context);
+ }
+
default R visitCount(Count count, C context) {
return visitAggregateFunction(count, context);
}
diff --git
a/regression-test/data/nereids_function_p0/agg_function/test_corr.out
b/regression-test/data/nereids_function_p0/agg_function/test_corr.out
new file mode 100644
index 00000000000..4fc9a9d4baa
--- /dev/null
+++ b/regression-test/data/nereids_function_p0/agg_function/test_corr.out
@@ -0,0 +1,13 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !sql --
+1.0
+
+-- !sql --
+-1.0
+
+-- !sql --
+0.0
+
+-- !sql --
+0.8944271909999159
+
diff --git
a/regression-test/suites/nereids_function_p0/agg_function/test_corr.groovy
b/regression-test/suites/nereids_function_p0/agg_function/test_corr.groovy
new file mode 100644
index 00000000000..15f27f84276
--- /dev/null
+++ b/regression-test/suites/nereids_function_p0/agg_function/test_corr.groovy
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_corr") {
+ sql """ DROP TABLE IF EXISTS test_corr """
+
+ sql """ SET enable_nereids_planner=true """
+ sql """ SET enable_fallback_to_original_planner=false """
+
+ sql """
+ CREATE TABLE test_corr (
+ `id` int,
+ `x` int,
+ `y` int,
+ ) ENGINE=OLAP
+ Duplicate KEY (`id`)
+ DISTRIBUTED BY HASH(`id`) BUCKETS 4
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1"
+ );
+ """
+
+ // Perfect positive correlation
+ sql """
+ insert into test_corr values
+ (1, 1, 1),
+ (2, 2, 2),
+ (3, 3, 3),
+ (4, 4, 4),
+ (5, 5, 5)
+ """
+ qt_sql "select corr(x,y) from test_corr"
+ sql """ truncate table test_corr """
+
+ // Perfect negative correlation
+ sql """
+ insert into test_corr values
+ (1, 1, 5),
+ (2, 2, 4),
+ (3, 3, 3),
+ (4, 4, 2),
+ (5, 5, 1)
+ """
+ qt_sql "select corr(x,y) from test_corr"
+ sql """ truncate table test_corr """
+
+ // Zero correlation
+ sql """
+ insert into test_corr values
+ (1, 1, 1),
+ (2, 1, 2),
+ (3, 1, 3),
+ (4, 1, 4),
+ (5, 1, 5)
+ """
+ qt_sql "select corr(x,y) from test_corr"
+ sql """ truncate table test_corr """
+
+ // Partial linear correlation
+ sql """
+ insert into test_corr values
+ (1, 1, 1),
+ (2, 2, 2),
+ (3, 3, 3),
+ (4, 4, 4),
+ (5, 5, 10)
+ """
+ qt_sql "select corr(x,y) from test_corr"
+
+ sql """ DROP TABLE IF EXISTS test_corr """
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]