This is an automated email from the ASF dual-hosted git repository.

zclll pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 9d79ab8548e [Feature](function) Support uniform function (#55789)
9d79ab8548e is described below

commit 9d79ab8548e2f363a50947edbf1c83aa50afe791
Author: zclllyybb <[email protected]>
AuthorDate: Tue Sep 9 18:04:47 2025 +0800

    [Feature](function) Support uniform function (#55789)
    
    generate a uniform distribution random in specific range using given
    seeds.
    
    ```sql
    mysql> select uniform(1, 100, random() * 10000) as result from 
numbers("number" = "10");
    +--------+
    | result |
    +--------+
    |     92 |
    |     89 |
    |     64 |
    |     36 |
    |     34 |
    |     30 |
    |     15 |
    |     82 |
    |     49 |
    |     78 |
    +--------+
    ```
---
 be/src/vec/functions/simple_function_factory.h     |   2 +
 be/src/vec/functions/uniform.cpp                   | 187 +++++++++++++++++++++
 .../doris/catalog/BuiltinScalarFunctions.java      |   4 +-
 .../expressions/functions/scalar/Uniform.java      | 117 +++++++++++++
 .../expressions/visitor/ScalarFunctionVisitor.java |   5 +
 .../nereids_function_p0/scalar_function/U.groovy   |  47 ++++++
 6 files changed, 361 insertions(+), 1 deletion(-)

diff --git a/be/src/vec/functions/simple_function_factory.h 
b/be/src/vec/functions/simple_function_factory.h
index bd08c7af81f..906a294d3db 100644
--- a/be/src/vec/functions/simple_function_factory.h
+++ b/be/src/vec/functions/simple_function_factory.h
@@ -82,6 +82,7 @@ void register_function_ifnull(SimpleFunctionFactory& factory);
 void register_function_like(SimpleFunctionFactory& factory);
 void register_function_regexp(SimpleFunctionFactory& factory);
 void register_function_random(SimpleFunctionFactory& factory);
+void register_function_uniform(SimpleFunctionFactory& factory);
 void register_function_uuid(SimpleFunctionFactory& factory);
 void register_function_uuid_numeric(SimpleFunctionFactory& factory);
 void register_function_uuid_transforms(SimpleFunctionFactory& factory);
@@ -300,6 +301,7 @@ public:
             register_function_like(instance);
             register_function_regexp(instance);
             register_function_random(instance);
+            register_function_uniform(instance);
             register_function_uuid(instance);
             register_function_uuid_numeric(instance);
             register_function_uuid_transforms(instance);
diff --git a/be/src/vec/functions/uniform.cpp b/be/src/vec/functions/uniform.cpp
new file mode 100644
index 00000000000..56d85b26199
--- /dev/null
+++ b/be/src/vec/functions/uniform.cpp
@@ -0,0 +1,187 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <fmt/format.h>
+#include <glog/logging.h>
+
+#include <boost/iterator/iterator_facade.hpp>
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
+#include <random>
+#include <utility>
+
+#include "common/status.h"
+#include "runtime/primitive_type.h"
+#include "udf/udf.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/column.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/block.h"
+#include "vec/core/column_numbers.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type_number.h" // IWYU pragma: keep
+#include "vec/functions/function.h"
+#include "vec/functions/simple_function_factory.h"
+
+namespace doris::vectorized {
+#include "common/compile_check_begin.h"
+
+// Integer uniform implementation
+struct UniformIntImpl {
+    static DataTypes get_variadic_argument_types() {
+        return {std::make_shared<DataTypeInt64>(), 
std::make_shared<DataTypeInt64>(),
+                std::make_shared<DataTypeInt64>()};
+    }
+
+    static DataTypePtr get_return_type_impl(const ColumnsWithTypeAndName& 
arguments) {
+        return std::make_shared<DataTypeInt64>();
+    }
+
+    static Status execute_impl(FunctionContext* context, Block& block,
+                               const ColumnNumbers& arguments, uint32_t result,
+                               size_t input_rows_count) {
+        auto res_column = ColumnInt64::create(input_rows_count);
+        auto& res_data = static_cast<ColumnInt64&>(*res_column).get_data();
+
+        // Get min and max values (constants)
+        const auto& left =
+                assert_cast<const 
ColumnConst&>(*block.get_by_position(arguments[0]).column)
+                        .get_data_column();
+        const auto& right =
+                assert_cast<const 
ColumnConst&>(*block.get_by_position(arguments[1]).column)
+                        .get_data_column();
+        Int64 min = assert_cast<const ColumnInt64&>(left).get_element(0);
+        Int64 max = assert_cast<const ColumnInt64&>(right).get_element(0);
+
+        if (min >= max) {
+            return Status::InvalidArgument(
+                    "uniform's min should be less than max, but got [{}, {})", 
min, max);
+        }
+
+        // Get gen column (seed values)
+        const auto& gen_column = block.get_by_position(arguments[2]).column;
+
+        for (int i = 0; i < input_rows_count; i++) {
+            // Use gen value as seed for each row
+            auto seed = (*gen_column)[i].get<Int64>();
+            std::mt19937_64 generator(seed);
+            std::uniform_int_distribution<int64_t> distribution(min, max);
+            res_data[i] = distribution(generator);
+        }
+
+        block.replace_by_position(result, std::move(res_column));
+        return Status::OK();
+    }
+};
+
+// Double uniform implementation
+struct UniformDoubleImpl {
+    static DataTypes get_variadic_argument_types() {
+        return {std::make_shared<DataTypeFloat64>(), 
std::make_shared<DataTypeFloat64>(),
+                std::make_shared<DataTypeInt64>()};
+    }
+
+    static DataTypePtr get_return_type_impl(const ColumnsWithTypeAndName& 
arguments) {
+        return std::make_shared<DataTypeFloat64>();
+    }
+
+    static Status execute_impl(FunctionContext* context, Block& block,
+                               const ColumnNumbers& arguments, uint32_t result,
+                               size_t input_rows_count) {
+        auto res_column = ColumnFloat64::create(input_rows_count);
+        auto& res_data = static_cast<ColumnFloat64&>(*res_column).get_data();
+
+        // Get min and max values (constants)
+        const auto& left =
+                assert_cast<const 
ColumnConst&>(*block.get_by_position(arguments[0]).column)
+                        .get_data_column();
+        const auto& right =
+                assert_cast<const 
ColumnConst&>(*block.get_by_position(arguments[1]).column)
+                        .get_data_column();
+        double min = assert_cast<const ColumnFloat64&>(left).get_element(0);
+        double max = assert_cast<const ColumnFloat64&>(right).get_element(0);
+
+        if (min >= max) {
+            return Status::InvalidArgument(
+                    "uniform's min should be less than max, but got [{}, {})", 
min, max);
+        }
+
+        // Get gen column (seed values)
+        const auto& gen_column = block.get_by_position(arguments[2]).column;
+
+        for (int i = 0; i < input_rows_count; i++) {
+            // Use gen value as seed for each row
+            auto seed = (*gen_column)[i].get<Int64>();
+            std::mt19937_64 generator(seed);
+            std::uniform_real_distribution<double> distribution(min, max);
+            res_data[i] = distribution(generator);
+        }
+
+        block.replace_by_position(result, std::move(res_column));
+        return Status::OK();
+    }
+};
+
+template <typename Impl>
+class FunctionUniform : public IFunction {
+public:
+    static constexpr auto name = "uniform";
+
+    static FunctionPtr create() { return 
std::make_shared<FunctionUniform<Impl>>(); }
+    String get_name() const override { return name; }
+
+    size_t get_number_of_arguments() const override {
+        return get_variadic_argument_types_impl().size();
+    }
+
+    DataTypePtr get_return_type_impl(const ColumnsWithTypeAndName& arguments) 
const override {
+        return Impl::get_return_type_impl(arguments);
+    }
+
+    DataTypes get_variadic_argument_types_impl() const override {
+        return Impl::get_variadic_argument_types();
+    }
+
+    Status open(FunctionContext* context, FunctionContext::FunctionStateScope 
scope) override {
+        // init_function_context do set_constant_cols for FRAGMENT_LOCAL scope
+        if (scope == FunctionContext::FRAGMENT_LOCAL) {
+            if (!context->is_col_constant(0)) {
+                return Status::InvalidArgument(
+                        "The first parameter (min) of uniform function must be 
literal");
+            }
+            if (!context->is_col_constant(1)) {
+                return Status::InvalidArgument(
+                        "The second parameter (max) of uniform function must 
be literal");
+            }
+        }
+        return Status::OK();
+    }
+
+    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                        uint32_t result, size_t input_rows_count) const 
override {
+        return Impl::execute_impl(context, block, arguments, result, 
input_rows_count);
+    }
+};
+
+void register_function_uniform(SimpleFunctionFactory& factory) {
+    factory.register_function<FunctionUniform<UniformIntImpl>>();
+    factory.register_function<FunctionUniform<UniformDoubleImpl>>();
+}
+#include "common/compile_check_end.h"
+} // namespace doris::vectorized
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
index 2116a6d3821..dbad709fff0 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java
@@ -487,6 +487,7 @@ import 
org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.Uncompress;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.Unhex;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.UnhexNull;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Uniform;
 import 
org.apache.doris.nereids.trees.expressions.functions.scalar.UnixTimestamp;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.Upper;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.UrlDecode;
@@ -1035,7 +1036,8 @@ public class BuiltinScalarFunctions implements 
FunctionHelper {
             scalar(AIMask.class, "ai_mask"),
             scalar(AISummarize.class, "ai_summarize"),
             scalar(AISimilarity.class, "ai_similarity"),
-            scalar(Embed.class, "embed"));
+            scalar(Embed.class, "embed"),
+            scalar(Uniform.class, "uniform"));
 
     public static final BuiltinScalarFunctions INSTANCE = new 
BuiltinScalarFunctions();
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Uniform.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Uniform.java
new file mode 100644
index 00000000000..fbae666c55f
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Uniform.java
@@ -0,0 +1,117 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.scalar;
+
+import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.exceptions.AnalysisException;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import 
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.DoubleType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * ScalarFunction 'uniform'. This function generates uniform random numbers.
+ * Signature: UNIFORM(min, max, gen)
+ * - min, max: literal values defining the range [min, max]
+ * - gen: expression used as seed for random generation
+ * - If min/max are both integers, returns integer; otherwise returns double
+ */
+public class Uniform extends ScalarFunction
+        implements ExplicitlyCastableSignature {
+
+    public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE, 
BigIntType.INSTANCE,
+                    BigIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, 
DoubleType.INSTANCE,
+                    BigIntType.INSTANCE));
+
+    /**
+     * constructor with 3 arguments.
+     */
+    public Uniform(Expression min, Expression max, Expression gen) {
+        super("uniform", min, max, gen);
+    }
+
+    /** constructor for withChildren and reuse signature */
+    private Uniform(ScalarFunctionParams functionParams) {
+        super(functionParams);
+    }
+
+    @Override
+    public void checkLegalityBeforeTypeCoercion() {
+        if (!child(0).isLiteral()) {
+            throw new AnalysisException("The first parameter (min) of uniform 
function must be literal");
+        }
+        if (!child(1).isLiteral()) {
+            throw new AnalysisException("The second parameter (max) of uniform 
function must be literal");
+        }
+        // if do folding on BE, will before checkLegalityAfterRewrite, so we 
need it here too.
+        checkLegalityAfterRewrite();
+    }
+
+    @Override
+    public void checkLegalityAfterRewrite() {
+        if (child(2).isLiteral()) {
+            throw new AnalysisException("The third parameter (gen) of uniform 
function must not be literal");
+        }
+    }
+
+    @Override
+    public FunctionSignature computeSignature(FunctionSignature signature) {
+        if (child(0).getDataType().isIntegralType() && 
child(1).getDataType().isIntegralType()) {
+            // both integer, prefer integer return type
+            return SIGNATURES.get(0);
+        } else {
+            // otherwise, prefer double return type
+            return SIGNATURES.get(1);
+        }
+    }
+
+    /**
+     * custom compute nullable.
+     */
+    @Override
+    public boolean nullable() {
+        return children().stream().anyMatch(Expression::nullable);
+    }
+
+    /**
+     * withChildren.
+     */
+    @Override
+    public Uniform withChildren(List<Expression> children) {
+        Preconditions.checkArgument(children.size() == 3);
+        return new Uniform(getFunctionParams(children));
+    }
+
+    @Override
+    public List<FunctionSignature> getSignatures() {
+        return SIGNATURES;
+    }
+
+    @Override
+    public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+        return visitor.visitUniform(this, context);
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
index 7ed4adb87ea..d78d5bb7ff8 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
@@ -488,6 +488,7 @@ import 
org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.Uncompress;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.Unhex;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.UnhexNull;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Uniform;
 import 
org.apache.doris.nereids.trees.expressions.functions.scalar.UnixTimestamp;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.Upper;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.UrlDecode;
@@ -2545,4 +2546,8 @@ public interface ScalarFunctionVisitor<R, C> {
     default R visitEmbed(Embed embed, C context) {
         return visitScalarFunction(embed, context);
     }
+
+    default R visitUniform(Uniform uniform, C context) {
+        return visitScalarFunction(uniform, context);
+    }
 }
diff --git 
a/regression-test/suites/nereids_function_p0/scalar_function/U.groovy 
b/regression-test/suites/nereids_function_p0/scalar_function/U.groovy
index 47133a0a08b..4105389bc6f 100644
--- a/regression-test/suites/nereids_function_p0/scalar_function/U.groovy
+++ b/regression-test/suites/nereids_function_p0/scalar_function/U.groovy
@@ -47,4 +47,51 @@ suite("nereids_scalar_fn_U") {
     qt_sql_url_decode_empty "select url_decode('');"
     qt_sql_url_decode_null "select url_decode(null);"
     qt_sql_url_decode_invalid_url "select url_decode('This is not a url');"
+
+       def result = sql """select uniform(1, 100, random()*10000) from 
numbers("number" = "10");"""
+       assertTrue(result.size() == 10)
+       test {
+               sql """select uniform(100, 1, random()*10000) from 
numbers("number" = "10");"""
+               exception "uniform's min should be less than max"
+       }
+       test {
+               sql """select uniform(100, 1, 1) from numbers("number" = 
"10");"""
+               exception "The third parameter (gen) of uniform function must 
not be literal"
+       }
+       test {
+               sql """select uniform(100, 1, 1) from numbers("number" = 
"10");"""
+               exception "The third parameter (gen) of uniform function must 
not be literal"
+       }
+       sql "set enable_fold_constant_by_be=true;"
+       test {
+               sql """select uniform(100, 1, 1) from numbers("number" = 
"10");"""
+               exception "The third parameter (gen) of uniform function must 
not be literal"
+       }
+       sql "set enable_fold_constant_by_be=false;"
+       test {
+               sql """select uniform(ksint, 1, random()) from fn_test;"""
+               exception "The first parameter (min) of uniform function must 
be literal"
+       }
+       test {
+               sql """select uniform(1, kint, random()) from fn_test;"""
+               exception "The second parameter (max) of uniform function must 
be literal"
+       }
+       sql """ select uniform(1, 100, v.x) from (select random() * 10000 as x 
from numbers("number" = "10")) v; """
+       sql """ select uniform(1, 100, kdbl) from (select kdbl from fn_test) v; 
"""
+       test {
+               sql """select uniform(1, kint, random()) from fn_test;"""
+               exception "The second parameter (max) of uniform function must 
be literal"
+       }
+       explain {
+               sql """select uniform(1, 100.100, random()*10000) as result 
from numbers("number" = "10");"""
+               checkSlotTypeOf("result", "double")
+       }
+       explain {
+               sql """select uniform(1.23, 100.100, random()*10000) as result 
from numbers("number" = "10");"""
+               checkSlotTypeOf("result", "double")
+       }
+       explain {
+               sql """select uniform(1, 100, random()*10000) as result from 
numbers("number" = "10");"""
+               checkSlotTypeOf("result", "bigint")
+       }
 }
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to