This is an automated email from the ASF dual-hosted git repository.
zhangstar333 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new a36e088e781 [enhancement](function truncate) truncate can use column
as scale argument (#32746)
a36e088e781 is described below
commit a36e088e781f2fd47f4c3672b314f82a23d6e16f
Author: zhiqiang <[email protected]>
AuthorDate: Tue Apr 2 14:56:26 2024 +0800
[enhancement](function truncate) truncate can use column as scale argument
(#32746)
Co-authored-by: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
---
be/src/vec/functions/function_truncate.h | 245 ++++++++++++++
be/src/vec/functions/math.cpp | 23 +-
be/src/vec/functions/round.h | 224 ++++++++++++-
.../function/function_truncate_decimal_test.cpp | 370 +++++++++++++++++++++
.../apache/doris/analysis/FunctionCallExpr.java | 32 +-
.../functions/ComputePrecisionForRound.java | 40 ++-
.../math_functions/test_function_truncate.out | 101 ++++++
.../math_functions/test_function_truncate.groovy | 132 ++++++++
8 files changed, 1136 insertions(+), 31 deletions(-)
diff --git a/be/src/vec/functions/function_truncate.h
b/be/src/vec/functions/function_truncate.h
new file mode 100644
index 00000000000..e29bc99c041
--- /dev/null
+++ b/be/src/vec/functions/function_truncate.h
@@ -0,0 +1,245 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstddef>
+#include <functional>
+#include <type_traits>
+#include <utility>
+
+#include "common/exception.h"
+#include "common/status.h"
+#include "olap/olap_common.h"
+#include "round.h"
+#include "vec/columns/column.h"
+#include "vec/columns/column_const.h"
+#include "vec/columns/column_decimal.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/call_on_type_index.h"
+#include "vec/core/field.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_number.h"
+
+namespace doris::vectorized {
+
+struct TruncateFloatOneArgImpl {
+ static constexpr auto name = "truncate";
+ static DataTypes get_variadic_argument_types() { return
{std::make_shared<DataTypeFloat64>()}; }
+};
+
+struct TruncateFloatTwoArgImpl {
+ static constexpr auto name = "truncate";
+ static DataTypes get_variadic_argument_types() {
+ return {std::make_shared<DataTypeFloat64>(),
std::make_shared<DataTypeInt32>()};
+ }
+};
+
+struct TruncateDecimalOneArgImpl {
+ static constexpr auto name = "truncate";
+ static DataTypes get_variadic_argument_types() {
+ // All Decimal types are named Decimal, and real scale will be passed
as type argument for execute function
+ // So we can just register Decimal32 here
+ return {std::make_shared<DataTypeDecimal<Decimal32>>(9, 0)};
+ }
+};
+
+struct TruncateDecimalTwoArgImpl {
+ static constexpr auto name = "truncate";
+ static DataTypes get_variadic_argument_types() {
+ return {std::make_shared<DataTypeDecimal<Decimal32>>(9, 0),
+ std::make_shared<DataTypeInt32>()};
+ }
+};
+
+template <typename Impl>
+class FunctionTruncate : public FunctionRounding<Impl, RoundingMode::Trunc,
TieBreakingMode::Auto> {
+public:
+ static FunctionPtr create() { return std::make_shared<FunctionTruncate>();
}
+
+ ColumnNumbers get_arguments_that_are_always_constant() const override {
return {}; }
+ // SELECT number, truncate(123.345, 1) FROM number("numbers"="10")
+ // should NOT behave like two column arguments, so we can not use const
column default implementation
+ bool use_default_implementation_for_constants() const override { return
false; }
+
+ Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
+ size_t result, size_t input_rows_count) const override
{
+ const ColumnWithTypeAndName& column_general =
block.get_by_position(arguments[0]);
+ ColumnPtr res;
+
+ // potential argument types:
+ // 0. truncate(ColumnConst, ColumnConst)
+ // 1. truncate(Column), truncate(Column, ColumnConst)
+ // 2. truncate(Column, Column)
+ // 3. truncate(ColumnConst, Column)
+
+ if (arguments.size() == 2 &&
is_column_const(*block.get_by_position(arguments[0]).column) &&
+ is_column_const(*block.get_by_position(arguments[1]).column)) {
+ // truncate(ColumnConst, ColumnConst)
+ auto col_general =
+ assert_cast<const
ColumnConst&>(*column_general.column).get_data_column_ptr();
+ Int16 scale_arg = 0;
+ RETURN_IF_ERROR(FunctionTruncate<Impl>::get_scale_arg(
+ block.get_by_position(arguments[1]), &scale_arg));
+
+ auto call = [&](const auto& types) -> bool {
+ using Types = std::decay_t<decltype(types)>;
+ using DataType = typename Types::LeftType;
+
+ if constexpr (IsDataTypeNumber<DataType> ||
IsDataTypeDecimal<DataType>) {
+ using FieldType = typename DataType::FieldType;
+ res = Dispatcher<FieldType, RoundingMode::Trunc,
+
TieBreakingMode::Auto>::apply_vec_const(col_general,
+
scale_arg);
+ return true;
+ }
+
+ return false;
+ };
+
+#if !defined(__SSE4_1__) && !defined(__aarch64__)
+ /// In case of "nearbyint" function is used, we should ensure the
expected rounding mode for the Banker's rounding.
+ /// Actually it is by default. But we will set it just in case.
+
+ if constexpr (rounding_mode == RoundingMode::Round) {
+ if (0 != fesetround(FE_TONEAREST)) {
+ return Status::InvalidArgument("Cannot set floating point
rounding mode");
+ }
+ }
+#endif
+
+ if
(!call_on_index_and_data_type<void>(column_general.type->get_type_id(), call)) {
+ return Status::InvalidArgument("Invalid argument type {} for
function {}",
+
column_general.type->get_name(), "truncate");
+ }
+ // Important, make sure the result column has the same size as the
input column
+ res = ColumnConst::create(std::move(res), input_rows_count);
+ } else if (arguments.size() == 1 ||
+ (arguments.size() == 2 &&
+
is_column_const(*block.get_by_position(arguments[1]).column))) {
+ // truncate(Column) or truncate(Column, ColumnConst)
+ Int16 scale_arg = 0;
+ if (arguments.size() == 2) {
+ RETURN_IF_ERROR(FunctionTruncate<Impl>::get_scale_arg(
+ block.get_by_position(arguments[1]), &scale_arg));
+ }
+
+ auto call = [&](const auto& types) -> bool {
+ using Types = std::decay_t<decltype(types)>;
+ using DataType = typename Types::LeftType;
+
+ if constexpr (IsDataTypeNumber<DataType> ||
IsDataTypeDecimal<DataType>) {
+ using FieldType = typename DataType::FieldType;
+ res = Dispatcher<FieldType, RoundingMode::Trunc,
TieBreakingMode::Auto>::
+ apply_vec_const(column_general.column.get(),
scale_arg);
+ return true;
+ }
+
+ return false;
+ };
+#if !defined(__SSE4_1__) && !defined(__aarch64__)
+ /// In case of "nearbyint" function is used, we should ensure the
expected rounding mode for the Banker's rounding.
+ /// Actually it is by default. But we will set it just in case.
+
+ if constexpr (rounding_mode == RoundingMode::Round) {
+ if (0 != fesetround(FE_TONEAREST)) {
+ return Status::InvalidArgument("Cannot set floating point
rounding mode");
+ }
+ }
+#endif
+
+ if
(!call_on_index_and_data_type<void>(column_general.type->get_type_id(), call)) {
+ return Status::InvalidArgument("Invalid argument type {} for
function {}",
+
column_general.type->get_name(), "truncate");
+ }
+
+ } else if
(is_column_const(*block.get_by_position(arguments[0]).column)) {
+ // truncate(ColumnConst, Column)
+ const ColumnWithTypeAndName& column_scale =
block.get_by_position(arguments[1]);
+ const ColumnConst& const_col_general =
+ assert_cast<const ColumnConst&>(*column_general.column);
+
+ auto call = [&](const auto& types) -> bool {
+ using Types = std::decay_t<decltype(types)>;
+ using DataType = typename Types::LeftType;
+
+ if constexpr (IsDataTypeNumber<DataType> ||
IsDataTypeDecimal<DataType>) {
+ using FieldType = typename DataType::FieldType;
+ res = Dispatcher<FieldType, RoundingMode::Trunc,
TieBreakingMode::Auto>::
+ apply_const_vec(&const_col_general,
column_scale.column.get());
+ return true;
+ }
+
+ return false;
+ };
+
+#if !defined(__SSE4_1__) && !defined(__aarch64__)
+ /// In case of "nearbyint" function is used, we should ensure the
expected rounding mode for the Banker's rounding.
+ /// Actually it is by default. But we will set it just in case.
+
+ if constexpr (rounding_mode == RoundingMode::Round) {
+ if (0 != fesetround(FE_TONEAREST)) {
+ return Status::InvalidArgument("Cannot set floating point
rounding mode");
+ }
+ }
+#endif
+
+ if
(!call_on_index_and_data_type<void>(column_general.type->get_type_id(), call)) {
+ return Status::InvalidArgument("Invalid argument type {} for
function {}",
+
column_general.type->get_name(), "truncate");
+ }
+ } else {
+ // truncate(Column, Column)
+ const ColumnWithTypeAndName& column_scale =
block.get_by_position(arguments[1]);
+
+ auto call = [&](const auto& types) -> bool {
+ using Types = std::decay_t<decltype(types)>;
+ using DataType = typename Types::LeftType;
+
+ if constexpr (IsDataTypeNumber<DataType> ||
IsDataTypeDecimal<DataType>) {
+ using FieldType = typename DataType::FieldType;
+ res = Dispatcher<FieldType, RoundingMode::Trunc,
TieBreakingMode::Auto>::
+ apply_vec_vec(column_general.column.get(),
column_scale.column.get());
+ return true;
+ }
+ return false;
+ };
+
+#if !defined(__SSE4_1__) && !defined(__aarch64__)
+ /// In case of "nearbyint" function is used, we should ensure the
expected rounding mode for the Banker's rounding.
+ /// Actually it is by default. But we will set it just in case.
+
+ if constexpr (rounding_mode == RoundingMode::Round) {
+ if (0 != fesetround(FE_TONEAREST)) {
+ return Status::InvalidArgument("Cannot set floating point
rounding mode");
+ }
+ }
+#endif
+
+ if
(!call_on_index_and_data_type<void>(column_general.type->get_type_id(), call)) {
+ return Status::InvalidArgument("Invalid argument type {} for
function {}",
+
column_general.type->get_name(), "truncate");
+ }
+ }
+
+ block.replace_by_position(result, std::move(res));
+ return Status::OK();
+ }
+};
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/functions/math.cpp b/be/src/vec/functions/math.cpp
index dc815cf74e5..c0dfe761576 100644
--- a/be/src/vec/functions/math.cpp
+++ b/be/src/vec/functions/math.cpp
@@ -46,6 +46,7 @@
#include "vec/functions/function_math_unary.h"
#include "vec/functions/function_string.h"
#include "vec/functions/function_totype.h"
+#include "vec/functions/function_truncate.h"
#include "vec/functions/function_unary_arithmetic.h"
#include "vec/functions/round.h"
#include "vec/functions/simple_function_factory.h"
@@ -392,16 +393,14 @@ struct DecimalRoundOneImpl {
// TODO: Now math may cause one thread compile time too long, because the
function in math
// so mush. Split it to speed up compile time in the future
void register_function_math(SimpleFunctionFactory& factory) {
-#define REGISTER_ROUND_FUNCTIONS(IMPL)
\
- factory.register_function<
\
- FunctionRounding<IMPL<RoundName>, RoundingMode::Round,
TieBreakingMode::Auto>>(); \
- factory.register_function<
\
- FunctionRounding<IMPL<FloorName>, RoundingMode::Floor,
TieBreakingMode::Auto>>(); \
- factory.register_function<
\
- FunctionRounding<IMPL<CeilName>, RoundingMode::Ceil,
TieBreakingMode::Auto>>(); \
- factory.register_function<
\
- FunctionRounding<IMPL<TruncateName>, RoundingMode::Trunc,
TieBreakingMode::Auto>>(); \
- factory.register_function<FunctionRounding<IMPL<RoundBankersName>,
RoundingMode::Round, \
+#define REGISTER_ROUND_FUNCTIONS(IMPL)
\
+ factory.register_function<
\
+ FunctionRounding<IMPL<RoundName>, RoundingMode::Round,
TieBreakingMode::Auto>>(); \
+ factory.register_function<
\
+ FunctionRounding<IMPL<FloorName>, RoundingMode::Floor,
TieBreakingMode::Auto>>(); \
+ factory.register_function<
\
+ FunctionRounding<IMPL<CeilName>, RoundingMode::Ceil,
TieBreakingMode::Auto>>(); \
+ factory.register_function<FunctionRounding<IMPL<RoundBankersName>,
RoundingMode::Round, \
TieBreakingMode::Bankers>>();
REGISTER_ROUND_FUNCTIONS(DecimalRoundOneImpl)
@@ -445,5 +444,9 @@ void register_function_math(SimpleFunctionFactory& factory)
{
factory.register_function<FunctionRadians>();
factory.register_function<FunctionDegrees>();
factory.register_function<FunctionBin>();
+ factory.register_function<FunctionTruncate<TruncateFloatOneArgImpl>>();
+ factory.register_function<FunctionTruncate<TruncateFloatTwoArgImpl>>();
+ factory.register_function<FunctionTruncate<TruncateDecimalOneArgImpl>>();
+ factory.register_function<FunctionTruncate<TruncateDecimalTwoArgImpl>>();
}
} // namespace doris::vectorized
diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h
index 7e48b8e9306..a9d1e7a019c 100644
--- a/be/src/vec/functions/round.h
+++ b/be/src/vec/functions/round.h
@@ -20,8 +20,15 @@
#pragma once
+#include <cstddef>
+#include <cstdint>
+
+#include "common/exception.h"
+#include "common/status.h"
#include "vec/columns/column_const.h"
#include "vec/columns/columns_number.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/types.h"
#include "vec/functions/function.h"
#if defined(__SSE4_1__) || defined(__aarch64__)
#include "util/sse_util.hpp"
@@ -176,6 +183,23 @@ public:
memcpy(out.data(), in.data(), in.size() * sizeof(T));
}
}
+
+ // NOTE: This function is only tested for truncate
+ // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW
EXACTLY WHAT YOU ARE DOING !!!
+ static NO_INLINE void apply(const NativeType& in, UInt32 in_scale,
NativeType& out,
+ Int16 out_scale) {
+ Int16 scale_arg = in_scale - out_scale;
+ if (scale_arg > 0) {
+ size_t scale = int_exp10(scale_arg);
+ if (out_scale < 0) {
+ Op::compute(&in, scale, &out, int_exp10(-out_scale));
+ } else {
+ Op::compute(&in, scale, &out, 1);
+ }
+ } else {
+ memcpy(&out, &in, sizeof(NativeType));
+ }
+ }
};
template <TieBreakingMode tie_breaking_mode>
@@ -314,6 +338,11 @@ public:
memcpy(p_out, &tmp_dst, tail_size_bytes);
}
}
+
+ static NO_INLINE void apply(const T& in, size_t scale, T& out) {
+ auto mm_scale = Op::prepare(scale);
+ Op::compute(&in, mm_scale, &out);
+ }
};
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode,
@@ -386,6 +415,10 @@ public:
__builtin_unreachable();
}
}
+
+ static NO_INLINE void apply(const T& in, size_t scale, T& out) {
+ Op::compute(&in, scale, &out, 1);
+ }
};
/** Select the appropriate processing algorithm depending on the scale.
@@ -400,7 +433,7 @@ struct Dispatcher {
FloatRoundingImpl<T, rounding_mode, scale_mode,
tie_breaking_mode>,
IntegerRoundingImpl<T, rounding_mode, scale_mode,
tie_breaking_mode>>>;
- static ColumnPtr apply(const IColumn* col_general, Int16 scale_arg) {
+ static ColumnPtr apply_vec_const(const IColumn* col_general, Int16
scale_arg) {
if constexpr (IsNumber<T>) {
const auto* const col =
check_and_get_column<ColumnVector<T>>(col_general);
auto col_res = ColumnVector<T>::create();
@@ -446,6 +479,179 @@ struct Dispatcher {
return nullptr;
}
}
+
+ // NOTE: This function is only tested for truncate
+ // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW
EXACTLY WHAT YOU ARE DOING !!!
+ static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn*
col_scale) {
+ if constexpr (rounding_mode != RoundingMode::Trunc) {
+ throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
+ "Using column as scale is only supported
for function truncate");
+ }
+
+ const ColumnInt32& col_scale_i32 = assert_cast<const
ColumnInt32&>(*col_scale);
+ const size_t input_row_count = col_scale_i32.size();
+ for (size_t i = 0; i < input_row_count; ++i) {
+ const Int32 scale_arg = col_scale_i32.get_data()[i];
+ if (scale_arg > std::numeric_limits<Int16>::max() ||
+ scale_arg < std::numeric_limits<Int16>::min()) {
+ throw doris::Exception(ErrorCode::OUT_OF_BOUND,
+ "Scale argument for function is out of
bound: {}",
+ scale_arg);
+ }
+ }
+
+ if constexpr (IsNumber<T>) {
+ const auto* col = assert_cast<const ColumnVector<T>*>(col_general);
+ auto col_res = ColumnVector<T>::create();
+ typename ColumnVector<T>::Container& vec_res = col_res->get_data();
+ vec_res.resize(input_row_count);
+
+ for (size_t i = 0; i < input_row_count; ++i) {
+ const Int32 scale_arg = col_scale_i32.get_data()[i];
+ if (scale_arg == 0) {
+ size_t scale = 1;
+
FunctionRoundingImpl<ScaleMode::Zero>::apply(col->get_data()[i], scale,
+ vec_res[i]);
+ } else if (scale_arg > 0) {
+ size_t scale = int_exp10(scale_arg);
+
FunctionRoundingImpl<ScaleMode::Positive>::apply(col->get_data()[i], scale,
+
vec_res[i]);
+ } else {
+ size_t scale = int_exp10(-scale_arg);
+
FunctionRoundingImpl<ScaleMode::Negative>::apply(col->get_data()[i], scale,
+
vec_res[i]);
+ }
+ }
+ return col_res;
+ } else if constexpr (IsDecimalNumber<T>) {
+ const auto* decimal_col = assert_cast<const
ColumnDecimal<T>*>(col_general);
+
+ // For truncate, ALWAYS use SAME scale with source Decimal column
+ const Int32 input_scale = decimal_col->get_scale();
+ auto col_res = ColumnDecimal<T>::create(input_row_count,
input_scale);
+
+ for (size_t i = 0; i < input_row_count; ++i) {
+ DecimalRoundingImpl<T, rounding_mode,
tie_breaking_mode>::apply(
+ decimal_col->get_element(i).value, input_scale,
+ col_res->get_element(i).value,
col_scale_i32.get_data()[i]);
+ }
+
+ for (size_t i = 0; i < input_row_count; ++i) {
+ // For truncate(ColumnDecimal, ColumnInt32), we should always
have same scale with source Decimal column
+ // So we need this check to make sure the result have correct
digits count
+ //
+ // Case 0: scale_arg <= -(integer part digits count)
+ // do nothing, because result is 0
+ // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits
count)
+ // decimal parts has been erased, so add them back by
multiply 10^(scale_arg)
+ // Case 2: scale_arg > 0 && scale_arg < decimal part digits
count
+ // decimal part now has scale_arg digits, so multiply
10^(input_scale - scal_arg)
+ // Case 3: scale_arg >= input_scale
+ // do nothing
+ const Int32 scale_arg = col_scale_i32.get_data()[i];
+ if (scale_arg <= 0) {
+ col_res->get_element(i).value *= int_exp10(input_scale);
+ } else if (scale_arg > 0 && scale_arg < input_scale) {
+ col_res->get_element(i).value *= int_exp10(input_scale -
scale_arg);
+ }
+ }
+
+ return col_res;
+ } else {
+ LOG(FATAL) << "__builtin_unreachable";
+ __builtin_unreachable();
+ return nullptr;
+ }
+ }
+
+ // NOTE: This function is only tested for truncate
+ // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW
EXACTLY WHAT YOU ARE DOING !!! only test for truncate
+ static ColumnPtr apply_const_vec(const ColumnConst* const_col_general,
+ const IColumn* col_scale) {
+ if constexpr (rounding_mode != RoundingMode::Trunc) {
+ throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
+ "Using column as scale is only supported
for function truncate");
+ }
+
+ const ColumnInt32& col_scale_i32 = assert_cast<const
ColumnInt32&>(*col_scale);
+ const size_t input_rows_count = col_scale->size();
+
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ const Int32 scale_arg = col_scale_i32.get_data()[i];
+
+ if (scale_arg > std::numeric_limits<Int16>::max() ||
+ scale_arg < std::numeric_limits<Int16>::min()) {
+ throw doris::Exception(ErrorCode::OUT_OF_BOUND,
+ "Scale argument for function is out of
bound: {}",
+ scale_arg);
+ }
+ }
+
+ if constexpr (IsDecimalNumber<T>) {
+ const ColumnDecimal<T>& data_col_general =
+ assert_cast<const
ColumnDecimal<T>&>(const_col_general->get_data_column());
+ const T& general_val = data_col_general.get_data()[0];
+ Int32 input_scale = data_col_general.get_scale();
+
+ auto col_res = ColumnDecimal<T>::create(input_rows_count,
input_scale);
+
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ DecimalRoundingImpl<T, rounding_mode,
tie_breaking_mode>::apply(
+ general_val, input_scale,
col_res->get_element(i).value,
+ col_scale_i32.get_data()[i]);
+ }
+
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ // For truncate(ColumnDecimal, ColumnInt32), we should always
have same scale with source Decimal column
+ // So we need this check to make sure the result have correct
digits count
+ //
+ // Case 0: scale_arg <= -(integer part digits count)
+ // do nothing, because result is 0
+ // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits
count)
+ // decimal parts has been erased, so add them back by
multiply 10^(scale_arg)
+ // Case 2: scale_arg > 0 && scale_arg < decimal part digits
count
+ // decimal part now has scale_arg digits, so multiply
10^(input_scale - scal_arg)
+ // Case 3: scale_arg >= input_scale
+ // do nothing
+ const Int32 scale_arg = col_scale_i32.get_data()[i];
+ if (scale_arg <= 0) {
+ col_res->get_element(i).value *= int_exp10(input_scale);
+ } else if (scale_arg > 0 && scale_arg < input_scale) {
+ col_res->get_element(i).value *= int_exp10(input_scale -
scale_arg);
+ }
+ }
+
+ return col_res;
+ } else if constexpr (IsNumber<T>) {
+ const ColumnVector<T>& data_col_general =
+ assert_cast<const
ColumnVector<T>&>(const_col_general->get_data_column());
+ const T& general_val = data_col_general.get_data()[0];
+ auto col_res = ColumnVector<T>::create(input_rows_count);
+ typename ColumnVector<T>::Container& vec_res = col_res->get_data();
+
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ const Int16 scale_arg = col_scale_i32.get_data()[i];
+ if (scale_arg == 0) {
+ size_t scale = 1;
+ FunctionRoundingImpl<ScaleMode::Zero>::apply(general_val,
scale, vec_res[i]);
+ } else if (scale_arg > 0) {
+ size_t scale = int_exp10(col_scale_i32.get_data()[i]);
+
FunctionRoundingImpl<ScaleMode::Positive>::apply(general_val, scale,
+
vec_res[i]);
+ } else {
+ size_t scale = int_exp10(-col_scale_i32.get_data()[i]);
+
FunctionRoundingImpl<ScaleMode::Negative>::apply(general_val, scale,
+
vec_res[i]);
+ }
+ }
+
+ return col_res;
+ } else {
+ throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
+ "Unsupported column {} for function
truncate",
+ const_col_general->get_name());
+ }
+ }
};
template <typename Impl, RoundingMode rounding_mode, TieBreakingMode
tie_breaking_mode>
@@ -476,17 +682,17 @@ public:
static Status get_scale_arg(const ColumnWithTypeAndName& arguments, Int16*
scale) {
const IColumn& scale_column = *arguments.column;
- Int32 scale64 = static_cast<const ColumnInt32&>(
- static_cast<const
ColumnConst*>(&scale_column)->get_data_column())
- .get_element(0);
+ Int32 scale_arg = assert_cast<const ColumnInt32&>(
+ assert_cast<const
ColumnConst*>(&scale_column)->get_data_column())
+ .get_element(0);
- if (scale64 > std::numeric_limits<Int16>::max() ||
- scale64 < std::numeric_limits<Int16>::min()) {
+ if (scale_arg > std::numeric_limits<Int16>::max() ||
+ scale_arg < std::numeric_limits<Int16>::min()) {
return Status::InvalidArgument("Scale argument for function {} is
out of bound: {}",
- name, scale64);
+ name, scale_arg);
}
- *scale = scale64;
+ *scale = scale_arg;
return Status::OK();
}
@@ -507,7 +713,7 @@ public:
if constexpr (IsDataTypeNumber<DataType> ||
IsDataTypeDecimal<DataType>) {
using FieldType = typename DataType::FieldType;
- res = Dispatcher<FieldType, rounding_mode,
tie_breaking_mode>::apply(
+ res = Dispatcher<FieldType, rounding_mode,
tie_breaking_mode>::apply_vec_const(
column.column.get(), scale_arg);
return true;
}
diff --git a/be/test/vec/function/function_truncate_decimal_test.cpp
b/be/test/vec/function/function_truncate_decimal_test.cpp
new file mode 100644
index 00000000000..36fcaa14e67
--- /dev/null
+++ b/be/test/vec/function/function_truncate_decimal_test.cpp
@@ -0,0 +1,370 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest-message.h>
+#include <gtest/gtest.h>
+
+#include <climits>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <iomanip>
+#include <limits>
+#include <map>
+#include <memory>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "function_test_util.h"
+#include "vec/columns/column.h"
+#include "vec/columns/column_const.h"
+#include "vec/columns/column_decimal.h"
+#include "vec/columns/columns_number.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/column_numbers.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/functions/function_truncate.h"
+
+namespace doris::vectorized {
+// {precision, scale} -> {input, scale_arg, expectation}
+using TestDataSet = std::map<std::pair<int, int>,
std::vector<std::tuple<Int128, int, Int128>>>;
+
+const static TestDataSet truncate_decimal32_cases = {
+ {{1, 0},
+ {
+ {1, -10, 0}, {1, -9, 0}, {1, -8, 0}, {1, -7, 0}, {1, -6, 0},
{1, -5, 0},
+ {1, -4, 0}, {1, -3, 0}, {1, -2, 0}, {1, -1, 0}, {1, 0, 1},
{1, 1, 1},
+ {1, 2, 1}, {1, 3, 1}, {1, 4, 1}, {1, 5, 1}, {1, 6, 1},
{1, 7, 1},
+ {1, 8, 1}, {1, 9, 1}, {1, 10, 1},
+ }},
+ {{1, 1},
+ {
+ {1, -10, 0}, {1, -9, 0}, {1, -8, 0}, {1, -7, 0}, {1, -6, 0},
{1, -5, 0},
+ {1, -4, 0}, {1, -3, 0}, {1, -2, 0}, {1, -1, 0}, {1, 0, 0},
{1, 1, 1},
+ {1, 2, 1}, {1, 3, 1}, {1, 4, 1}, {1, 5, 1}, {1, 6, 1},
{1, 7, 1},
+ {1, 8, 1}, {1, 9, 1}, {1, 10, 1},
+ }},
+ {{2, 0},
+ {
+ {12, -4, 0},
+ {12, -3, 0},
+ {12, -2, 0},
+ {12, -1, 10},
+ {12, 0, 12},
+ {12, 1, 12},
+ {12, 2, 12},
+ {12, 3, 12},
+ {12, 4, 12},
+ }},
+ {{2, 1},
+ {
+ {12, -4, 0},
+ {12, -3, 0},
+ {12, -2, 0},
+ {12, -1, 0},
+ {12, 0, 10},
+ {12, 1, 12},
+ {12, 2, 12},
+ {12, 3, 12},
+ {12, 4, 12},
+ }},
+ {{2, 2},
+ {
+ {12, -4, 0},
+ {12, -3, 0},
+ {12, -2, 0},
+ {12, -1, 0},
+ {12, 0, 0},
+ {12, 1, 10},
+ {12, 2, 12},
+ {12, 3, 12},
+ {12, 4, 12},
+ }},
+ {{9, 0},
+ {
+ {123456789, -10, 0}, {123456789, -9, 0},
{123456789, -8, 100000000},
+ {123456789, -7, 120000000}, {123456789, -6, 123000000},
{123456789, -5, 123400000},
+ {123456789, -4, 123450000}, {123456789, -3, 123456000},
{123456789, -2, 123456700},
+ {123456789, -1, 123456780}, {123456789, 0, 123456789},
{123456789, 1, 123456789},
+ {123456789, 2, 123456789}, {123456789, 3, 123456789},
{123456789, 4, 123456789},
+ {123456789, 5, 123456789}, {123456789, 6, 123456789},
{123456789, 7, 123456789},
+ {123456789, 8, 123456789}, {123456789, 9, 123456789},
{123456789, 10, 123456789},
+ }},
+ {{9, 1},
+ {
+ {123456789, -10, 0}, {123456789, -9, 0},
{123456789, -8, 0},
+ {123456789, -7, 100000000}, {123456789, -6, 120000000},
{123456789, -5, 123000000},
+ {123456789, -4, 123400000}, {123456789, -3, 123450000},
{123456789, -2, 123456000},
+ {123456789, -1, 123456700}, {123456789, 0, 123456780},
{123456789, 1, 123456789},
+ {123456789, 2, 123456789}, {123456789, 3, 123456789},
{123456789, 4, 123456789},
+ {123456789, 5, 123456789}, {123456789, 6, 123456789},
{123456789, 7, 123456789},
+ {123456789, 8, 123456789}, {123456789, 9, 123456789},
{123456789, 10, 123456789},
+ }},
+ {{9, 2},
+ {
+ {123456789, -10, 0}, {123456789, -9, 0},
{123456789, -8, 0},
+ {123456789, -7, 0}, {123456789, -6, 100000000},
{123456789, -5, 120000000},
+ {123456789, -4, 123000000}, {123456789, -3, 123400000},
{123456789, -2, 123450000},
+ {123456789, -1, 123456000}, {123456789, 0, 123456700},
{123456789, 1, 123456780},
+ {123456789, 2, 123456789}, {123456789, 3, 123456789},
{123456789, 4, 123456789},
+ {123456789, 5, 123456789}, {123456789, 6, 123456789},
{123456789, 7, 123456789},
+ {123456789, 8, 123456789}, {123456789, 9, 123456789},
{123456789, 10, 123456789},
+ }},
+ {{9, 3},
+ {
+ {123456789, -10, 0}, {123456789, -9, 0},
{123456789, -8, 0},
+ {123456789, -7, 0}, {123456789, -6, 0},
{123456789, -5, 100000000},
+ {123456789, -4, 120000000}, {123456789, -3, 123000000},
{123456789, -2, 123400000},
+ {123456789, -1, 123450000}, {123456789, 0, 123456000},
{123456789, 1, 123456700},
+ {123456789, 2, 123456780}, {123456789, 3, 123456789},
{123456789, 4, 123456789},
+ {123456789, 5, 123456789}, {123456789, 6, 123456789},
{123456789, 7, 123456789},
+ {123456789, 8, 123456789}, {123456789, 9, 123456789},
{123456789, 10, 123456789},
+ }},
+ {{9, 4},
+ {
+ {123456789, -10, 0}, {123456789, -9, 0},
{123456789, -8, 0},
+ {123456789, -7, 0}, {123456789, -6, 0},
{123456789, -5, 0},
+ {123456789, -4, 100000000}, {123456789, -3, 120000000},
{123456789, -2, 123000000},
+ {123456789, -1, 123400000}, {123456789, 0, 123450000},
{123456789, 1, 123456000},
+ {123456789, 2, 123456700}, {123456789, 3, 123456780},
{123456789, 4, 123456789},
+ {123456789, 5, 123456789}, {123456789, 6, 123456789},
{123456789, 7, 123456789},
+ {123456789, 8, 123456789}, {123456789, 9, 123456789},
{123456789, 10, 123456789},
+ }},
+ {{9, 5},
+ {
+ {123456789, -10, 0}, {123456789, -9, 0},
{123456789, -8, 0},
+ {123456789, -7, 0}, {123456789, -6, 0},
{123456789, -5, 0},
+ {123456789, -4, 0}, {123456789, -3, 100000000},
{123456789, -2, 120000000},
+ {123456789, -1, 123000000}, {123456789, 0, 123400000},
{123456789, 1, 123450000},
+ {123456789, 2, 123456000}, {123456789, 3, 123456700},
{123456789, 4, 123456780},
+ {123456789, 5, 123456789}, {123456789, 6, 123456789},
{123456789, 7, 123456789},
+ {123456789, 8, 123456789}, {123456789, 9, 123456789},
{123456789, 10, 123456789},
+ }},
+ {{9, 6},
+ {
+ {123456789, -10, 0}, {123456789, -9, 0},
{123456789, -8, 0},
+ {123456789, -7, 0}, {123456789, -6, 0},
{123456789, -5, 0},
+ {123456789, -4, 0}, {123456789, -3, 0},
{123456789, -2, 100000000},
+ {123456789, -1, 120000000}, {123456789, 0, 123000000},
{123456789, 1, 123400000},
+ {123456789, 2, 123450000}, {123456789, 3, 123456000},
{123456789, 4, 123456700},
+ {123456789, 5, 123456780}, {123456789, 6, 123456789},
{123456789, 7, 123456789},
+ {123456789, 8, 123456789}, {123456789, 9, 123456789},
{123456789, 10, 123456789},
+ }},
+ {{9, 7},
+ {
+ {123456789, -10, 0}, {123456789, -9, 0},
{123456789, -8, 0},
+ {123456789, -7, 0}, {123456789, -6, 0},
{123456789, -5, 0},
+ {123456789, -4, 0}, {123456789, -3, 0},
{123456789, -2, 0},
+ {123456789, -1, 100000000}, {123456789, 0, 120000000},
{123456789, 1, 123000000},
+ {123456789, 2, 123400000}, {123456789, 3, 123450000},
{123456789, 4, 123456000},
+ {123456789, 5, 123456700}, {123456789, 6, 123456780},
{123456789, 7, 123456789},
+ {123456789, 8, 123456789}, {123456789, 9, 123456789},
{123456789, 10, 123456789},
+ }},
+ {{9, 8},
+ {
+ {123456789, -10, 0}, {123456789, -9, 0},
{123456789, -8, 0},
+ {123456789, -7, 0}, {123456789, -6, 0},
{123456789, -5, 0},
+ {123456789, -4, 0}, {123456789, -3, 0},
{123456789, -2, 0},
+ {123456789, -1, 0}, {123456789, 0, 100000000},
{123456789, 1, 120000000},
+ {123456789, 2, 123000000}, {123456789, 3, 123400000},
{123456789, 4, 123450000},
+ {123456789, 5, 123456000}, {123456789, 6, 123456700},
{123456789, 7, 123456780},
+ {123456789, 8, 123456789}, {123456789, 9, 123456789},
{123456789, 10, 123456789},
+ }},
+ {{9, 9},
+ {
+ {123456789, -10, 0}, {123456789, -9, 0},
{123456789, -8, 0},
+ {123456789, -7, 0}, {123456789, -6, 0},
{123456789, -5, 0},
+ {123456789, -4, 0}, {123456789, -3, 0},
{123456789, -2, 0},
+ {123456789, -1, 0}, {123456789, 0, 0},
{123456789, 1, 100000000},
+ {123456789, 2, 120000000}, {123456789, 3, 123000000},
{123456789, 4, 123400000},
+ {123456789, 5, 123450000}, {123456789, 6, 123456000},
{123456789, 7, 123456700},
+ {123456789, 8, 123456780}, {123456789, 9, 123456789},
{123456789, 10, 123456789},
+ }}};
+
+const static TestDataSet truncate_decimal64_cases = {
+ {{10, 0},
+ {{1234567891, -11, 0}, {1234567891, -10, 0},
{1234567891, -9, 1000000000},
+ {1234567891, -8, 1200000000}, {1234567891, -7, 1230000000},
{1234567891, -6, 1234000000},
+ {1234567891, -5, 1234500000}, {1234567891, -4, 1234560000},
{1234567891, -3, 1234567000},
+ {1234567891, -2, 1234567800}, {1234567891, -1, 1234567890},
{1234567891, 0, 1234567891},
+ {1234567891, 1, 1234567891}, {1234567891, 2, 1234567891},
{1234567891, 3, 1234567891},
+ {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891},
{1234567891, 6, 1234567891},
+ {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891},
{1234567891, 9, 1234567891},
+ {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}},
+ {{10, 1},
+ {{1234567891, -11, 0}, {1234567891, -10, 0},
{1234567891, -9, 0},
+ {1234567891, -8, 1000000000}, {1234567891, -7, 1200000000},
{1234567891, -6, 1230000000},
+ {1234567891, -5, 1234000000}, {1234567891, -4, 1234500000},
{1234567891, -3, 1234560000},
+ {1234567891, -2, 1234567000}, {1234567891, -1, 1234567800},
{1234567891, 0, 1234567890},
+ {1234567891, 1, 1234567891}, {1234567891, 2, 1234567891},
{1234567891, 3, 1234567891},
+ {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891},
{1234567891, 6, 1234567891},
+ {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891},
{1234567891, 9, 1234567891},
+ {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}
+
+ }},
+ {{10, 2},
+ {{1234567891, -11, 0}, {1234567891, -10, 0},
{1234567891, -9, 0},
+ {1234567891, -8, 0}, {1234567891, -7, 1000000000},
{1234567891, -6, 1200000000},
+ {1234567891, -5, 1230000000}, {1234567891, -4, 1234000000},
{1234567891, -3, 1234500000},
+ {1234567891, -2, 1234560000}, {1234567891, -1, 1234567000},
{1234567891, 0, 1234567800},
+ {1234567891, 1, 1234567890}, {1234567891, 2, 1234567891},
{1234567891, 3, 1234567891},
+ {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891},
{1234567891, 6, 1234567891},
+ {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891},
{1234567891, 9, 1234567891},
+ {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}},
+ {{10, 9},
+ {{1234567891, -11, 0}, {1234567891, -10, 0},
{1234567891, -9, 0},
+ {1234567891, -8, 0}, {1234567891, -7, 0},
{1234567891, -6, 0},
+ {1234567891, -5, 0}, {1234567891, -4, 0},
{1234567891, -3, 0},
+ {1234567891, -2, 0}, {1234567891, -1, 0},
{1234567891, 0, 1000000000},
+ {1234567891, 1, 1200000000}, {1234567891, 2, 1230000000},
{1234567891, 3, 1234000000},
+ {1234567891, 4, 1234500000}, {1234567891, 5, 1234560000},
{1234567891, 6, 1234567000},
+ {1234567891, 7, 1234567800}, {1234567891, 8, 1234567890},
{1234567891, 9, 1234567891},
+ {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}},
+ {{18, 0},
+ {{123456789123456789, -19, 0},
+ {123456789123456789, -18, 0},
+ {123456789123456789, -17, 100000000000000000},
+ {123456789123456789, -16, 120000000000000000},
+ {123456789123456789, -15, 123000000000000000},
+ {123456789123456789, -14, 123400000000000000},
+ {123456789123456789, -13, 123450000000000000},
+ {123456789123456789, -12, 123456000000000000},
+ {123456789123456789, -11, 123456700000000000},
+ {123456789123456789, -10, 123456780000000000},
+ {123456789123456789, -9, 123456789000000000},
+ {123456789123456789, -8, 123456789100000000},
+ {123456789123456789, -7, 123456789120000000},
+ {123456789123456789, -6, 123456789123000000},
+ {123456789123456789, -5, 123456789123400000},
+ {123456789123456789, -4, 123456789123450000},
+ {123456789123456789, -3, 123456789123456000},
+ {123456789123456789, -2, 123456789123456700},
+ {123456789123456789, -1, 123456789123456780},
+ {123456789123456789, 0, 123456789123456789},
+ {123456789123456789, 1, 123456789123456789},
+ {123456789123456789, 2, 123456789123456789},
+ {123456789123456789, 3, 123456789123456789},
+ {123456789123456789, 4, 123456789123456789},
+ {123456789123456789, 5, 123456789123456789},
+ {123456789123456789, 6, 123456789123456789},
+ {123456789123456789, 7, 123456789123456789},
+ {123456789123456789, 8, 123456789123456789},
+ {123456789123456789, 18, 123456789123456789}}},
+ {{18, 18},
+ {{123456789123456789, -1, 0},
+ {123456789123456789, 0, 0},
+ {123456789123456789, 1, 100000000000000000},
+ {123456789123456789, 2, 120000000000000000},
+ {123456789123456789, 3, 123000000000000000},
+ {123456789123456789, 4, 123400000000000000},
+ {123456789123456789, 5, 123450000000000000},
+ {123456789123456789, 6, 123456000000000000},
+ {123456789123456789, 7, 123456700000000000},
+ {123456789123456789, 8, 123456780000000000},
+ {123456789123456789, 9, 123456789000000000},
+ {123456789123456789, 10, 123456789100000000},
+ {123456789123456789, 11, 123456789120000000},
+ {123456789123456789, 12, 123456789123000000},
+ {123456789123456789, 13, 123456789123400000},
+ {123456789123456789, 14, 123456789123450000},
+ {123456789123456789, 15, 123456789123456000},
+ {123456789123456789, 16, 123456789123456700},
+ {123456789123456789, 17, 123456789123456780},
+ {123456789123456789, 18, 123456789123456789},
+ {123456789123456789, 19, 123456789123456789},
+ {123456789123456789, 20, 123456789123456789},
+ {123456789123456789, 21, 123456789123456789},
+ {123456789123456789, 22, 123456789123456789},
+ {123456789123456789, 23, 123456789123456789},
+ {123456789123456789, 24, 123456789123456789},
+ {123456789123456789, 25, 123456789123456789},
+ {123456789123456789, 26, 123456789123456789}}}};
+
+template <typename FuncType, typename DecimalType>
+static void checker(const TestDataSet& truncate_test_cases, bool
decimal_col_is_const) {
+ static_assert(IsDecimalNumber<DecimalType>);
+ auto func = std::dynamic_pointer_cast<FuncType>(FuncType::create());
+ FunctionContext* context = nullptr;
+
+ for (const auto& test_case : truncate_test_cases) {
+ Block block;
+ size_t res_idx = 2;
+ ColumnNumbers arguments = {0, 1, 2};
+ const int precision = test_case.first.first;
+ const int scale = test_case.first.second;
+ const size_t input_rows_count = test_case.second.size();
+ auto col_general =
ColumnDecimal<DecimalType>::create(input_rows_count, scale);
+ auto col_scale = ColumnInt32::create();
+ auto col_res_expected =
ColumnDecimal<DecimalType>::create(input_rows_count, scale);
+ size_t rid = 0;
+
+ for (const auto& test_date : test_case.second) {
+ auto input = std::get<0>(test_date);
+ auto scale_arg = std::get<1>(test_date);
+ auto expectation = std::get<2>(test_date);
+ col_general->get_element(rid) = DecimalType(input);
+ col_scale->insert(scale_arg);
+ col_res_expected->get_element(rid) = DecimalType(expectation);
+ rid++;
+ }
+
+ if (decimal_col_is_const) {
+ block.insert({ColumnConst::create(col_general->clone_resized(1),
1),
+
std::make_shared<DataTypeDecimal<DecimalType>>(precision, scale),
+ "col_general_const"});
+ } else {
+ block.insert({col_general->clone(),
+
std::make_shared<DataTypeDecimal<DecimalType>>(precision, scale),
+ "col_general"});
+ }
+
+ block.insert({col_scale->clone(), std::make_shared<DataTypeInt32>(),
"col_scale"});
+ block.insert({nullptr,
std::make_shared<DataTypeDecimal<DecimalType>>(precision, scale),
+ "col_res"});
+
+ auto status = func->execute_impl(context, block, arguments, res_idx,
input_rows_count);
+ auto col_res = assert_cast<const ColumnDecimal<DecimalType>&>(
+ *(block.get_by_position(res_idx).column));
+ EXPECT_TRUE(status.ok());
+
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ auto res = col_res.get_element(i);
+ auto res_expected = col_res_expected->get_element(i);
+ EXPECT_EQ(res, res_expected)
+ << "precision " << precision << " input_scale " << scale
<< " input "
+ << col_general->get_element(i) << " scale_arg " <<
col_scale->get_element(i)
+ << " res " << res << " res_expected " << res_expected;
+ }
+ }
+}
+TEST(TruncateFunctionTest, normal_decimal) {
+ checker<FunctionTruncate<TruncateDecimalTwoArgImpl>,
Decimal32>(truncate_decimal32_cases,
+ false);
+ checker<FunctionTruncate<TruncateDecimalTwoArgImpl>,
Decimal64>(truncate_decimal64_cases,
+ false);
+}
+
+TEST(TruncateFunctionTest, normal_decimal_const) {
+ checker<FunctionTruncate<TruncateDecimalTwoArgImpl>,
Decimal32>(truncate_decimal32_cases, true);
+ checker<FunctionTruncate<TruncateDecimalTwoArgImpl>,
Decimal64>(truncate_decimal64_cases, true);
+}
+
+} // namespace doris::vectorized
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
index b5184c33fcd..9bc857bacef 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
@@ -122,7 +122,7 @@ public class FunctionCallExpr extends Expr {
Preconditions.checkArgument(children.get(1) instanceof
IntLiteral
|| (children.get(1) instanceof CastExpr
&& children.get(1).getChild(0) instanceof
IntLiteral),
- "2nd argument of function round/floor/ceil/truncate
must be literal");
+ "2nd argument of function round/floor/ceil must be
literal");
if (children.get(1) instanceof CastExpr &&
children.get(1).getChild(0) instanceof IntLiteral) {
children.get(1).getChild(0).setType(children.get(1).getType());
children.set(1, children.get(1).getChild(0));
@@ -136,6 +136,34 @@ public class FunctionCallExpr extends Expr {
return returnType;
}
};
+
+ java.util.function.BiFunction<ArrayList<Expr>, Type, Type>
truncateRule = (children, returnType) -> {
+ Preconditions.checkArgument(children != null && children.size() >
0);
+ if (children.size() == 1 &&
children.get(0).getType().isDecimalV3()) {
+ return
ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(), 0);
+ } else if (children.size() == 2) {
+ Expr scaleExpr = children.get(1);
+ if (scaleExpr instanceof IntLiteral
+ || (scaleExpr instanceof CastExpr &&
scaleExpr.getChild(0) instanceof IntLiteral)) {
+ if (children.get(1) instanceof CastExpr &&
children.get(1).getChild(0) instanceof IntLiteral) {
+
children.get(1).getChild(0).setType(children.get(1).getType());
+ children.set(1, children.get(1).getChild(0));
+ } else {
+ children.get(1).setType(Type.INT);
+ }
+ int scaleArg = (int) (((IntLiteral)
children.get(1)).getValue());
+ return
ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(),
+ Math.min(Math.max(scaleArg, 0), ((ScalarType)
children.get(0).getType()).decimalScale()));
+ } else {
+ // Scale argument is a Column, always use same scale with
input decimal
+ return
ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(),
+ ((ScalarType)
children.get(0).getType()).decimalScale());
+ }
+ } else {
+ return returnType;
+ }
+ };
+
java.util.function.BiFunction<ArrayList<Expr>, Type, Type>
arrayDateTimeV2OrDecimalV3Rule
= (children, returnType) -> {
Preconditions.checkArgument(children != null &&
children.size() > 0);
@@ -239,7 +267,7 @@ public class FunctionCallExpr extends Expr {
PRECISION_INFER_RULE.put("dround", roundRule);
PRECISION_INFER_RULE.put("dceil", roundRule);
PRECISION_INFER_RULE.put("dfloor", roundRule);
- PRECISION_INFER_RULE.put("truncate", roundRule);
+ PRECISION_INFER_RULE.put("truncate", truncateRule);
}
public static final ImmutableSet<String> TIME_FUNCTIONS_WITH_PRECISION =
new ImmutableSortedSet.Builder(
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
index 4b57772ed23..6b6308c516c 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.coercion.Int32OrLessType;
@@ -37,19 +38,38 @@ public interface ComputePrecisionForRound extends
ComputePrecision {
} else if (arity() == 2 && signature.getArgType(0) instanceof
DecimalV3Type) {
DecimalV3Type decimalV3Type =
DecimalV3Type.forType(getArgumentType(0));
Expression floatLength = getArgument(1);
- Preconditions.checkArgument(floatLength.getDataType() instanceof
Int32OrLessType
- && (floatLength.isLiteral() || (
- floatLength instanceof Cast &&
floatLength.child(0).isLiteral()
- && floatLength.child(0).getDataType()
instanceof Int32OrLessType)),
- "2nd argument of function round/floor/ceil/truncate must
be literal");
-
int scale;
- if (floatLength instanceof Cast) {
- scale = ((IntegerLikeLiteral)
floatLength.child(0)).getIntValue();
+
+ if (this instanceof Truncate) {
+ if (floatLength.isLiteral() || (
+ floatLength instanceof Cast &&
floatLength.child(0).isLiteral()
+ && floatLength.child(0).getDataType()
instanceof Int32OrLessType)) {
+ // Scale argument is a literal or cast from other literal
+ if (floatLength instanceof Cast) {
+ scale = ((IntegerLikeLiteral)
floatLength.child(0)).getIntValue();
+ } else {
+ scale = ((IntegerLikeLiteral)
floatLength).getIntValue();
+ }
+ scale = Math.min(Math.max(scale, 0),
decimalV3Type.getScale());
+ } else {
+ // Truncate could use Column as its scale argument.
+ // Result scale will always same with input Decimal in
this situation.
+ scale = decimalV3Type.getScale();
+ }
} else {
- scale = ((IntegerLikeLiteral) floatLength).getIntValue();
+ Preconditions.checkArgument(floatLength.getDataType()
instanceof Int32OrLessType
+ && (floatLength.isLiteral() || (
+ floatLength instanceof Cast &&
floatLength.child(0).isLiteral()
+ && floatLength.child(0).getDataType()
instanceof Int32OrLessType)),
+ "2nd argument of function round/floor/ceil must be
literal");
+ if (floatLength instanceof Cast) {
+ scale = ((IntegerLikeLiteral)
floatLength.child(0)).getIntValue();
+ } else {
+ scale = ((IntegerLikeLiteral) floatLength).getIntValue();
+ }
+ scale = Math.min(Math.max(scale, 0), decimalV3Type.getScale());
}
- scale = Math.min(Math.max(scale, 0), decimalV3Type.getScale());
+
return signature.withArgumentType(0, decimalV3Type)
.withReturnType(DecimalV3Type.createDecimalV3Type(decimalV3Type.getPrecision(),
scale));
} else {
diff --git
a/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out
b/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out
new file mode 100644
index 00000000000..24f675ffbe2
--- /dev/null
+++
b/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out
@@ -0,0 +1,101 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !sql --
+0 123.3
+1 123.3
+2 123.3
+3 123.3
+4 123.3
+5 123.3
+6 123.3
+7 123.3
+8 123.3
+9 123.3
+
+-- !sql --
+0 120
+1 120
+2 120
+3 120
+4 120
+5 120
+6 120
+7 120
+8 120
+9 120
+
+-- !sql --
+0 123
+1 123
+2 123
+3 123
+4 123
+5 123
+6 123
+7 123
+8 123
+9 123
+
+-- !sql --
+0E-8
+
+-- !sql --
+0 0.0
+1 0.0
+2 0.0
+3 0.0
+4 0.0
+
+-- !vec_const0 --
+1 12345.0 1.23456789E8
+2 12345.0 1.23456789E8
+3 12345.0 1.23456789E8
+4 0.0 0.0
+
+-- !vec_const0 --
+1 12345.1 1.234567891E8
+2 12345.1 1.234567891E8
+3 12345.1 1.234567891E8
+4 0.0 0.0
+
+-- !vec_const0 --
+1 12340.0 1.2345678E8
+2 12340.0 1.2345678E8
+3 12340.0 1.2345678E8
+4 0.0 0.0
+
+-- !vec_const1 --
+1 123456789 123456789 12345678.1 12345678
0.123456789 0
+2 123456789 123456789 12345678.1 12345678
0.123456789 0
+3 123456789 123456789 12345678.1 12345678
0.123456789 0
+4 0 0 0.0 0 0E-9 0
+
+-- !vec_const2 --
+1 123456789 123456789 1.123456789 1 0.1234567890 0
+2 123456789 123456789 1.123456789 1 0.1234567890 0
+3 123456789 123456789 1.123456789 1 0.1234567890 0
+4 0 0 0E-9 0 0E-10 0
+
+-- !const_vec1 --
+123456789.123456789 1 123456789.100000000
+123456789.123456789 1 123456789.100000000
+123456789.123456789 1 123456789.100000000
+123456789.123456789 1 123456789.100000000
+
+-- !const_vec2 --
+123456789.123456789 -1 123456780.000000000
+123456789.123456789 -1 123456780.000000000
+123456789.123456789 -1 123456780.000000000
+123456789.123456789 -1 123456780.000000000
+
+-- !vec_vec0 --
+1 1 12345.1 1.234567891E8
+2 1 12345.1 1.234567891E8
+3 1 12345.1 1.234567891E8
+4 1 0.0 0.0
+
+-- !truncate_dec128 --
+1 1234567891234567891 1234567891234567891 1234567891.123456789
1234567891 0.1234567891234567891 0
+
+-- !truncate_dec128 --
+1 1234567891234567891 1234567891234567891 1234567891.123456789
1234567891.100000000 0.1234567891234567891 0.1000000000000000000
+
diff --git
a/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy
b/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy
new file mode 100644
index 00000000000..767140e7a6f
--- /dev/null
+++
b/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy
@@ -0,0 +1,132 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_function_truncate") {
+ qt_sql """
+ SELECT number, truncate(123.345 , 1) FROM numbers("number"="10");
+ """
+ qt_sql """
+ SELECT number, truncate(123.123, -1) FROM numbers("number"="10");
+ """
+ qt_sql """
+ SELECT number, truncate(123.123, 0) FROM numbers("number"="10");
+ """
+
+ // const_const, result scale should be 10
+ qt_sql """
+ SELECT truncate(cast(0 as Decimal(9,8)), 10);
+ """
+
+ // const_const, result scale should be 1
+ qt_sql """
+ SELECT number, truncate(cast(0 as Decimal(9,4)), 1) FROM
numbers("number"="5")
+ """
+
+ sql """DROP TABLE IF EXISTS test_function_truncate;"""
+ sql """DROP TABLE IF EXISTS test_function_truncate_dec128;"""
+ sql """
+ CREATE TABLE test_function_truncate (
+ rid int, flo float, dou double,
+ dec90 decimal(9, 0), dec91 decimal(9, 1), dec99 decimal(9, 9),
+ dec100 decimal(10,0), dec109 decimal(10,9), dec1010 decimal(10,10),
+ number int DEFAULT 1)
+ DISTRIBUTED BY HASH(rid)
+ PROPERTIES("replication_num" = "1" );
+ """
+
+ sql """
+ INSERT INTO test_function_truncate
+ VALUES
+ (1, 12345.123, 123456789.123456789,
+ 123456789, 12345678.1, 0.123456789,
+ 123456789.1, 1.123456789, 0.123456789, 1);
+ """
+ sql """
+ INSERT INTO test_function_truncate
+ VALUES
+ (2, 12345.123, 123456789.123456789,
+ 123456789, 12345678.1, 0.123456789,
+ 123456789.1, 1.123456789, 0.123456789, 1);
+ """
+ sql """
+ INSERT INTO test_function_truncate
+ VALUES
+ (3, 12345.123, 123456789.123456789,
+ 123456789, 12345678.1, 0.123456789,
+ 123456789.1, 1.123456789, 0.123456789, 1);
+ """
+ sql """
+ INSERT INTO test_function_truncate
+ VALUES
+ (4, 0, 0, 0, 0.0, 0, 0, 0, 0, 1);
+ """
+ qt_vec_const0 """
+ SELECT rid, truncate(flo, 0), truncate(dou, 0) FROM
test_function_truncate order by rid;
+ """
+ qt_vec_const0 """
+ SELECT rid, truncate(flo, 1), truncate(dou, 1) FROM
test_function_truncate order by rid;
+ """
+ qt_vec_const0 """
+ SELECT rid, truncate(flo, -1), truncate(dou, -1) FROM
test_function_truncate order by rid;
+ """
+ qt_vec_const1 """
+ SELECT rid, dec90, truncate(dec90, 0), dec91, truncate(dec91, 0),
dec99, truncate(dec99, 0) FROM test_function_truncate order by rid
+ """
+ qt_vec_const2 """
+ SELECT rid, dec100, truncate(dec100, 0), dec109, truncate(dec109, 0),
dec1010, truncate(dec1010, 0) FROM test_function_truncate order by rid
+ """
+
+
+
+ qt_const_vec1 """
+ SELECT 123456789.123456789, number, truncate(123456789.123456789,
number) from test_function_truncate;
+ """
+ qt_const_vec2 """
+ SELECT 123456789.123456789, -number, truncate(123456789.123456789,
-number) from test_function_truncate;
+ """
+ qt_vec_vec0 """
+ SELECT rid,number, truncate(flo, number), truncate(dou, number) FROM
test_function_truncate order by rid;
+ """
+
+ sql """
+ CREATE TABLE test_function_truncate_dec128 (
+ rid int, dec190 decimal(19,0), dec199 decimal(19,9), dec1919
decimal(19,19),
+ dec380 decimal(38,0), dec3819 decimal(38,19), dec3838
decimal(38,38),
+ number int DEFAULT 1
+ )
+ DISTRIBUTED BY HASH(rid)
+ PROPERTIES("replication_num" = "1" );
+ """
+ sql """
+ INSERT INTO test_function_truncate_dec128
+ VALUES
+ (1, 1234567891234567891.0, 1234567891.123456789, 0.1234567891234567891,
+ 12345678912345678912345678912345678912.0,
+ 1234567891234567891.1234567891234567891,
+
0.12345678912345678912345678912345678912345678912345678912345678912345678912,
1);
+ """
+ qt_truncate_dec128 """
+ SELECT rid, dec190, truncate(dec190, 0), dec199, truncate(dec199, 0),
dec1919, truncate(dec1919, 0)
+ FROM test_function_truncate_dec128 order by rid
+ """
+
+ qt_truncate_dec128 """
+ SELECT rid, dec190, truncate(dec190, number), dec199, truncate(dec199,
number), dec1919, truncate(dec1919, number)
+ FROM test_function_truncate_dec128 order by rid
+ """
+
+}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]