This is an automated email from the ASF dual-hosted git repository.
zclll pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new e06abedb120 [fix](regr) Use Youngs-Cramer for REGR_SLOPE/INTERCEPT to
align with PG (#55940)
e06abedb120 is described below
commit e06abedb120dd77469abfb1cdf5c972b952148ef
Author: Jover <[email protected]>
AuthorDate: Wed Dec 10 23:16:10 2025 +0800
[fix](regr) Use Youngs-Cramer for REGR_SLOPE/INTERCEPT to align with PG
(#55940)
This PR reimplements `REGR_SLOPE` and `REGR_INTERCEPT` using the
Youngs–Cramer algorithm to align with PostgreSQL.
It also extends `AggregateFunctionRegrData<T>` so it can be reused by
all `REGR_*` functions (`SXX`, `SYY`, `SXY`, `R2`, etc.).
```sql
-- Copy from
`regression-test/suites/query_p0/aggregate/support_type/regr_slope/regr_slope.groovy`
-- dataset (PostgreSQL)
drop table if exists d_table;
create table d_table (
k1 int,
k2 int not null,
k3 bigint,
col_tinyint smallint,
col_smallint smallint,
col_int int,
col_bigint bigint,
col_largeint numeric(38,0),
col_float real,
col_double double precision
);
insert into d_table values
(1, 1, 1, 100, 10000, 1000000, 10000000000, 100000000000000000000,
3.14, 2.718281828),
(2, 2, 2, 101, 10001, 1000001, 10000000001, 100000000000000000001,
6.28, 3.141592653),
(3, 3, 3, 102, 10002, 1000002, 10000000002, 100000000000000000002,
9.42, 1.618033988);
select regr_slope(col_tinyint, col_smallint) from d_table;
-- 1.0
select regr_slope(col_smallint, col_int) from d_table;
-- 1.0
select regr_slope(col_int, col_bigint) from d_table;
-- 1.0
select regr_slope(col_bigint, col_largeint) from d_table;
-- <null>
select regr_slope(col_largeint, col_float) from d_table;
-- 0.0
select regr_slope(col_float, col_double) from d_table;
-- -2.7928921351549283
select regr_slope(col_double, col_tinyint) from d_table;
-- -0.5501239200000003
select regr_intercept(col_tinyint, col_smallint) from d_table;
-- -9900.0
select regr_intercept(col_smallint, col_int) from d_table;
-- -990000.0
select regr_intercept(col_int, col_bigint) from d_table;
-- -9999000000.0
select regr_intercept(col_bigint, col_largeint) from d_table;
-- <null>
select regr_intercept(col_largeint, col_float) from d_table;
-- 1e+20
select regr_intercept(col_float, col_double) from d_table;
-- 13.241664047161668
select regr_intercept(col_double, col_tinyint) from d_table;
-- 58.055152076333364
```
---
.../aggregate_function_regr_union.h | 250 ++++++++++++++++-----
.../support_type/regr_intercept/regr_intercept.out | 8 +-
.../support_type/regr_slope/regr_slope.out | 8 +-
.../query_p0/aggregate/test_regr_intercept.groovy | 18 +-
.../query_p0/aggregate/test_regr_slope.groovy | 18 +-
5 files changed, 226 insertions(+), 76 deletions(-)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h
b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h
index 1cc30b8c430..dde9fc5e48f 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h
@@ -32,88 +32,238 @@
namespace doris::vectorized {
#include "common/compile_check_begin.h"
-template <PrimitiveType T>
+template <PrimitiveType T,
+ // requires Sx and Sy
+ bool NeedSxy,
+ // level 1: Sx
+ // level 2: Sxx
+ size_t SxLevel = size_t {NeedSxy},
+ // level 1: Sy
+ // level 2: Syy
+ size_t SyLevel = size_t {NeedSxy}>
struct AggregateFunctionRegrData {
static constexpr PrimitiveType Type = T;
- UInt64 count = 0;
- Float64 sum_x {};
- Float64 sum_y {};
- Float64 sum_of_x_mul_y {};
- Float64 sum_of_x_squared {};
+
+ static_assert(!NeedSxy || (SxLevel > 0 && SyLevel > 0),
+ "NeedSxy requires SxLevel > 0 and SyLevel > 0");
+ static_assert(SxLevel <= 2 && SyLevel <= 2, "Sx/Sy level must be <= 2");
+
+ static constexpr bool need_sx = SxLevel > 0;
+ static constexpr bool need_sy = SyLevel > 0;
+ static constexpr bool need_sxx = SxLevel > 1;
+ static constexpr bool need_syy = SyLevel > 1;
+ static constexpr bool need_sxy = NeedSxy;
+
+ static constexpr size_t kMomentSize = SxLevel + SyLevel + size_t
{need_sxy};
+ static_assert(kMomentSize > 0 && kMomentSize <= 5, "Unexpected size of
regr moment array");
+
+ /**
+ * The moments array is:
+ * Sx = sum(X)
+ * Sy = sum(Y)
+ * Sxx = sum((X-Sx/N)^2)
+ * Syy = sum((Y-Sy/N)^2)
+ * Sxy = sum((X-Sx/N)*(Y-Sy/N))
+ */
+ std::array<Float64, kMomentSize> moments {};
+ UInt64 n {};
+
+ static constexpr size_t idx_sx() {
+ static_assert(need_sx, "sx not enabled");
+ return 0;
+ }
+ static constexpr size_t idx_sy() {
+ static_assert(need_sy, "sy not enabled");
+ return size_t {need_sx};
+ }
+ static constexpr size_t idx_sxx() {
+ static_assert(need_sxx, "sxx not enabled");
+ return size_t {need_sx + need_sy};
+ }
+ static constexpr size_t idx_syy() {
+ static_assert(need_syy, "syy not enabled");
+ return size_t {need_sx + need_sy + need_sxx};
+ }
+ static constexpr size_t idx_sxy() {
+ static_assert(need_sxy, "sxy not enabled");
+ return size_t {need_sx + need_sy + need_sxx + need_syy};
+ }
+
+ Float64& sx() { return moments[idx_sx()]; }
+ Float64& sy() { return moments[idx_sy()]; }
+ Float64& sxx() { return moments[idx_sxx()]; }
+ Float64& syy() { return moments[idx_syy()]; }
+ Float64& sxy() { return moments[idx_sxy()]; }
+
+ const Float64& sx() const { return moments[idx_sx()]; }
+ const Float64& sy() const { return moments[idx_sy()]; }
+ const Float64& sxx() const { return moments[idx_sxx()]; }
+ const Float64& syy() const { return moments[idx_syy()]; }
+ const Float64& sxy() const { return moments[idx_sxy()]; }
void write(BufferWritable& buf) const {
- buf.write_binary(sum_x);
- buf.write_binary(sum_y);
- buf.write_binary(sum_of_x_mul_y);
- buf.write_binary(sum_of_x_squared);
- buf.write_binary(count);
+ if constexpr (need_sx) {
+ buf.write_binary(sx());
+ }
+ if constexpr (need_sy) {
+ buf.write_binary(sy());
+ }
+ if constexpr (need_sxx) {
+ buf.write_binary(sxx());
+ }
+ if constexpr (need_syy) {
+ buf.write_binary(syy());
+ }
+ if constexpr (need_sxy) {
+ buf.write_binary(sxy());
+ }
+ buf.write_binary(n);
}
void read(BufferReadable& buf) {
- buf.read_binary(sum_x);
- buf.read_binary(sum_y);
- buf.read_binary(sum_of_x_mul_y);
- buf.read_binary(sum_of_x_squared);
- buf.read_binary(count);
+ if constexpr (need_sx) {
+ buf.read_binary(sx());
+ }
+ if constexpr (need_sy) {
+ buf.read_binary(sy());
+ }
+ if constexpr (need_sxx) {
+ buf.read_binary(sxx());
+ }
+ if constexpr (need_syy) {
+ buf.read_binary(syy());
+ }
+ if constexpr (need_sxy) {
+ buf.read_binary(sxy());
+ }
+ buf.read_binary(n);
}
void reset() {
- sum_x = {};
- sum_y = {};
- sum_of_x_mul_y = {};
- sum_of_x_squared = {};
- count = 0;
+ moments.fill({});
+ n = {};
}
+ /**
+ * The merge function uses the Youngs–Cramer algorithm:
+ * N = N1 + N2
+ * Sx = Sx1 + Sx2
+ * Sy = Sy1 + Sy2
+ * Sxx = Sxx1 + Sxx2 + N1 * N2 * (Sx1/N1 - Sx2/N2)^2 / N
+ * Syy = Syy1 + Syy2 + N1 * N2 * (Sy1/N1 - Sy2/N2)^2 / N
+ * Sxy = Sxy1 + Sxy2 + N1 * N2 * (Sx1/N1 - Sx2/N2) * (Sy1/N1 - Sy2/N2)
/ N
+ */
void merge(const AggregateFunctionRegrData& rhs) {
- if (rhs.count == 0) {
+ if (rhs.n == 0) {
+ return;
+ }
+ if (n == 0) {
+ *this = rhs;
return;
}
- sum_x += rhs.sum_x;
- sum_y += rhs.sum_y;
- sum_of_x_mul_y += rhs.sum_of_x_mul_y;
- sum_of_x_squared += rhs.sum_of_x_squared;
- count += rhs.count;
+ const auto n1 = static_cast<Float64>(n);
+ const auto n2 = static_cast<Float64>(rhs.n);
+ const auto nsum = n1 + n2;
+
+ Float64 dx {};
+ Float64 dy {};
+ if constexpr (need_sxx || need_sxy) {
+ dx = sx() / n1 - rhs.sx() / n2;
+ }
+ if constexpr (need_syy || need_sxy) {
+ dy = sy() / n1 - rhs.sy() / n2;
+ }
+
+ n += rhs.n;
+ if constexpr (need_sx) {
+ sx() += rhs.sx();
+ }
+ if constexpr (need_sy) {
+ sy() += rhs.sy();
+ }
+ if constexpr (need_sxx) {
+ sxx() += rhs.sxx() + n1 * n2 * dx * dx / nsum;
+ }
+ if constexpr (need_syy) {
+ syy() += rhs.syy() + n1 * n2 * dy * dy / nsum;
+ }
+ if constexpr (need_sxy) {
+ sxy() += rhs.sxy() + n1 * n2 * dx * dy / nsum;
+ }
}
+ /**
+ * N
+ * Sx = sum(X)
+ * Sy = sum(Y)
+ * Sxx = sum((X-Sx/N)^2)
+ * Syy = sum((Y-Sy/N)^2)
+ * Sxy = sum((X-Sx/N)*(Y-Sy/N))
+ */
void add(typename PrimitiveTypeTraits<T>::ColumnItemType value_y,
typename PrimitiveTypeTraits<T>::ColumnItemType value_x) {
- sum_x += (double)value_x;
- sum_y += (double)value_y;
- sum_of_x_mul_y += (double)value_x * (double)value_y;
- sum_of_x_squared += (double)value_x * (double)value_x;
- count += 1;
- }
+ const auto x = static_cast<Float64>(value_x);
+ const auto y = static_cast<Float64>(value_y);
- Float64 get_slope() const {
- Float64 denominator = (double)count * sum_of_x_squared - sum_x * sum_x;
- if (count < 2 || denominator == 0.0) {
- return std::numeric_limits<Float64>::quiet_NaN();
+ if constexpr (need_sx) {
+ sx() += x;
+ }
+ if constexpr (need_sy) {
+ sy() += y;
+ }
+
+ if (n == 0) [[unlikely]] {
+ n = 1;
+ return;
+ }
+ const auto n_old = static_cast<Float64>(n);
+ const auto n_new = n_old + 1;
+ const auto scale = 1.0 / (n_new * n_old);
+ n += 1;
+
+ Float64 tmp_x {};
+ Float64 tmp_y {};
+ if constexpr (need_sxx || need_sxy) {
+ tmp_x = x * n_new - sx();
+ }
+ if constexpr (need_syy || need_sxy) {
+ tmp_y = y * n_new - sy();
+ }
+
+ if constexpr (need_sxx) {
+ sxx() += tmp_x * tmp_x * scale;
+ }
+ if constexpr (need_syy) {
+ syy() += tmp_y * tmp_y * scale;
+ }
+ if constexpr (need_sxy) {
+ sxy() += tmp_x * tmp_y * scale;
}
- Float64 slope = ((double)count * sum_of_x_mul_y - sum_x * sum_y) /
denominator;
- return slope;
}
};
template <PrimitiveType T>
-struct RegrSlopeFunc : AggregateFunctionRegrData<T> {
+struct RegrSlopeFunc : AggregateFunctionRegrData<T, true, 2, 1> {
static constexpr const char* name = "regr_slope";
- Float64 get_result() const { return this->get_slope(); }
+ Float64 get_result() const {
+ if (this->n < 1 || this->sxx() == 0.0) {
+ return std::numeric_limits<Float64>::quiet_NaN();
+ }
+ return this->sxy() / this->sxx();
+ }
};
template <PrimitiveType T>
-struct RegrInterceptFunc : AggregateFunctionRegrData<T> {
+struct RegrInterceptFunc : AggregateFunctionRegrData<T, true, 2, 2> {
static constexpr const char* name = "regr_intercept";
Float64 get_result() const {
- auto slope = this->get_slope();
- if (std::isnan(slope)) {
- return slope;
- } else {
- Float64 intercept = (this->sum_y - slope * this->sum_x) /
(double)this->count;
- return intercept;
+ if (this->n < 1 || this->sxx() == 0.0) {
+ return std::numeric_limits<Float64>::quiet_NaN();
}
+ return (this->sy() - this->sx() * this->sxy() / this->sxx()) /
+ static_cast<Float64>(this->n);
}
};
@@ -147,7 +297,7 @@ public:
const XInputCol* x_nested_column = nullptr;
if constexpr (y_nullable) {
- const ColumnNullable& y_column_nullable =
+ const auto& y_column_nullable =
assert_cast<const ColumnNullable&,
TypeCheckOnRelease::DISABLE>(*columns[0]);
y_null = y_column_nullable.is_null_at(row_num);
y_nested_column = assert_cast<const YInputCol*,
TypeCheckOnRelease::DISABLE>(
@@ -158,7 +308,7 @@ public:
}
if constexpr (x_nullable) {
- const ColumnNullable& x_column_nullable =
+ const auto& x_column_nullable =
assert_cast<const ColumnNullable&,
TypeCheckOnRelease::DISABLE>(*columns[1]);
x_null = x_column_nullable.is_null_at(row_num);
x_nested_column = assert_cast<const XInputCol*,
TypeCheckOnRelease::DISABLE>(
diff --git
a/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out
b/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out
index 88a91371f5f..f58aaf4a55f 100644
---
a/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out
+++
b/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out
@@ -6,17 +6,17 @@
-990000.0
-- !regr_intercept_int --
-1000001.0
+-9.999E9
-- !regr_intercept_bigint --
\N
-- !regr_intercept_largeint --
-9.999999999999989E19
+1.0E20
-- !regr_intercept_float --
-13.241664047161644
+13.24166404716167
-- !regr_intercept_double --
-58.05515207632899
+58.05515207633332
diff --git
a/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out
b/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out
index 77140f0d1d3..0e9d13ae71d 100644
---
a/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out
+++
b/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out
@@ -6,17 +6,17 @@
1.0
-- !regr_slope_int --
--0.0
+1.0
-- !regr_slope_bigint --
\N
-- !regr_slope_largeint --
-17725.127617654194
+0.0
-- !regr_slope_float --
--2.79289213515492
+-2.792892135154929
-- !regr_slope_double --
--0.5501239199999569
+-0.5501239199999999
diff --git
a/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy
b/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy
index f7c44642427..10683585309 100644
--- a/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy
+++ b/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy
@@ -51,9 +51,9 @@ suite("test_regr_intercept") {
// no value
// agg function without group by should return null
- qt_sql_empty_1 "select regr_intercept(y,x) from test_regr_intercept_int"
+ qt_sql_empty_1 "select regr_intercept(y, x) from test_regr_intercept_int"
// agg function with group by should return empty set
- qt_sql_empty_2 "select regr_intercept(y,x) from test_regr_intercept_int
group by id"
+ qt_sql_empty_2 "select regr_intercept(y, x) from test_regr_intercept_int
group by id"
sql """ TRUNCATE TABLE test_regr_intercept_int """
@@ -83,7 +83,7 @@ suite("test_regr_intercept") {
qt_sql_int_2 "select regr_intercept(x, 4) from test_regr_intercept_int"
// int value
- qt_sql_int_3 "select regr_intercept(y,x) from test_regr_intercept_int"
+ qt_sql_int_3 "select regr_intercept(y, x) from test_regr_intercept_int"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_int_4 "select regr_intercept(non_nullable(y), non_nullable(x)) from
test_regr_intercept_int"
@@ -122,8 +122,8 @@ suite("test_regr_intercept") {
qt_sql_int_7 "select regr_intercept(x, 4) from test_regr_intercept_int"
// int value
- qt_sql_int_8 "select regr_intercept(y,x) from test_regr_intercept_int"
- qt_sql_int_8 "select regr_intercept(y,x) from test_regr_intercept_int
group by id order by id"
+ qt_sql_int_8 "select regr_intercept(y, x) from test_regr_intercept_int"
+ qt_sql_int_8 "select regr_intercept(y, x) from test_regr_intercept_int
group by id order by id"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_int_9 "select regr_intercept(non_nullable(y), non_nullable(x)) from
test_regr_intercept_int where id >= 3"
@@ -142,8 +142,8 @@ suite("test_regr_intercept") {
qt_sql_double_2 "select regr_intercept(x, 4) from
test_regr_intercept_double"
// int value
- qt_sql_double_3 "select regr_intercept(y,x) from
test_regr_intercept_double"
- qt_sql_double_3 "select regr_intercept(y,x) from
test_regr_intercept_double group by id order by id"
+ qt_sql_double_3 "select regr_intercept(y, x) from
test_regr_intercept_double"
+ qt_sql_double_3 "select regr_intercept(y, x) from
test_regr_intercept_double group by id order by id"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_double_4 "select regr_intercept(non_nullable(y), non_nullable(x))
from test_regr_intercept_double"
@@ -183,8 +183,8 @@ suite("test_regr_intercept") {
qt_sql_double_7 "select regr_intercept(x, 4) from
test_regr_intercept_double"
// int value
- qt_sql_double_8 "select regr_intercept(y,x) from
test_regr_intercept_double"
- qt_sql_double_8 "select regr_intercept(y,x) from
test_regr_intercept_double group by id order by id"
+ qt_sql_double_8 "select regr_intercept(y, x) from
test_regr_intercept_double"
+ qt_sql_double_8 "select regr_intercept(y, x) from
test_regr_intercept_double group by id order by id"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_double_9 "select regr_intercept(non_nullable(y), non_nullable(x))
from test_regr_intercept_double where id >= 3"
diff --git a/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy
b/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy
index 19397036234..0c600710367 100644
--- a/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy
+++ b/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy
@@ -51,9 +51,9 @@ suite("test_regr_slope") {
// no value
// agg function without group by should return null
- qt_sql_empty_1 "select regr_slope(y,x) from test_regr_slope_int"
+ qt_sql_empty_1 "select regr_slope(y, x) from test_regr_slope_int"
// agg function with group by should return empty set
- qt_sql_empty_2 "select regr_slope(y,x) from test_regr_slope_int group by
id"
+ qt_sql_empty_2 "select regr_slope(y, x) from test_regr_slope_int group by
id"
sql """ TRUNCATE TABLE test_regr_slope_int """
@@ -83,7 +83,7 @@ suite("test_regr_slope") {
qt_sql_int_2 "select regr_slope(x, 4) from test_regr_slope_int"
// int value
- qt_sql_int_3 "select regr_slope(y,x) from test_regr_slope_int"
+ qt_sql_int_3 "select regr_slope(y, x) from test_regr_slope_int"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_int_4 "select regr_slope(non_nullable(y), non_nullable(x)) from
test_regr_slope_int"
@@ -122,8 +122,8 @@ suite("test_regr_slope") {
qt_sql_int_7 "select regr_slope(x, 4) from test_regr_slope_int"
// int value
- qt_sql_int_8 "select regr_slope(y,x) from test_regr_slope_int"
- qt_sql_int_8 "select regr_slope(y,x) from test_regr_slope_int group by id
order by id"
+ qt_sql_int_8 "select regr_slope(y, x) from test_regr_slope_int"
+ qt_sql_int_8 "select regr_slope(y, x) from test_regr_slope_int group by id
order by id"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_int_9 "select regr_slope(non_nullable(y), non_nullable(x)) from
test_regr_slope_int where id >= 3"
@@ -142,8 +142,8 @@ suite("test_regr_slope") {
qt_sql_double_2 "select regr_slope(x, 4) from test_regr_slope_double"
// int value
- qt_sql_double_3 "select regr_slope(y,x) from test_regr_slope_double"
- qt_sql_double_3 "select regr_slope(y,x) from test_regr_slope_double group
by id order by id"
+ qt_sql_double_3 "select regr_slope(y, x) from test_regr_slope_double"
+ qt_sql_double_3 "select regr_slope(y, x) from test_regr_slope_double group
by id order by id"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_double_4 "select regr_slope(non_nullable(y), non_nullable(x)) from
test_regr_slope_double"
@@ -183,8 +183,8 @@ suite("test_regr_slope") {
qt_sql_double_7 "select regr_slope(x, 4) from test_regr_slope_double"
// int value
- qt_sql_double_8 "select regr_slope(y,x) from test_regr_slope_double"
- qt_sql_double_8 "select regr_slope(y,x) from test_regr_slope_double group
by id order by id"
+ qt_sql_double_8 "select regr_slope(y, x) from test_regr_slope_double"
+ qt_sql_double_8 "select regr_slope(y, x) from test_regr_slope_double group
by id order by id"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_double_9 "select regr_slope(non_nullable(y), non_nullable(x)) from
test_regr_slope_double where id >= 3"
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]