This is an automated email from the ASF dual-hosted git repository.
zclll pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 9d79ab8548e [Feature](function) Support uniform function (#55789)
9d79ab8548e is described below
commit 9d79ab8548e2f363a50947edbf1c83aa50afe791
Author: zclllyybb <[email protected]>
AuthorDate: Tue Sep 9 18:04:47 2025 +0800
[Feature](function) Support uniform function (#55789)
generate a uniform distribution random in specific range using given
seeds.
```sql
mysql> select uniform(1, 100, random() * 10000) as result from
numbers("number" = "10");
+--------+
| result |
+--------+
| 92 |
| 89 |
| 64 |
| 36 |
| 34 |
| 30 |
| 15 |
| 82 |
| 49 |
| 78 |
+--------+
```
---
be/src/vec/functions/simple_function_factory.h | 2 +
be/src/vec/functions/uniform.cpp | 187 +++++++++++++++++++++
.../doris/catalog/BuiltinScalarFunctions.java | 4 +-
.../expressions/functions/scalar/Uniform.java | 117 +++++++++++++
.../expressions/visitor/ScalarFunctionVisitor.java | 5 +
.../nereids_function_p0/scalar_function/U.groovy | 47 ++++++
6 files changed, 361 insertions(+), 1 deletion(-)
diff --git a/be/src/vec/functions/simple_function_factory.h
b/be/src/vec/functions/simple_function_factory.h
index bd08c7af81f..906a294d3db 100644
--- a/be/src/vec/functions/simple_function_factory.h
+++ b/be/src/vec/functions/simple_function_factory.h
@@ -82,6 +82,7 @@ void register_function_ifnull(SimpleFunctionFactory& factory);
void register_function_like(SimpleFunctionFactory& factory);
void register_function_regexp(SimpleFunctionFactory& factory);
void register_function_random(SimpleFunctionFactory& factory);
+void register_function_uniform(SimpleFunctionFactory& factory);
void register_function_uuid(SimpleFunctionFactory& factory);
void register_function_uuid_numeric(SimpleFunctionFactory& factory);
void register_function_uuid_transforms(SimpleFunctionFactory& factory);
@@ -300,6 +301,7 @@ public:
register_function_like(instance);
register_function_regexp(instance);
register_function_random(instance);
+ register_function_uniform(instance);
register_function_uuid(instance);
register_function_uuid_numeric(instance);
register_function_uuid_transforms(instance);
diff --git a/be/src/vec/functions/uniform.cpp b/be/src/vec/functions/uniform.cpp
new file mode 100644
index 00000000000..56d85b26199
--- /dev/null
+++ b/be/src/vec/functions/uniform.cpp
@@ -0,0 +1,187 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <fmt/format.h>
+#include <glog/logging.h>
+
+#include <boost/iterator/iterator_facade.hpp>
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
+#include <random>
+#include <utility>
+
+#include "common/status.h"
+#include "runtime/primitive_type.h"
+#include "udf/udf.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/column.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/block.h"
+#include "vec/core/column_numbers.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type_number.h" // IWYU pragma: keep
+#include "vec/functions/function.h"
+#include "vec/functions/simple_function_factory.h"
+
+namespace doris::vectorized {
+#include "common/compile_check_begin.h"
+
+// Integer uniform implementation
+struct UniformIntImpl {
+ static DataTypes get_variadic_argument_types() {
+ return {std::make_shared<DataTypeInt64>(),
std::make_shared<DataTypeInt64>(),
+ std::make_shared<DataTypeInt64>()};
+ }
+
+ static DataTypePtr get_return_type_impl(const ColumnsWithTypeAndName&
arguments) {
+ return std::make_shared<DataTypeInt64>();
+ }
+
+ static Status execute_impl(FunctionContext* context, Block& block,
+ const ColumnNumbers& arguments, uint32_t result,
+ size_t input_rows_count) {
+ auto res_column = ColumnInt64::create(input_rows_count);
+ auto& res_data = static_cast<ColumnInt64&>(*res_column).get_data();
+
+ // Get min and max values (constants)
+ const auto& left =
+ assert_cast<const
ColumnConst&>(*block.get_by_position(arguments[0]).column)
+ .get_data_column();
+ const auto& right =
+ assert_cast<const
ColumnConst&>(*block.get_by_position(arguments[1]).column)
+ .get_data_column();
+ Int64 min = assert_cast<const ColumnInt64&>(left).get_element(0);
+ Int64 max = assert_cast<const ColumnInt64&>(right).get_element(0);
+
+ if (min >= max) {
+ return Status::InvalidArgument(
+ "uniform's min should be less than max, but got [{}, {})",
min, max);
+ }
+
+ // Get gen column (seed values)
+ const auto& gen_column = block.get_by_position(arguments[2]).column;
+
+ for (int i = 0; i < input_rows_count; i++) {
+ // Use gen value as seed for each row
+ auto seed = (*gen_column)[i].get<Int64>();
+ std::mt19937_64 generator(seed);
+ std::uniform_int_distribution<int64_t> distribution(min, max);
+ res_data[i] = distribution(generator);
+ }
+
+ block.replace_by_position(result, std::move(res_column));
+ return Status::OK();
+ }
+};
+
+// Double uniform implementation
+struct UniformDoubleImpl {
+ static DataTypes get_variadic_argument_types() {
+ return {std::make_shared<DataTypeFloat64>(),
std::make_shared<DataTypeFloat64>(),
+ std::make_shared<DataTypeInt64>()};
+ }
+
+ static DataTypePtr get_return_type_impl(const ColumnsWithTypeAndName&
arguments) {
+ return std::make_shared<DataTypeFloat64>();
+ }
+
+ static Status execute_impl(FunctionContext* context, Block& block,
+ const ColumnNumbers& arguments, uint32_t result,
+ size_t input_rows_count) {
+ auto res_column = ColumnFloat64::create(input_rows_count);
+ auto& res_data = static_cast<ColumnFloat64&>(*res_column).get_data();
+
+ // Get min and max values (constants)
+ const auto& left =
+ assert_cast<const
ColumnConst&>(*block.get_by_position(arguments[0]).column)
+ .get_data_column();
+ const auto& right =
+ assert_cast<const
ColumnConst&>(*block.get_by_position(arguments[1]).column)
+ .get_data_column();
+ double min = assert_cast<const ColumnFloat64&>(left).get_element(0);
+ double max = assert_cast<const ColumnFloat64&>(right).get_element(0);
+
+ if (min >= max) {
+ return Status::InvalidArgument(
+ "uniform's min should be less than max, but got [{}, {})",
min, max);
+ }
+
+ // Get gen column (seed values)
+ const auto& gen_column = block.get_by_position(arguments[2]).column;
+
+ for (int i = 0; i < input_rows_count; i++) {
+ // Use gen value as seed for each row
+ auto seed = (*gen_column)[i].get<Int64>();
+ std::mt19937_64 generator(seed);
+ std::uniform_real_distribution<double> distribution(min, max);
+ res_data[i] = distribution(generator);
+ }
+
+ block.replace_by_position(result, std::move(res_column));
+ return Status::OK();
+ }
+};
+
+template <typename Impl>
+class FunctionUniform : public IFunction {
+public:
+ static constexpr auto name = "uniform";
+
+ static FunctionPtr create() { return
std::make_shared<FunctionUniform<Impl>>(); }
+ String get_name() const override { return name; }
+
+ size_t get_number_of_arguments() const override {
+ return get_variadic_argument_types_impl().size();
+ }
+
+ DataTypePtr get_return_type_impl(const ColumnsWithTypeAndName& arguments)
const override {
+ return Impl::get_return_type_impl(arguments);
+ }
+
+ DataTypes get_variadic_argument_types_impl() const override {
+ return Impl::get_variadic_argument_types();
+ }
+
+ Status open(FunctionContext* context, FunctionContext::FunctionStateScope
scope) override {
+ // init_function_context do set_constant_cols for FRAGMENT_LOCAL scope
+ if (scope == FunctionContext::FRAGMENT_LOCAL) {
+ if (!context->is_col_constant(0)) {
+ return Status::InvalidArgument(
+ "The first parameter (min) of uniform function must be
literal");
+ }
+ if (!context->is_col_constant(1)) {
+ return Status::InvalidArgument(
+ "The second parameter (max) of uniform function must
be literal");
+ }
+ }
+ return Status::OK();
+ }
+
+ Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
+ uint32_t result, size_t input_rows_count) const
override {
+ return Impl::execute_impl(context, block, arguments, result,
input_rows_count);
+ }
+};
+
+void register_function_uniform(SimpleFunctionFactory& factory) {
+ factory.register_function<FunctionUniform<UniformIntImpl>>();
+ factory.register_function<FunctionUniform<UniformDoubleImpl>>();
+}
+#include "common/compile_check_end.h"
+} // namespace doris::vectorized
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
index 2116a6d3821..dbad709fff0 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
@@ -487,6 +487,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Uncompress;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Unhex;
import org.apache.doris.nereids.trees.expressions.functions.scalar.UnhexNull;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Uniform;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.UnixTimestamp;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Upper;
import org.apache.doris.nereids.trees.expressions.functions.scalar.UrlDecode;
@@ -1035,7 +1036,8 @@ public class BuiltinScalarFunctions implements
FunctionHelper {
scalar(AIMask.class, "ai_mask"),
scalar(AISummarize.class, "ai_summarize"),
scalar(AISimilarity.class, "ai_similarity"),
- scalar(Embed.class, "embed"));
+ scalar(Embed.class, "embed"),
+ scalar(Uniform.class, "uniform"));
public static final BuiltinScalarFunctions INSTANCE = new
BuiltinScalarFunctions();
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Uniform.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Uniform.java
new file mode 100644
index 00000000000..fbae666c55f
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Uniform.java
@@ -0,0 +1,117 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.scalar;
+
+import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.exceptions.AnalysisException;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.DoubleType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * ScalarFunction 'uniform'. This function generates uniform random numbers.
+ * Signature: UNIFORM(min, max, gen)
+ * - min, max: literal values defining the range [min, max]
+ * - gen: expression used as seed for random generation
+ * - If min/max are both integers, returns integer; otherwise returns double
+ */
+public class Uniform extends ScalarFunction
+ implements ExplicitlyCastableSignature {
+
+ public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
+
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE,
BigIntType.INSTANCE,
+ BigIntType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE,
DoubleType.INSTANCE,
+ BigIntType.INSTANCE));
+
+ /**
+ * constructor with 3 arguments.
+ */
+ public Uniform(Expression min, Expression max, Expression gen) {
+ super("uniform", min, max, gen);
+ }
+
+ /** constructor for withChildren and reuse signature */
+ private Uniform(ScalarFunctionParams functionParams) {
+ super(functionParams);
+ }
+
+ @Override
+ public void checkLegalityBeforeTypeCoercion() {
+ if (!child(0).isLiteral()) {
+ throw new AnalysisException("The first parameter (min) of uniform
function must be literal");
+ }
+ if (!child(1).isLiteral()) {
+ throw new AnalysisException("The second parameter (max) of uniform
function must be literal");
+ }
+ // if do folding on BE, will before checkLegalityAfterRewrite, so we
need it here too.
+ checkLegalityAfterRewrite();
+ }
+
+ @Override
+ public void checkLegalityAfterRewrite() {
+ if (child(2).isLiteral()) {
+ throw new AnalysisException("The third parameter (gen) of uniform
function must not be literal");
+ }
+ }
+
+ @Override
+ public FunctionSignature computeSignature(FunctionSignature signature) {
+ if (child(0).getDataType().isIntegralType() &&
child(1).getDataType().isIntegralType()) {
+ // both integer, prefer integer return type
+ return SIGNATURES.get(0);
+ } else {
+ // otherwise, prefer double return type
+ return SIGNATURES.get(1);
+ }
+ }
+
+ /**
+ * custom compute nullable.
+ */
+ @Override
+ public boolean nullable() {
+ return children().stream().anyMatch(Expression::nullable);
+ }
+
+ /**
+ * withChildren.
+ */
+ @Override
+ public Uniform withChildren(List<Expression> children) {
+ Preconditions.checkArgument(children.size() == 3);
+ return new Uniform(getFunctionParams(children));
+ }
+
+ @Override
+ public List<FunctionSignature> getSignatures() {
+ return SIGNATURES;
+ }
+
+ @Override
+ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+ return visitor.visitUniform(this, context);
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
index 7ed4adb87ea..d78d5bb7ff8 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
@@ -488,6 +488,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Uncompress;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Unhex;
import org.apache.doris.nereids.trees.expressions.functions.scalar.UnhexNull;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Uniform;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.UnixTimestamp;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Upper;
import org.apache.doris.nereids.trees.expressions.functions.scalar.UrlDecode;
@@ -2545,4 +2546,8 @@ public interface ScalarFunctionVisitor<R, C> {
default R visitEmbed(Embed embed, C context) {
return visitScalarFunction(embed, context);
}
+
+ default R visitUniform(Uniform uniform, C context) {
+ return visitScalarFunction(uniform, context);
+ }
}
diff --git
a/regression-test/suites/nereids_function_p0/scalar_function/U.groovy
b/regression-test/suites/nereids_function_p0/scalar_function/U.groovy
index 47133a0a08b..4105389bc6f 100644
--- a/regression-test/suites/nereids_function_p0/scalar_function/U.groovy
+++ b/regression-test/suites/nereids_function_p0/scalar_function/U.groovy
@@ -47,4 +47,51 @@ suite("nereids_scalar_fn_U") {
qt_sql_url_decode_empty "select url_decode('');"
qt_sql_url_decode_null "select url_decode(null);"
qt_sql_url_decode_invalid_url "select url_decode('This is not a url');"
+
+ def result = sql """select uniform(1, 100, random()*10000) from
numbers("number" = "10");"""
+ assertTrue(result.size() == 10)
+ test {
+ sql """select uniform(100, 1, random()*10000) from
numbers("number" = "10");"""
+ exception "uniform's min should be less than max"
+ }
+ test {
+ sql """select uniform(100, 1, 1) from numbers("number" =
"10");"""
+ exception "The third parameter (gen) of uniform function must
not be literal"
+ }
+ test {
+ sql """select uniform(100, 1, 1) from numbers("number" =
"10");"""
+ exception "The third parameter (gen) of uniform function must
not be literal"
+ }
+ sql "set enable_fold_constant_by_be=true;"
+ test {
+ sql """select uniform(100, 1, 1) from numbers("number" =
"10");"""
+ exception "The third parameter (gen) of uniform function must
not be literal"
+ }
+ sql "set enable_fold_constant_by_be=false;"
+ test {
+ sql """select uniform(ksint, 1, random()) from fn_test;"""
+ exception "The first parameter (min) of uniform function must
be literal"
+ }
+ test {
+ sql """select uniform(1, kint, random()) from fn_test;"""
+ exception "The second parameter (max) of uniform function must
be literal"
+ }
+ sql """ select uniform(1, 100, v.x) from (select random() * 10000 as x
from numbers("number" = "10")) v; """
+ sql """ select uniform(1, 100, kdbl) from (select kdbl from fn_test) v;
"""
+ test {
+ sql """select uniform(1, kint, random()) from fn_test;"""
+ exception "The second parameter (max) of uniform function must
be literal"
+ }
+ explain {
+ sql """select uniform(1, 100.100, random()*10000) as result
from numbers("number" = "10");"""
+ checkSlotTypeOf("result", "double")
+ }
+ explain {
+ sql """select uniform(1.23, 100.100, random()*10000) as result
from numbers("number" = "10");"""
+ checkSlotTypeOf("result", "double")
+ }
+ explain {
+ sql """select uniform(1, 100, random()*10000) as result from
numbers("number" = "10");"""
+ checkSlotTypeOf("result", "bigint")
+ }
}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]