This is an automated email from the ASF dual-hosted git repository.
zclll pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new c7637239b18 [feat](func) Support two-args version of atan (#56561)
c7637239b18 is described below
commit c7637239b18b7a54c4d15243f5bb8cb70d32e155
Author: linrrarity <[email protected]>
AuthorDate: Fri Oct 10 18:02:21 2025 +0800
[feat](func) Support two-args version of atan (#56561)
### What problem does this PR solve?
Issue Number: close #xxx
Related PR: #48203
Problem Summary:
Support two-args version of atan, which behavior is the same as atan2.
```text
mysql> select atan(-1.0, cast('inf' as double));
+-----------------------------------+
| atan(-1.0, cast('inf' as double)) |
+-----------------------------------+
| -0 |
+-----------------------------------+
mysql> select atan(cast('nan' as double), 1.0);
+----------------------------------+
| atan(cast('nan' as double), 1.0) |
+----------------------------------+
| NaN |
+----------------------------------+
mysql> select atan(cast('-inf' as double), cast('inf' as double));
+-----------------------------------------------------+
| atan(cast('-inf' as double), cast('inf' as double)) |
+-----------------------------------------------------+
| -0.7853981633974483 |
+-----------------------------------------------------+
```
---
be/src/vec/functions/math.cpp | 80 +++++++++++-
.../functions/executable/NumericArithmetic.java | 8 ++
.../trees/expressions/functions/scalar/Atan.java | 12 +-
.../data/function_p0/test_math_function.out | 127 ++++++++++++++++++
.../suites/function_p0/test_math_function.groovy | 142 +++++++++++++++++++++
5 files changed, 365 insertions(+), 4 deletions(-)
diff --git a/be/src/vec/functions/math.cpp b/be/src/vec/functions/math.cpp
index a8ae52fa75a..87117b80f2c 100644
--- a/be/src/vec/functions/math.cpp
+++ b/be/src/vec/functions/math.cpp
@@ -71,10 +71,86 @@ struct AsinhName {
};
using FunctionAsinh = FunctionMathUnary<UnaryFunctionPlain<AsinhName,
std::asinh>>;
-struct AtanName {
+class FunctionAtan : public IFunction {
+public:
static constexpr auto name = "atan";
+ static FunctionPtr create() { return std::make_shared<FunctionAtan>(); }
+
+ String get_name() const override { return name; }
+ bool is_variadic() const override { return true; }
+ size_t get_number_of_arguments() const override { return 0; }
+
+ DataTypePtr get_return_type_impl(const DataTypes& arguments) const
override {
+ return std::make_shared<DataTypeFloat64>();
+ }
+
+ Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
+ uint32_t result, size_t input_rows_count) const
override {
+ if (arguments.size() == 1) {
+ return execute_unary(block, arguments, result, input_rows_count);
+ } else if (arguments.size() == 2) {
+ return execute_binary(block, arguments, result, input_rows_count);
+ } else {
+ return Status::InvalidArgument("atan function expects 1 or 2
arguments, but got {}",
+ arguments.size());
+ }
+ }
+
+private:
+ Status execute_unary(Block& block, const ColumnNumbers& arguments,
uint32_t result,
+ size_t input_rows_count) const {
+ auto res_col = ColumnFloat64::create(input_rows_count);
+ auto& res_data = res_col->get_data();
+
+ const auto& col_data =
+ assert_cast<const
ColumnFloat64*>(block.get_by_position(arguments[0]).column.get())
+ ->get_data();
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ res_data[i] = std::atan(col_data[i]);
+ }
+
+ block.replace_by_position(result, std::move(res_col));
+ return Status::OK();
+ }
+
+ Status execute_binary(Block& block, const ColumnNumbers& arguments,
uint32_t result,
+ size_t input_rows_count) const {
+ auto [col_y, is_const_y] =
unpack_if_const(block.get_by_position(arguments[0]).column);
+ auto [col_x, is_const_x] =
unpack_if_const(block.get_by_position(arguments[1]).column);
+
+ auto result_column = ColumnFloat64::create(input_rows_count);
+ auto& result_data = result_column->get_data();
+
+ if (is_const_y) {
+ auto y_val = assert_cast<const
ColumnFloat64*>(col_y.get())->get_element(0);
+
+ const auto* x_col = assert_cast<const ColumnFloat64*>(col_x.get());
+ const auto& x_data = x_col->get_data();
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ result_data[i] = std::atan2(y_val, x_data[i]);
+ }
+ } else if (is_const_x) {
+ auto x_val = assert_cast<const
ColumnFloat64*>(col_x.get())->get_element(0);
+
+ const auto* y_col = assert_cast<const ColumnFloat64*>(col_y.get());
+ const auto& y_data = y_col->get_data();
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ result_data[i] = std::atan2(y_data[i], x_val);
+ }
+ } else {
+ const auto* y_col = assert_cast<const ColumnFloat64*>(col_y.get());
+ const auto* x_col = assert_cast<const ColumnFloat64*>(col_x.get());
+ const auto& y_data = y_col->get_data();
+ const auto& x_data = x_col->get_data();
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ result_data[i] = std::atan2(y_data[i], x_data[i]);
+ }
+ }
+
+ block.replace_by_position(result, std::move(result_column));
+ return Status::OK();
+ }
};
-using FunctionAtan = FunctionMathUnary<UnaryFunctionPlain<AtanName,
std::atan>>;
struct AtanhName {
static constexpr auto name = "atanh";
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java
index 56c8a1f6c40..c68d2456ac8 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/NumericArithmetic.java
@@ -647,6 +647,14 @@ public class NumericArithmetic {
return new DoubleLiteral(Math.atan(first.getValue()));
}
+ /**
+ * atan
+ */
+ @ExecFunction(name = "atan")
+ public static Expression atan(DoubleLiteral first, DoubleLiteral second) {
+ return new DoubleLiteral(Math.atan2(first.getValue(),
second.getValue()));
+ }
+
/**
* asinh
*/
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Atan.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Atan.java
index 2750fbd81c1..e131129ab26 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Atan.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Atan.java
@@ -37,7 +37,8 @@ public class Atan extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature,
PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
-
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
+
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE,
DoubleType.INSTANCE)
);
/**
@@ -47,6 +48,13 @@ public class Atan extends ScalarFunction
super("atan", arg);
}
+ /**
+ * constructor with 2 argument.
+ */
+ public Atan(Expression arg0, Expression arg1) {
+ super("atan", arg0, arg1);
+ }
+
/** constructor for withChildren and reuse signature */
private Atan(ScalarFunctionParams functionParams) {
super(functionParams);
@@ -57,7 +65,7 @@ public class Atan extends ScalarFunction
*/
@Override
public Atan withChildren(List<Expression> children) {
- Preconditions.checkArgument(children.size() == 1);
+ Preconditions.checkArgument(children.size() == 1 || children.size() ==
2);
return new Atan(getFunctionParams(children));
}
diff --git a/regression-test/data/function_p0/test_math_function.out
b/regression-test/data/function_p0/test_math_function.out
index 3490fa98f1f..3c3cebd5799 100644
--- a/regression-test/data/function_p0/test_math_function.out
+++ b/regression-test/data/function_p0/test_math_function.out
@@ -26,6 +26,133 @@ NaN 1.570796326794897 -1.570796326794897 0.0
-0.0 0.7853981633974483 -0.78539816
-- !select_atan2 --
NaN NaN 0.0 -0.0 0.0 -0.0 3.141592653589793
-3.141592653589793 3.141592653589793 -3.141592653589793
1.570796326794897 1.570796326794897 -1.570796326794897
-1.570796326794897 1.570796326794897 -1.570796326794897
2.356194490192345 -2.356194490192345 0.7853981633974483
-0.7853981633974483 3.141592653589793 -3.141592653589793 0.0
-0.0
+-- !select_atan_with_two_args --
+NaN NaN 0.0 -0.0 0.0 -0.0 3.141592653589793
-3.141592653589793 3.141592653589793 -3.141592653589793
1.570796326794897 1.570796326794897 -1.570796326794897
-1.570796326794897 1.570796326794897 -1.570796326794897
2.356194490192345 -2.356194490192345 0.7853981633974483
-0.7853981633974483 3.141592653589793 -3.141592653589793 0.0
-0.0
+
+-- !select_atan_single_param_from_table --
+0
+0.7853981633974483
+-0.7853981633974483
+0.4636476090008061
+-0.4636476090008061
+1.047197551304379
+NaN
+1.570796326794897
+-1.570796326794897
+0
+0.7853981633974483
+-0.7853981633974483
+0
+0.7853981633974483
+-0.7853981633974483
+1.570796326794897
+-1.570796326794897
+1.570796326794897
+-1.570796326794897
+0.7853981633974483
+-0.7853981633974483
+0.7853981633974483
+-0.7853981633974483
+0
+0
+0
+0
+1.570796326794897
+-1.570796326794897
+
+-- !select_atan_two_params_from_table --
+0
+0.7853981633974483
+-0.7853981633974483
+0.4636476090008061
+-0.4636476090008061
+1.047197551088817
+NaN
+1.570796326794897
+-1.570796326794897
+0
+1.570796326794897
+-1.570796326794897
+3.141592653589793
+2.356194490192345
+-2.356194490192345
+0.7853981633974483
+-0.7853981633974483
+2.356194490192345
+-2.356194490192345
+0
+-0
+3.141592653589793
+-3.141592653589793
+0
+3.141592653589793
+0
+0
+1.570796326794897
+-1.570796326794897
+
+-- !select_atan_with_first_const --
+0.7853981633974483
+0.7853981633974483
+0.7853981633974483
+0.4636476090008061
+0.4636476090008061
+0.5235987754905181
+0.7853981633974483
+0.7853981633974483
+0.7853981633974483
+1.570796326794897
+1.570796326794897
+1.570796326794897
+2.356194490192345
+2.356194490192345
+2.356194490192345
+0
+0
+3.141592653589793
+3.141592653589793
+0
+0
+3.141592653589793
+3.141592653589793
+0.7853981633974483
+2.356194490192345
+1.570796326794897
+1.570796326794897
+0.7853981633974483
+0.7853981633974483
+
+-- !select_atan_with_second_const --
+0
+0.7853981633974483
+-0.7853981633974483
+0.7853981633974483
+-0.7853981633974483
+1.249045772398254
+NaN
+1.570796326794897
+-1.570796326794897
+0
+0.7853981633974483
+-0.7853981633974483
+0
+0.7853981633974483
+-0.7853981633974483
+1.570796326794897
+-1.570796326794897
+1.570796326794897
+-1.570796326794897
+0.7853981633974483
+-0.7853981633974483
+0.7853981633974483
+-0.7853981633974483
+0
+0
+0
+0
+1.570796326794897
+-1.570796326794897
+
-- !select_cbrt --
NaN Infinity -Infinity
diff --git a/regression-test/suites/function_p0/test_math_function.groovy
b/regression-test/suites/function_p0/test_math_function.groovy
index 2d0e26eaaa6..466820c2152 100644
--- a/regression-test/suites/function_p0/test_math_function.groovy
+++ b/regression-test/suites/function_p0/test_math_function.groovy
@@ -49,6 +49,57 @@ suite("test_math_function") {
qt_select_atan""" select atan(cast('nan' as double)), atan(cast('inf' as
double)), atan(cast('-inf' as double)), atan(cast('0.0' as double)),
atan(cast('-0.0' as double)), atan(cast('1.0' as double)), atan(cast('-1.0' as
double)), atan(cast('1e308' as double)), atan(cast('-1e308' as double)) """
+ def mathFuncTestTable = "math_function_test_table";
+
+ sql """ DROP TABLE IF EXISTS ${mathFuncTestTable}; """
+
+ sql """
+ CREATE TABLE IF NOT EXISTS ${mathFuncTestTable} (
+ id INT,
+ single_val DOUBLE,
+ y_val DOUBLE,
+ x_val DOUBLE
+ )
+ DUPLICATE KEY(id)
+ DISTRIBUTED BY HASH(id) BUCKETS 1
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1"
+ )
+ """
+
+ sql """
+ INSERT INTO ${mathFuncTestTable} VALUES
+ (1, 0.0, 0.0, 1.0),
+ (2, 1.0, 1.0, 1.0),
+ (3, -1.0, -1.0, 1.0),
+ (4, 0.5, 1.0, 2.0),
+ (5, -0.5, -1.0, 2.0),
+ (6, 1.732050808, 3.0, 1.732050808),
+ (7, cast('nan' as double), cast('nan' as double), 1.0),
+ (8, cast('inf' as double), cast('inf' as double), 1.0),
+ (9, cast('-inf' as double), cast('-inf' as double), 1.0),
+ (10, 0.0, 0.0, 0.0),
+ (11, 1.0, 1.0, 0.0),
+ (12, -1.0, -1.0, 0.0),
+ (13, 0.0, 0.0, -1.0),
+ (14, 1.0, 1.0, -1.0),
+ (15, -1.0, -1.0, -1.0),
+ (16, cast('inf' as double), cast('inf' as double), cast('inf' as
double)),
+ (17, cast('-inf' as double), cast('-inf' as double), cast('inf' as
double)),
+ (18, cast('inf' as double), cast('inf' as double), cast('-inf' as
double)),
+ (19, cast('-inf' as double), cast('-inf' as double), cast('-inf' as
double)),
+ (20, 1.0, 1.0, cast('inf' as double)),
+ (21, -1.0, -1.0, cast('inf' as double)),
+ (22, 1.0, 1.0, cast('-inf' as double)),
+ (23, -1.0, -1.0, cast('-inf' as double)),
+ (24, -0.0, -0.0, 1.0),
+ (25, -0.0, -0.0, -1.0),
+ (26, -0.0, -0.0, 0.0),
+ (27, -0.0, -0.0, -0.0),
+ (28, 1e308, 1e308, 1.0),
+ (29, -1e308, -1e308, 1.0);
+ """
+
testFoldConst"""
select
atan2(cast('nan' as double), cast('1.0' as double)),
@@ -121,6 +172,94 @@ suite("test_math_function") {
atan2(cast('-1.0' as double), cast('inf' as double))
"""
+ qt_select_atan_with_two_args"""
+ select
+ atan(cast('nan' as double), cast('1.0' as double)),
+ atan(cast('1.0' as double), cast('nan' as double)),
+
+ atan(cast('0.0' as double), cast('1.0' as double)),
+ atan(cast('-0.0' as double), cast('1.0' as double)),
+ atan(cast('0.0' as double), cast('0.0' as double)),
+ atan(cast('-0.0' as double), cast('0.0' as double)),
+
+ atan(cast('0.0' as double), cast('-1.0' as double)),
+ atan(cast('-0.0' as double), cast('-1.0' as double)),
+ atan(cast('0.0' as double), cast('-0.0' as double)),
+ atan(cast('-0.0' as double), cast('-0.0' as double)),
+
+ atan(cast('1.0' as double), cast('0.0' as double)),
+ atan(cast('1.0' as double), cast('-0.0' as double)),
+ atan(cast('-1.0' as double), cast('0.0' as double)),
+ atan(cast('-1.0' as double), cast('-0.0' as double)),
+
+ atan(cast('inf' as double), cast('1.0' as double)),
+ atan(cast('-inf' as double), cast('1.0' as double)),
+
+ atan(cast('inf' as double), cast('-inf' as double)),
+ atan(cast('-inf' as double), cast('-inf' as double)),
+
+ atan(cast('inf' as double), cast('inf' as double)),
+ atan(cast('-inf' as double), cast('inf' as double)),
+
+ atan(cast('1.0' as double), cast('-inf' as double)),
+ atan(cast('-1.0' as double), cast('-inf' as double)),
+
+ atan(cast('1.0' as double), cast('inf' as double)),
+ atan(cast('-1.0' as double), cast('inf' as double))
+ """
+
+ qt_select_atan_single_param_from_table """
+ SELECT atan(single_val) FROM ${mathFuncTestTable} ORDER BY id;
+ """
+
+ qt_select_atan_two_params_from_table """
+ SELECT atan(y_val, x_val) FROM ${mathFuncTestTable} ORDER BY id;
+ """
+
+ qt_select_atan_with_first_const """
+ SELECT atan(cast('1.0' as double), x_val) FROM ${mathFuncTestTable}
ORDER BY id;
+ """
+
+ qt_select_atan_with_second_const """
+ SELECT atan(y_val, cast('1.0' as double)) FROM ${mathFuncTestTable}
ORDER BY id;
+ """
+
+ testFoldConst"""
+ select
+ atan(cast('nan' as double), cast('1.0' as double)),
+ atan(cast('1.0' as double), cast('nan' as double)),
+
+ atan(cast('0.0' as double), cast('1.0' as double)),
+ atan(cast('-0.0' as double), cast('1.0' as double)),
+ atan(cast('0.0' as double), cast('0.0' as double)),
+ atan(cast('-0.0' as double), cast('0.0' as double)),
+
+ atan(cast('0.0' as double), cast('-1.0' as double)),
+ atan(cast('-0.0' as double), cast('-1.0' as double)),
+ atan(cast('0.0' as double), cast('-0.0' as double)),
+ atan(cast('-0.0' as double), cast('-0.0' as double)),
+
+ atan(cast('1.0' as double), cast('0.0' as double)),
+ atan(cast('1.0' as double), cast('-0.0' as double)),
+ atan(cast('-1.0' as double), cast('0.0' as double)),
+ atan(cast('-1.0' as double), cast('-0.0' as double)),
+
+ atan(cast('inf' as double), cast('1.0' as double)),
+ atan(cast('-inf' as double), cast('1.0' as double)),
+
+ atan(cast('inf' as double), cast('-inf' as double)),
+ atan(cast('-inf' as double), cast('-inf' as double)),
+
+ atan(cast('inf' as double), cast('inf' as double)),
+ atan(cast('-inf' as double), cast('inf' as double)),
+
+ atan(cast('1.0' as double), cast('-inf' as double)),
+ atan(cast('-1.0' as double), cast('-inf' as double)),
+
+ atan(cast('1.0' as double), cast('inf' as double)),
+ atan(cast('-1.0' as double), cast('inf' as double))
+ """
+
testFoldConst""" select cbrt(cast('nan' as double)), cbrt(cast('inf' as
double)), cbrt(cast('-inf' as double)) """
qt_select_cbrt""" select cbrt(cast('nan' as double)), cbrt(cast('inf' as
double)), cbrt(cast('-inf' as double)) """
@@ -154,4 +293,7 @@ suite("test_math_function") {
testFoldConst""" select tanh(cast('nan' as double)), tanh(cast('inf' as
double)), tanh(cast('-inf' as double)) """
qt_select_tanh""" select tanh(cast('nan' as double)), tanh(cast('inf' as
double)), tanh(cast('-inf' as double)) """
+ sql """
+ DROP TABLE IF EXISTS ${mathFuncTestTable};
+ """
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]