This is an automated email from the ASF dual-hosted git repository.
zclll pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 2dee657a397 [Feature](agg) add sem agg function (#57545)
2dee657a397 is described below
commit 2dee657a397fe2aec46976da4a060d58987d6105
Author: admiring_xm <[email protected]>
AuthorDate: Wed Nov 5 11:49:07 2025 +0800
[Feature](agg) add sem agg function (#57545)
add sem agg function
---
.../aggregate_functions/aggregate_function_sem.cpp | 41 ++++++
.../aggregate_functions/aggregate_function_sem.h | 149 +++++++++++++++++++++
.../aggregate_function_simple_factory.cpp | 2 +
.../doris/catalog/BuiltinAggregateFunctions.java | 2 +
.../trees/expressions/functions/agg/Sem.java | 88 ++++++++++++
.../visitor/AggregateFunctionVisitor.java | 5 +
.../test_aggregate_all_functions2.out | 54 ++++++++
.../test_aggregate_all_functions2.groovy | 22 +++
8 files changed, 363 insertions(+)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sem.cpp
b/be/src/vec/aggregate_functions/aggregate_function_sem.cpp
new file mode 100644
index 00000000000..f446bbe2e57
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_sem.cpp
@@ -0,0 +1,41 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "vec/aggregate_functions/aggregate_function_sem.h"
+
+#include "runtime/define_primitive_type.h"
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/helpers.h"
+#include "vec/core/field.h"
+
+namespace doris::vectorized {
+#include "common/compile_check_begin.h"
+
+using AggregateFuncSem = AggregateFunctionSem<AggregateFunctionSemData>;
+
+void register_aggregate_function_sem(AggregateFunctionSimpleFactory& factory) {
+ AggregateFunctionCreator creator = [&](const std::string& name, const
DataTypes& types,
+ const bool result_is_nullable,
+ const AggregateFunctionAttr& attr) {
+ return creator_without_type::creator<AggregateFuncSem>(name, types,
result_is_nullable,
+ attr);
+ };
+ factory.register_function_both("sem", creator);
+}
+
+#include "common/compile_check_end.h"
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sem.h
b/be/src/vec/aggregate_functions/aggregate_function_sem.h
new file mode 100644
index 00000000000..333c4e9c233
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_sem.h
@@ -0,0 +1,149 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <math.h>
+
+#include "runtime/decimalv2_value.h"
+#include "runtime/define_primitive_type.h"
+#include "runtime/primitive_type.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/column.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_decimal.h"
+
+namespace doris::vectorized {
+#include "common/compile_check_begin.h"
+
+class Arena;
+class BufferReadable;
+class BufferWritable;
+
+/**
+ * SEM = sqrt(variance / count) = sqrt(m2 / (count * (count - 1)))
+ * It uses Welford’s method for numerically stable one-pass computation of the
mean and variance.
+ */
+struct AggregateFunctionSemData {
+ double mean {};
+ double m2 {}; // Cumulative sum of squares
+ UInt64 count = 0;
+
+ // Let the old mean be mean_{n-1} and we receive a new value x_n.
+ // The new mean should be: mean_n = mean_{n-1} + (x_n - mean_{n-1}) / n
+ // The new M2 should capture total squared deviations from the new mean:
+ // M2_n = M2_{n-1} + (x_n - mean_{n-1}) * (x_n - mean_n)
+ void add(const double& value) {
+ count++;
+ double delta = value - mean;
+ mean += delta / static_cast<double>(count);
+ double delta2 = value - mean;
+ m2 += delta * delta2;
+ }
+
+ // Suppose we have dataset A (count_a, mean_a, M2_a) and B (count_b,
mean_b, M2_b).
+ // When merging:
+ // - The total count is count_a + count_b
+ // - The new mean is the weighted average of the two means
+ // - The new M2 accumulates both variances and an adjustment term to
account for
+ // the difference in means between A and B:
+ // M2 = M2_a + M2_b + delta^2 * count_a * count_b / total_count
+ // where delta = mean_b - mean_a
+ void merge(const AggregateFunctionSemData& rhs) {
+ UInt64 total_count = count + rhs.count;
+ double delta = rhs.mean - mean;
+ mean = (mean * static_cast<double>(count) + rhs.mean *
static_cast<double>(rhs.count)) /
+ static_cast<double>(total_count);
+ m2 += rhs.m2 + delta * delta * static_cast<double>(count) *
static_cast<double>(rhs.count) /
+ static_cast<double>(total_count);
+ count = total_count;
+ }
+
+ void write(BufferWritable& buf) const {
+ buf.write_binary(mean);
+ buf.write_binary(m2);
+ buf.write_binary(count);
+ }
+
+ void read(BufferReadable& buf) {
+ buf.read_binary(mean);
+ buf.read_binary(m2);
+ buf.read_binary(count);
+ }
+
+ void reset() {
+ mean = {};
+ m2 = {};
+ count = 0;
+ }
+
+ double result() const {
+ if (count < 2) {
+ return 0;
+ }
+ double dCount = static_cast<double>(count);
+ double result = std::sqrt(m2 / (dCount * (dCount - 1.0)));
+ return result;
+ }
+};
+
+template <typename Data>
+class AggregateFunctionSem final
+ : public IAggregateFunctionDataHelper<Data,
AggregateFunctionSem<Data>>,
+ UnaryExpression,
+ NullableAggregateFunction {
+public:
+ AggregateFunctionSem(const DataTypes& argument_types_)
+ : IAggregateFunctionDataHelper<Data,
AggregateFunctionSem<Data>>(argument_types_) {}
+
+ String get_name() const override { return "sem"; }
+
+ DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeFloat64>(); }
+
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
+ Arena&) const override {
+ const auto& column =
+ assert_cast<const ColumnFloat64&,
TypeCheckOnRelease::DISABLE>(*columns[0]);
+ this->data(place).add((double)column.get_data()[row_num]);
+ }
+
+ void reset(AggregateDataPtr place) const override {
this->data(place).reset(); }
+
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena&) const override {
+ this->data(place).merge(this->data(rhs));
+ }
+
+ void serialize(ConstAggregateDataPtr __restrict place, BufferWritable&
buf) const override {
+ this->data(place).write(buf);
+ }
+
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena&) const override {
+ this->data(place).read(buf);
+ }
+
+ void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
+ auto& column = assert_cast<ColumnFloat64&>(to);
+ column.get_data().push_back(this->data(place).result());
+ }
+};
+
+#include "common/compile_check_end.h"
+} // namespace doris::vectorized
\ No newline at end of file
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index 2fd696f8415..aa4410d347e 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -79,6 +79,7 @@ void
register_aggregate_function_kurtosis(AggregateFunctionSimpleFactory& factor
void
register_aggregate_function_percentile_reservoir(AggregateFunctionSimpleFactory&
factory);
void register_aggregate_function_ai_agg(AggregateFunctionSimpleFactory&
factory);
void register_aggregate_function_bool_union(AggregateFunctionSimpleFactory&
factory);
+void register_aggregate_function_sem(AggregateFunctionSimpleFactory& factory);
AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
static std::once_flag oc;
@@ -135,6 +136,7 @@ AggregateFunctionSimpleFactory&
AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_percentile_reservoir(instance);
register_aggregate_function_ai_agg(instance);
register_aggregate_function_bool_union(instance);
+ register_aggregate_function_sem(instance);
// Register foreach and foreachv2 functions
register_aggregate_function_combinator_foreach(instance);
register_aggregate_function_combinator_foreachv2(instance);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
index eb15275058a..9afebbfa73f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
@@ -77,6 +77,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.RegrIntercept;
import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSlope;
import org.apache.doris.nereids.trees.expressions.functions.agg.Retention;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sem;
import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceMatch;
import org.apache.doris.nereids.trees.expressions.functions.agg.Skew;
@@ -167,6 +168,7 @@ public class BuiltinAggregateFunctions implements
FunctionHelper {
agg(RegrIntercept.class, "regr_intercept"),
agg(RegrSlope.class, "regr_slope"),
agg(Retention.class, "retention"),
+ agg(Sem.class, "sem"),
agg(SequenceCount.class, "sequence_count"),
agg(SequenceMatch.class, "sequence_match"),
agg(Skew.class, "skew", "skew_pop", "skewness"),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sem.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sem.java
new file mode 100644
index 00000000000..7e0dbd53787
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sem.java
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.agg;
+
+import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.DoubleType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * AggregateFunction 'sem'.
+ */
+
+public class Sem extends NullableAggregateFunction
+ implements UnaryExpression, ExplicitlyCastableSignature {
+
+ public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
+
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE));
+
+ /**
+ * constructor with 1 argument.
+ */
+ public Sem(Expression child) {
+ this(false, false, child);
+ }
+
+ /**
+ * constructor with distinct flag and 1 argument.
+ */
+ public Sem(boolean distinct, Expression arg) {
+ this(distinct, false, arg);
+ }
+
+ private Sem(boolean distinct, boolean alwaysNullable, Expression arg) {
+ super("sem", distinct, alwaysNullable, arg);
+ }
+
+ /** constructor for withChildren and reuse signature */
+ private Sem(NullableAggregateFunctionParams functionParams) {
+ super(functionParams);
+ }
+
+ /**
+ * withDistinctAndChildren.
+ */
+ @Override
+ public Sem withDistinctAndChildren(boolean distinct, List<Expression>
children) {
+ Preconditions.checkArgument(children.size() == 1);
+ return new Sem(getFunctionParams(distinct, children));
+ }
+
+ @Override
+ public NullableAggregateFunction withAlwaysNullable(boolean
alwaysNullable) {
+ return new Sem(getAlwaysNullableFunctionParams(alwaysNullable));
+ }
+
+ @Override
+ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+ return visitor.visitSem(this, context);
+ }
+
+ @Override
+ public List<FunctionSignature> getSignatures() {
+ return SIGNATURES;
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
index 28af6ecda11..5493bdc44f9 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
@@ -76,6 +76,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.RegrIntercept;
import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSlope;
import org.apache.doris.nereids.trees.expressions.functions.agg.Retention;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sem;
import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceMatch;
import org.apache.doris.nereids.trees.expressions.functions.agg.Skew;
@@ -331,6 +332,10 @@ public interface AggregateFunctionVisitor<R, C> {
return visitNullableAggregateFunction(retention, context);
}
+ default R visitSem(Sem sem, C context) {
+ return visitNullableAggregateFunction(sem, context);
+ }
+
default R visitSequenceCount(SequenceCount sequenceCount, C context) {
return visitAggregateFunction(sequenceCount, context);
}
diff --git
a/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.out
b/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.out
index 380729fd459..aad0d7f0286 100644
---
a/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.out
+++
b/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.out
@@ -232,6 +232,60 @@ false
false
true
+-- !sem_bool --
+0.1333333333333333
+
+-- !sem_tinyint --
+1.154700538379251
+
+-- !sem_smallint --
+4529.188787732372
+
+-- !sem_int --
+296380900.8138039
+
+-- !sem_bigint --
+1.272946276726964e+18
+
+-- !sem_largeint --
+1.660407933506217e+37
+
+-- !sem_float --
+5350.392064205735
+
+-- !sem_double --
+66813.32467716311
+
+-- !sem_distinct_double --
+83605.72943255708
+
+-- !sem_window_double --
+\N
+141088.649482848
+19466.9385013329
+19466.9385013329
+19466.9385013329
+141088.649482848
+141088.649482848
+19466.9385013329
+141088.649482848
+141088.649482848
+19466.9385013329
+141088.649482848
+19466.9385013329
+19466.9385013329
+19466.9385013329
+141088.649482848
+
+-- !sem_null_double --
+66813.32467716312
+
+-- !sem_literal_single --
+0
+
+-- !sem_literal_multi --
+0.5773502691896257
+
-- !select_minmax1 --
20200622 1 \N
20200622 2 \N
diff --git
a/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.groovy
b/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.groovy
index 9a7c0b570da..5da30a1ae04 100644
---
a/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.groovy
+++
b/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.groovy
@@ -124,6 +124,28 @@ suite("test_aggregate_all_functions2") {
exception "requires a boolean or numeric argument"
}
+ qt_sem_bool """SELECT sem(k0) FROM baseall;""" // 0.1333333333
+ qt_sem_tinyint """SELECT sem(k1) FROM baseall;""" // 1.154700538
+ qt_sem_smallint """SELECT sem(k2) FROM baseall;""" // 4529.188786
+ qt_sem_int """SELECT sem(k3) FROM baseall;""" // 296380900.8
+ qt_sem_bigint """SELECT sem(k4) FROM baseall;""" //
1272946277000000000
+ qt_sem_largeint """SELECT sem(k13) FROM baseall;""" // 1.660407933e+37
+ qt_sem_float """SELECT sem(k9) FROM baseall;""" // 5350.392064
+ qt_sem_double """SELECT sem(k8) FROM baseall;""" // 66813.32469
+
+ qt_sem_distinct_double """SELECT sem(DISTINCT k8) FROM baseall;"""
+
+ qt_sem_window_double """SELECT sem(k8) OVER(PARTITION BY k6) FROM baseall
ORDER BY k1;"""
+
+ qt_sem_null_double """SELECT sem(k8) FROM baseall WHERE k8 IS NULL OR k8
IS NOT NULL;"""
+
+ qt_sem_literal_single """SELECT sem(CAST(1.0 AS DOUBLE));""" // 0.0
+ qt_sem_literal_multi """SELECT sem(CAST(x AS DOUBLE)) FROM (
+ SELECT 1.0 AS x
+ UNION ALL SELECT 2.0
+ UNION ALL SELECT 3.0
+ ) t;""" // 0.5773502692
+
sql "DROP DATABASE IF EXISTS metric_table"
sql """
CREATE TABLE `metric_table` (
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]