This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new af8d027b969 branch-4.0: [feat](func) Support two-args version of atan 
#56561 (#56860)
af8d027b969 is described below

commit af8d027b9690489c703fb1c83d87c01144f2f6ed
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Sun Oct 12 12:06:32 2025 +0800

    branch-4.0: [feat](func) Support two-args version of atan #56561 (#56860)
    
    Cherry-picked from #56561
    
    Co-authored-by: linrrarity <[email protected]>
---
 be/src/vec/functions/math.cpp                      |  80 +++++++++++-
 .../functions/executable/NumericArithmetic.java    |   8 ++
 .../trees/expressions/functions/scalar/Atan.java   |  12 +-
 .../data/function_p0/test_math_function.out        | 127 ++++++++++++++++++
 .../suites/function_p0/test_math_function.groovy   | 142 +++++++++++++++++++++
 5 files changed, 365 insertions(+), 4 deletions(-)

diff --git a/be/src/vec/functions/math.cpp b/be/src/vec/functions/math.cpp
index a8ae52fa75a..87117b80f2c 100644
--- a/be/src/vec/functions/math.cpp
+++ b/be/src/vec/functions/math.cpp
@@ -71,10 +71,86 @@ struct AsinhName {
 };
 using FunctionAsinh = FunctionMathUnary<UnaryFunctionPlain<AsinhName, 
std::asinh>>;
 
-struct AtanName {
+class FunctionAtan : public IFunction {
+public:
     static constexpr auto name = "atan";
+    static FunctionPtr create() { return std::make_shared<FunctionAtan>(); }
+
+    String get_name() const override { return name; }
+    bool is_variadic() const override { return true; }
+    size_t get_number_of_arguments() const override { return 0; }
+
+    DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
+        return std::make_shared<DataTypeFloat64>();
+    }
+
+    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                        uint32_t result, size_t input_rows_count) const 
override {
+        if (arguments.size() == 1) {
+            return execute_unary(block, arguments, result, input_rows_count);
+        } else if (arguments.size() == 2) {
+            return execute_binary(block, arguments, result, input_rows_count);
+        } else {
+            return Status::InvalidArgument("atan function expects 1 or 2 
arguments, but got {}",
+                                           arguments.size());
+        }
+    }
+
+private:
+    Status execute_unary(Block& block, const ColumnNumbers& arguments, 
uint32_t result,
+                         size_t input_rows_count) const {
+        auto res_col = ColumnFloat64::create(input_rows_count);
+        auto& res_data = res_col->get_data();
+
+        const auto& col_data =
+                assert_cast<const 
ColumnFloat64*>(block.get_by_position(arguments[0]).column.get())
+                        ->get_data();
+        for (size_t i = 0; i < input_rows_count; ++i) {
+            res_data[i] = std::atan(col_data[i]);
+        }
+
+        block.replace_by_position(result, std::move(res_col));
+        return Status::OK();
+    }
+
+    Status execute_binary(Block& block, const ColumnNumbers& arguments, 
uint32_t result,
+                          size_t input_rows_count) const {
+        auto [col_y, is_const_y] = 
unpack_if_const(block.get_by_position(arguments[0]).column);
+        auto [col_x, is_const_x] = 
unpack_if_const(block.get_by_position(arguments[1]).column);
+
+        auto result_column = ColumnFloat64::create(input_rows_count);
+        auto& result_data = result_column->get_data();
+
+        if (is_const_y) {
+            auto y_val = assert_cast<const 
ColumnFloat64*>(col_y.get())->get_element(0);
+
+            const auto* x_col = assert_cast<const ColumnFloat64*>(col_x.get());
+            const auto& x_data = x_col->get_data();
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                result_data[i] = std::atan2(y_val, x_data[i]);
+            }
+        } else if (is_const_x) {
+            auto x_val = assert_cast<const 
ColumnFloat64*>(col_x.get())->get_element(0);
+
+            const auto* y_col = assert_cast<const ColumnFloat64*>(col_y.get());
+            const auto& y_data = y_col->get_data();
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                result_data[i] = std::atan2(y_data[i], x_val);
+            }
+        } else {
+            const auto* y_col = assert_cast<const ColumnFloat64*>(col_y.get());
+            const auto* x_col = assert_cast<const ColumnFloat64*>(col_x.get());
+            const auto& y_data = y_col->get_data();
+            const auto& x_data = x_col->get_data();
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                result_data[i] = std::atan2(y_data[i], x_data[i]);
+            }
+        }
+
+        block.replace_by_position(result, std::move(result_column));
+        return Status::OK();
+    }
 };
-using FunctionAtan = FunctionMathUnary<UnaryFunctionPlain<AtanName, 
std::atan>>;
 
 struct AtanhName {
     static constexpr auto name = "atanh";
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java
index 56c8a1f6c40..c68d2456ac8 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java
@@ -647,6 +647,14 @@ public class NumericArithmetic {
         return new DoubleLiteral(Math.atan(first.getValue()));
     }
 
+    /**
+     * atan
+     */
+    @ExecFunction(name = "atan")
+    public static Expression atan(DoubleLiteral first, DoubleLiteral second) {
+        return new DoubleLiteral(Math.atan2(first.getValue(), 
second.getValue()));
+    }
+
     /**
      * asinh
      */
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Atan.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Atan.java
index 2750fbd81c1..e131129ab26 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Atan.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Atan.java
@@ -37,7 +37,8 @@ public class Atan extends ScalarFunction
         implements UnaryExpression, ExplicitlyCastableSignature, 
PropagateNullable {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, 
DoubleType.INSTANCE)
     );
 
     /**
@@ -47,6 +48,13 @@ public class Atan extends ScalarFunction
         super("atan", arg);
     }
 
+    /**
+     * constructor with 2 argument.
+     */
+    public Atan(Expression arg0, Expression arg1) {
+        super("atan", arg0, arg1);
+    }
+
     /** constructor for withChildren and reuse signature */
     private Atan(ScalarFunctionParams functionParams) {
         super(functionParams);
@@ -57,7 +65,7 @@ public class Atan extends ScalarFunction
      */
     @Override
     public Atan withChildren(List<Expression> children) {
-        Preconditions.checkArgument(children.size() == 1);
+        Preconditions.checkArgument(children.size() == 1 || children.size() == 
2);
         return new Atan(getFunctionParams(children));
     }
 
diff --git a/regression-test/data/function_p0/test_math_function.out 
b/regression-test/data/function_p0/test_math_function.out
index 3490fa98f1f..3c3cebd5799 100644
--- a/regression-test/data/function_p0/test_math_function.out
+++ b/regression-test/data/function_p0/test_math_function.out
@@ -26,6 +26,133 @@ NaN 1.570796326794897       -1.570796326794897      0.0     
-0.0    0.7853981633974483      -0.78539816
 -- !select_atan2 --
 NaN    NaN     0.0     -0.0    0.0     -0.0    3.141592653589793       
-3.141592653589793      3.141592653589793       -3.141592653589793      
1.570796326794897       1.570796326794897       -1.570796326794897      
-1.570796326794897      1.570796326794897       -1.570796326794897      
2.356194490192345       -2.356194490192345      0.7853981633974483      
-0.7853981633974483     3.141592653589793       -3.141592653589793      0.0     
-0.0
 
+-- !select_atan_with_two_args --
+NaN    NaN     0.0     -0.0    0.0     -0.0    3.141592653589793       
-3.141592653589793      3.141592653589793       -3.141592653589793      
1.570796326794897       1.570796326794897       -1.570796326794897      
-1.570796326794897      1.570796326794897       -1.570796326794897      
2.356194490192345       -2.356194490192345      0.7853981633974483      
-0.7853981633974483     3.141592653589793       -3.141592653589793      0.0     
-0.0
+
+-- !select_atan_single_param_from_table --
+0
+0.7853981633974483
+-0.7853981633974483
+0.4636476090008061
+-0.4636476090008061
+1.047197551304379
+NaN
+1.570796326794897
+-1.570796326794897
+0
+0.7853981633974483
+-0.7853981633974483
+0
+0.7853981633974483
+-0.7853981633974483
+1.570796326794897
+-1.570796326794897
+1.570796326794897
+-1.570796326794897
+0.7853981633974483
+-0.7853981633974483
+0.7853981633974483
+-0.7853981633974483
+0
+0
+0
+0
+1.570796326794897
+-1.570796326794897
+
+-- !select_atan_two_params_from_table --
+0
+0.7853981633974483
+-0.7853981633974483
+0.4636476090008061
+-0.4636476090008061
+1.047197551088817
+NaN
+1.570796326794897
+-1.570796326794897
+0
+1.570796326794897
+-1.570796326794897
+3.141592653589793
+2.356194490192345
+-2.356194490192345
+0.7853981633974483
+-0.7853981633974483
+2.356194490192345
+-2.356194490192345
+0
+-0
+3.141592653589793
+-3.141592653589793
+0
+3.141592653589793
+0
+0
+1.570796326794897
+-1.570796326794897
+
+-- !select_atan_with_first_const --
+0.7853981633974483
+0.7853981633974483
+0.7853981633974483
+0.4636476090008061
+0.4636476090008061
+0.5235987754905181
+0.7853981633974483
+0.7853981633974483
+0.7853981633974483
+1.570796326794897
+1.570796326794897
+1.570796326794897
+2.356194490192345
+2.356194490192345
+2.356194490192345
+0
+0
+3.141592653589793
+3.141592653589793
+0
+0
+3.141592653589793
+3.141592653589793
+0.7853981633974483
+2.356194490192345
+1.570796326794897
+1.570796326794897
+0.7853981633974483
+0.7853981633974483
+
+-- !select_atan_with_second_const --
+0
+0.7853981633974483
+-0.7853981633974483
+0.7853981633974483
+-0.7853981633974483
+1.249045772398254
+NaN
+1.570796326794897
+-1.570796326794897
+0
+0.7853981633974483
+-0.7853981633974483
+0
+0.7853981633974483
+-0.7853981633974483
+1.570796326794897
+-1.570796326794897
+1.570796326794897
+-1.570796326794897
+0.7853981633974483
+-0.7853981633974483
+0.7853981633974483
+-0.7853981633974483
+0
+0
+0
+0
+1.570796326794897
+-1.570796326794897
+
 -- !select_cbrt --
 NaN    Infinity        -Infinity
 
diff --git a/regression-test/suites/function_p0/test_math_function.groovy 
b/regression-test/suites/function_p0/test_math_function.groovy
index 2d0e26eaaa6..466820c2152 100644
--- a/regression-test/suites/function_p0/test_math_function.groovy
+++ b/regression-test/suites/function_p0/test_math_function.groovy
@@ -49,6 +49,57 @@ suite("test_math_function") {
 
     qt_select_atan""" select atan(cast('nan' as double)), atan(cast('inf' as 
double)), atan(cast('-inf' as double)), atan(cast('0.0' as double)), 
atan(cast('-0.0' as double)), atan(cast('1.0' as double)), atan(cast('-1.0' as 
double)), atan(cast('1e308' as double)), atan(cast('-1e308' as double)) """
 
+    def mathFuncTestTable = "math_function_test_table";
+
+    sql """ DROP TABLE IF EXISTS ${mathFuncTestTable}; """
+
+    sql """
+        CREATE TABLE IF NOT EXISTS ${mathFuncTestTable} (
+            id INT,
+            single_val DOUBLE,
+            y_val DOUBLE,
+            x_val DOUBLE
+        )
+        DUPLICATE KEY(id)
+        DISTRIBUTED BY HASH(id) BUCKETS 1
+        PROPERTIES (
+            "replication_allocation" = "tag.location.default: 1"
+        )
+    """
+
+    sql """
+        INSERT INTO ${mathFuncTestTable} VALUES
+        (1, 0.0, 0.0, 1.0),
+        (2, 1.0, 1.0, 1.0),
+        (3, -1.0, -1.0, 1.0),
+        (4, 0.5, 1.0, 2.0),
+        (5, -0.5, -1.0, 2.0),
+        (6, 1.732050808, 3.0, 1.732050808),
+        (7, cast('nan' as double), cast('nan' as double), 1.0),
+        (8, cast('inf' as double), cast('inf' as double), 1.0),
+        (9, cast('-inf' as double), cast('-inf' as double), 1.0),
+        (10, 0.0, 0.0, 0.0),
+        (11, 1.0, 1.0, 0.0),
+        (12, -1.0, -1.0, 0.0),
+        (13, 0.0, 0.0, -1.0),
+        (14, 1.0, 1.0, -1.0),
+        (15, -1.0, -1.0, -1.0),
+        (16, cast('inf' as double), cast('inf' as double), cast('inf' as 
double)),
+        (17, cast('-inf' as double), cast('-inf' as double), cast('inf' as 
double)),
+        (18, cast('inf' as double), cast('inf' as double), cast('-inf' as 
double)),
+        (19, cast('-inf' as double), cast('-inf' as double), cast('-inf' as 
double)),
+        (20, 1.0, 1.0, cast('inf' as double)),
+        (21, -1.0, -1.0, cast('inf' as double)),
+        (22, 1.0, 1.0, cast('-inf' as double)),
+        (23, -1.0, -1.0, cast('-inf' as double)),
+        (24, -0.0, -0.0, 1.0),
+        (25, -0.0, -0.0, -1.0),
+        (26, -0.0, -0.0, 0.0),
+        (27, -0.0, -0.0, -0.0),
+        (28, 1e308, 1e308, 1.0),
+        (29, -1e308, -1e308, 1.0);
+    """
+
     testFoldConst""" 
         select
             atan2(cast('nan' as double), cast('1.0' as double)),
@@ -121,6 +172,94 @@ suite("test_math_function") {
             atan2(cast('-1.0' as double), cast('inf' as double))
     """
 
+    qt_select_atan_with_two_args""" 
+        select
+            atan(cast('nan' as double), cast('1.0' as double)),
+            atan(cast('1.0' as double), cast('nan' as double)),
+
+            atan(cast('0.0' as double),  cast('1.0' as double)),
+            atan(cast('-0.0' as double), cast('1.0' as double)),
+            atan(cast('0.0' as double),  cast('0.0' as double)),
+            atan(cast('-0.0' as double), cast('0.0' as double)),
+
+            atan(cast('0.0' as double),  cast('-1.0' as double)),
+            atan(cast('-0.0' as double), cast('-1.0' as double)),
+            atan(cast('0.0' as double),  cast('-0.0' as double)),
+            atan(cast('-0.0' as double), cast('-0.0' as double)),
+
+            atan(cast('1.0' as double),  cast('0.0' as double)),
+            atan(cast('1.0' as double),  cast('-0.0' as double)),
+            atan(cast('-1.0' as double), cast('0.0' as double)),
+            atan(cast('-1.0' as double), cast('-0.0' as double)),
+
+            atan(cast('inf' as double),  cast('1.0' as double)),
+            atan(cast('-inf' as double), cast('1.0' as double)),
+
+            atan(cast('inf' as double),  cast('-inf' as double)),
+            atan(cast('-inf' as double), cast('-inf' as double)),
+
+            atan(cast('inf' as double),  cast('inf' as double)),
+            atan(cast('-inf' as double), cast('inf' as double)),
+
+            atan(cast('1.0' as double),  cast('-inf' as double)),
+            atan(cast('-1.0' as double), cast('-inf' as double)),
+
+            atan(cast('1.0' as double),  cast('inf' as double)),
+            atan(cast('-1.0' as double), cast('inf' as double))
+    """
+
+    qt_select_atan_single_param_from_table """
+        SELECT atan(single_val) FROM ${mathFuncTestTable} ORDER BY id;
+    """
+
+    qt_select_atan_two_params_from_table """
+        SELECT atan(y_val, x_val) FROM ${mathFuncTestTable} ORDER BY id;
+    """
+
+    qt_select_atan_with_first_const """
+        SELECT atan(cast('1.0' as double), x_val) FROM ${mathFuncTestTable} 
ORDER BY id;
+    """
+
+    qt_select_atan_with_second_const """
+        SELECT atan(y_val, cast('1.0' as double)) FROM ${mathFuncTestTable} 
ORDER BY id;
+    """
+
+    testFoldConst""" 
+        select
+            atan(cast('nan' as double), cast('1.0' as double)),
+            atan(cast('1.0' as double), cast('nan' as double)),
+
+            atan(cast('0.0' as double),  cast('1.0' as double)),
+            atan(cast('-0.0' as double), cast('1.0' as double)),
+            atan(cast('0.0' as double),  cast('0.0' as double)),
+            atan(cast('-0.0' as double), cast('0.0' as double)),
+
+            atan(cast('0.0' as double),  cast('-1.0' as double)),
+            atan(cast('-0.0' as double), cast('-1.0' as double)),
+            atan(cast('0.0' as double),  cast('-0.0' as double)),
+            atan(cast('-0.0' as double), cast('-0.0' as double)),
+
+            atan(cast('1.0' as double),  cast('0.0' as double)),
+            atan(cast('1.0' as double),  cast('-0.0' as double)),
+            atan(cast('-1.0' as double), cast('0.0' as double)),
+            atan(cast('-1.0' as double), cast('-0.0' as double)),
+
+            atan(cast('inf' as double),  cast('1.0' as double)),
+            atan(cast('-inf' as double), cast('1.0' as double)),
+
+            atan(cast('inf' as double),  cast('-inf' as double)),
+            atan(cast('-inf' as double), cast('-inf' as double)),
+
+            atan(cast('inf' as double),  cast('inf' as double)),
+            atan(cast('-inf' as double), cast('inf' as double)),
+
+            atan(cast('1.0' as double),  cast('-inf' as double)),
+            atan(cast('-1.0' as double), cast('-inf' as double)),
+
+            atan(cast('1.0' as double),  cast('inf' as double)),
+            atan(cast('-1.0' as double), cast('inf' as double))
+    """
+
     testFoldConst""" select cbrt(cast('nan' as double)), cbrt(cast('inf' as 
double)), cbrt(cast('-inf' as double)) """
     qt_select_cbrt""" select cbrt(cast('nan' as double)), cbrt(cast('inf' as 
double)), cbrt(cast('-inf' as double)) """
 
@@ -154,4 +293,7 @@ suite("test_math_function") {
     testFoldConst""" select tanh(cast('nan' as double)), tanh(cast('inf' as 
double)), tanh(cast('-inf' as double)) """
     qt_select_tanh""" select tanh(cast('nan' as double)), tanh(cast('inf' as 
double)), tanh(cast('-inf' as double)) """
 
+    sql """
+        DROP TABLE IF EXISTS ${mathFuncTestTable};
+    """
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to