This is an automated email from the ASF dual-hosted git repository.
taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 27647b6cec [GLUTEN-8704][CH] try accelerate some spark* function by
optimizing tight loops (#8708)
27647b6cec is described below
commit 27647b6cec1152ebfcd507838ef52194ed55cb2a
Author: 李扬 <[email protected]>
AuthorDate: Tue Feb 18 14:56:35 2025 +0800
[GLUTEN-8704][CH] try accelerate some spark* function by optimizing tight
loops (#8708)
* improve some loops
* commit again
* fix style
* fix uts
* add benchmark
* add all benchmarks
* vector
* commit again
* improve decimalcheckoverflow
* commit again
* optimize checkoverflow from int/float
* commit again
* fix bugs
* fix failed ut
* commit again
---
.../Functions/SparkFunctionCastFloatToInt.cpp | 32 +--
.../Functions/SparkFunctionCastFloatToInt.h | 91 ++++---
.../SparkFunctionCheckDecimalOverflow.cpp | 283 +++++++++++----------
.../SparkFunctionDecimalBinaryArithmetic.h | 1 -
.../local-engine/Functions/SparkFunctionDivide.h | 190 ++++++++------
cpp-ch/local-engine/Parser/ExpressionParser.cpp | 9 +-
cpp-ch/local-engine/tests/CMakeLists.txt | 2 +-
..._function.cpp => benchmark_spark_functions.cpp} | 229 ++++++++++++++++-
.../utils/clickhouse/ClickHouseTestSettings.scala | 1 +
9 files changed, 550 insertions(+), 288 deletions(-)
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
index c378f9fbf7..7cdc85cb17 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
@@ -26,31 +26,15 @@ using namespace DB;
namespace local_engine
{
-struct NameToUInt8 { static constexpr auto name = "sparkCastFloatToUInt8"; };
-struct NameToUInt16 { static constexpr auto name = "sparkCastFloatToUInt16"; };
-struct NameToUInt32 { static constexpr auto name = "sparkCastFloatToUInt32"; };
-struct NameToUInt64 { static constexpr auto name = "sparkCastFloatToUInt64"; };
-struct NameToUInt128 { static constexpr auto name = "sparkCastFloatToUInt128";
};
-struct NameToUInt256 { static constexpr auto name = "sparkCastFloatToUInt256";
};
struct NameToInt8 { static constexpr auto name = "sparkCastFloatToInt8"; };
struct NameToInt16 { static constexpr auto name = "sparkCastFloatToInt16"; };
struct NameToInt32 { static constexpr auto name = "sparkCastFloatToInt32"; };
struct NameToInt64 { static constexpr auto name = "sparkCastFloatToInt64"; };
-struct NameToInt128 { static constexpr auto name = "sparkCastFloatToInt128"; };
-struct NameToInt256 { static constexpr auto name = "sparkCastFloatToInt256"; };
-using SparkFunctionCastFloatToInt8 =
local_engine::SparkFunctionCastFloatToInt<Int8, NameToInt8, INT8_MAX, INT8_MIN>;
-using SparkFunctionCastFloatToInt16 =
local_engine::SparkFunctionCastFloatToInt<Int16, NameToInt16, INT16_MAX,
INT16_MIN>;
-using SparkFunctionCastFloatToInt32 =
local_engine::SparkFunctionCastFloatToInt<Int32, NameToInt32, INT32_MAX,
INT32_MIN>;
-using SparkFunctionCastFloatToInt64 =
local_engine::SparkFunctionCastFloatToInt<Int64, NameToInt64, INT64_MAX,
INT64_MIN>;
-using SparkFunctionCastFloatToInt128 =
local_engine::SparkFunctionCastFloatToInt<Int128, NameToInt128,
std::numeric_limits<Int128>::max(), std::numeric_limits<Int128>::min()>;
-using SparkFunctionCastFloatToInt256 =
local_engine::SparkFunctionCastFloatToInt<Int256, NameToInt256,
std::numeric_limits<Int256>::max(), std::numeric_limits<Int256>::min()>;
-using SparkFunctionCastFloatToUInt8 =
local_engine::SparkFunctionCastFloatToInt<UInt8, NameToUInt8, UINT8_MAX, 0>;
-using SparkFunctionCastFloatToUInt16 =
local_engine::SparkFunctionCastFloatToInt<UInt16, NameToUInt16, UINT16_MAX, 0>;
-using SparkFunctionCastFloatToUInt32 =
local_engine::SparkFunctionCastFloatToInt<UInt32, NameToUInt32, UINT32_MAX, 0>;
-using SparkFunctionCastFloatToUInt64 =
local_engine::SparkFunctionCastFloatToInt<UInt64, NameToUInt64, UINT64_MAX, 0>;
-using SparkFunctionCastFloatToUInt128 =
local_engine::SparkFunctionCastFloatToInt<UInt128, NameToUInt128,
std::numeric_limits<UInt128>::max(), 0>;
-using SparkFunctionCastFloatToUInt256 =
local_engine::SparkFunctionCastFloatToInt<UInt256, NameToUInt256,
std::numeric_limits<UInt256>::max(), 0>;
+using SparkFunctionCastFloatToInt8 =
local_engine::SparkFunctionCastFloatToInt<Int8, NameToInt8>;
+using SparkFunctionCastFloatToInt16 =
local_engine::SparkFunctionCastFloatToInt<Int16, NameToInt16>;
+using SparkFunctionCastFloatToInt32 =
local_engine::SparkFunctionCastFloatToInt<Int32, NameToInt32>;
+using SparkFunctionCastFloatToInt64 =
local_engine::SparkFunctionCastFloatToInt<Int64, NameToInt64>;
REGISTER_FUNCTION(SparkFunctionCastToInt)
{
@@ -58,13 +42,5 @@ REGISTER_FUNCTION(SparkFunctionCastToInt)
factory.registerFunction<SparkFunctionCastFloatToInt16>();
factory.registerFunction<SparkFunctionCastFloatToInt32>();
factory.registerFunction<SparkFunctionCastFloatToInt64>();
- factory.registerFunction<SparkFunctionCastFloatToInt128>();
- factory.registerFunction<SparkFunctionCastFloatToInt256>();
- factory.registerFunction<SparkFunctionCastFloatToUInt8>();
- factory.registerFunction<SparkFunctionCastFloatToUInt16>();
- factory.registerFunction<SparkFunctionCastFloatToUInt32>();
- factory.registerFunction<SparkFunctionCastFloatToUInt64>();
- factory.registerFunction<SparkFunctionCastFloatToUInt128>();
- factory.registerFunction<SparkFunctionCastFloatToUInt256>();
}
}
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
index abe66d536d..00614798ce 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
@@ -40,8 +40,7 @@ namespace ErrorCodes
namespace local_engine
{
-/// TODO(taiyang-li): remove int_max_value and int_min_value for it is
determined by T
-template <is_integer T, typename Name, T int_max_value, T int_min_value>
+template <is_integer T, typename Name>
class SparkFunctionCastFloatToInt : public DB::IFunction
{
public:
@@ -61,19 +60,16 @@ public:
if (arguments.size() != 1)
throw
DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}'s
arguments number must be 1", name);
+ if (!isFloat(removeNullable(arguments[0])))
+ throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Function {}'s 1st argument must be float type", name);
+
return makeNullable(std::make_shared<const DB::DataTypeNumber<T>>());
}
- DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments,
const DB::DataTypePtr & result_type, size_t) const override
+ DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments,
const DB::DataTypePtr & result_type, size_t input_rows_count) const override
{
- if (arguments.size() != 1)
- throw
DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}'s
arguments number must be 1", name);
-
- if (!isFloat(removeNullable(arguments[0].type)))
- throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Function {}'s 1st argument must be float type", name);
-
DB::ColumnPtr src_col = arguments[0].column;
- size_t size = src_col->size();
+ size_t size = input_rows_count;
auto res_col = DB::ColumnVector<T>::create(size, 0);
auto null_map_col = DB::ColumnUInt8::create(size, 0);
@@ -94,23 +90,60 @@ public:
return DB::ColumnNullable::create(std::move(res_col),
std::move(null_map_col));
}
+ MULTITARGET_FUNCTION_AVX2_SSE42(
+ MULTITARGET_FUNCTION_HEADER(template <typename F> static void
NO_SANITIZE_UNDEFINED NO_INLINE),
+ vectorImpl,
+ MULTITARGET_FUNCTION_BODY(
+ (F int_min,
+ F int_max,
+ const DB::PaddedPODArray<F> & src_data,
+ DB::PaddedPODArray<T> & data,
+ DB::PaddedPODArray<UInt8> & null_map_data,
+ size_t rows) /// NOLINT
+ {
+ for (size_t i = 0; i < rows; ++i)
+ {
+ null_map_data[i] = !isFinite(src_data[i]);
+ data[i] = static_cast<T>(std::fmax(int_min,
std::fmin(int_max, src_data[i])));
+ }
+ }))
+
template <typename F>
- void executeInternal(const DB::ColumnPtr & src, DB::PaddedPODArray<T> &
data, DB::PaddedPODArray<UInt8> & null_map_data) const
+ static void NO_INLINE vector(
+ F int_min,
+ F int_max,
+ const DB::PaddedPODArray<F> & src_data,
+ DB::PaddedPODArray<T> & data,
+ DB::PaddedPODArray<UInt8> & null_map_data,
+ size_t rows)
{
- const DB::ColumnVector<F> * src_vec = assert_cast<const
DB::ColumnVector<F> *>(src.get());
- /// TODO(taiyang-li): try to vectorize below loop
- for (size_t i = 0; i < src_vec->size(); ++i)
+#if USE_MULTITARGET_CODE
+ if (isArchSupported(DB::TargetArch::AVX2))
{
- F element = src_vec->getElement(i);
- if (isNaN(element) || !isFinite(element))
- null_map_data[i] = 1;
- else if (element > int_max_value)
- data[i] = int_max_value;
- else if (element < int_min_value)
- data[i] = int_min_value;
- else
- data[i] = static_cast<T>(element);
+ vectorImplAVX2(int_min, int_max, src_data, data, null_map_data,
rows);
+ return;
}
+
+ if (isArchSupported(DB::TargetArch::SSE42))
+ {
+ vectorImplSSE42(int_min, int_max, src_data, data, null_map_data,
rows);
+ return;
+ }
+#endif
+
+ vectorImpl(int_min, int_max, src_data, data, null_map_data, rows);
+ }
+
+ template <typename F>
+ void executeInternal(const DB::ColumnPtr & src, DB::PaddedPODArray<T> &
data, DB::PaddedPODArray<UInt8> & null_map_data) const
+ {
+ const DB::ColumnVector<F> * src_vec = assert_cast<const
DB::ColumnVector<F> *>(src.get());
+
+ size_t rows = src_vec->size();
+ const auto & src_data = src_vec->getData();
+ const auto int_min = static_cast<F>(std::numeric_limits<T>::min());
+ const auto int_max = static_cast<F>(std::numeric_limits<T>::max());
+ vector(int_min, int_max, src_data, data, null_map_data, rows);
}
#if USE_EMBEDDED_COMPILER
@@ -125,6 +158,7 @@ public:
return true;
}
+
llvm::Value *
compileImpl(llvm::IRBuilderBase & builder, const DB::ValuesWithType &
arguments, const DB::DataTypePtr & result_type) const override
{
@@ -140,16 +174,7 @@ public:
b.CreateFCmpOEQ(src_value,
llvm::ConstantFP::getInfinity(float_type, true)));
bool is_signed = std::is_signed_v<T>;
- llvm::Value * max_value = llvm::ConstantInt::get(int_type,
static_cast<UInt64>(int_max_value), is_signed);
- llvm::Value * min_value = llvm::ConstantInt::get(int_type,
static_cast<UInt64>(int_min_value), is_signed);
- llvm::Value * clamped_value = b.CreateSelect(
- b.CreateFCmpOGT(src_value, llvm::ConstantFP::get(float_type,
static_cast<Float64>(int_max_value))),
- max_value,
- b.CreateSelect(
- b.CreateFCmpOLT(src_value, llvm::ConstantFP::get(float_type,
static_cast<Float64>(int_min_value))),
- min_value,
- is_signed_v<T> ? b.CreateFPToSI(src_value, int_type) :
b.CreateFPToUI(src_value, int_type)));
- llvm::Value * result_value = b.CreateSelect(b.CreateOr(is_nan,
is_inf), llvm::Constant::getNullValue(int_type), clamped_value);
+ llvm::Value * result_value = is_signed_v<T> ?
b.CreateFPToSI(src_value, int_type) : b.CreateFPToUI(src_value, int_type);
llvm::Value * result_is_null = b.CreateOr(is_nan, is_inf);
auto * nullable_structure_type = toNativeType(b, result_type);
diff --git
a/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
index 95853b187a..3b870b22d7 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
@@ -16,7 +16,6 @@
*/
#include "SparkFunctionCheckDecimalOverflow.h"
-#include <typeinfo>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnsNumber.h>
@@ -27,6 +26,8 @@
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
+#include "Columns/ColumnsCommon.h"
+#include <iostream>
namespace DB
{
@@ -45,6 +46,8 @@ namespace local_engine
{
using namespace DB;
+namespace
+{
struct CheckDecimalOverflowSpark
{
static constexpr auto name = "checkDecimalOverflowSpark";
@@ -54,14 +57,19 @@ struct CheckDecimalOverflowSparkOrNull
static constexpr auto name = "checkDecimalOverflowSparkOrNull";
};
-enum class CheckExceptionMode
+enum class CheckExceptionMode: uint8_t
{
Throw, /// Throw exception if value cannot be parsed.
Null /// Return ColumnNullable with NULLs when value cannot be parsed.
};
-namespace
+enum class ScaleDirection: int8_t
{
+ Up = 1,
+ Down = -1,
+ None = 0
+};
+
/// Returns received decimal value if and Decimal value has less digits then
it's Precision allow, 0 otherwise.
/// Precision could be set as second argument or omitted. If omitted function
uses Decimal precision of the first argument.
template <typename Name, CheckExceptionMode mode>
@@ -102,193 +110,210 @@ public:
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const
DataTypePtr &, size_t input_rows_count) const override
{
- const auto & src_column = arguments[0];
- UInt32 precision = extractArgument(arguments[1]);
- UInt32 scale = extractArgument(arguments[2]);
+ UInt32 to_precision = extractArgument(arguments[1]);
+ UInt32 to_scale = extractArgument(arguments[2]);
- ColumnPtr result_column;
+ const auto & src_col = arguments[0];
+ ColumnPtr dst_col;
auto call = [&](const auto & types) -> bool
{
using Types = std::decay_t<decltype(types)>;
using FromDataType = typename Types::LeftType;
using ToDataType = typename Types::RightType;
+
if constexpr (IsDataTypeDecimal<FromDataType> ||
IsDataTypeNumber<FromDataType>)
{
using FromFieldType = typename FromDataType::FieldType;
- if (const ColumnVectorOrDecimal<FromFieldType> * col_vec =
checkAndGetColumn<ColumnVectorOrDecimal<FromFieldType>>(src_column.column.get()))
+
+ /// Fast path
+ if constexpr (IsDataTypeDecimal<FromDataType>)
+ {
+ auto from_precision = getDecimalPrecision(*src_col.type);
+ auto from_scale = getDecimalScale(*src_col.type);
+ if (from_precision == to_precision && from_scale ==
to_scale)
+ {
+ dst_col = src_col.column;
+ return true;
+ }
+ }
+
+ if (const ColumnVectorOrDecimal<FromFieldType> * col_vec =
checkAndGetColumn<ColumnVectorOrDecimal<FromFieldType>>(src_col.column.get()))
{
- executeInternal<FromDataType, ToDataType>(*col_vec,
result_column, input_rows_count, precision, scale);
+ executeInternal<FromDataType, ToDataType>(*col_vec,
dst_col, input_rows_count, to_precision, to_scale);
return true;
}
}
+
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal
column while execute function {}", getName());
};
- if (precision <= DecimalUtils::max_precision<Decimal32>)
-
callOnIndexAndDataType<DataTypeDecimal<Decimal32>>(src_column.type->getTypeId(),
call);
- else if (precision <= DecimalUtils::max_precision<Decimal64>)
-
callOnIndexAndDataType<DataTypeDecimal<Decimal64>>(src_column.type->getTypeId(),
call);
- else if (precision <= DecimalUtils::max_precision<Decimal128>)
-
callOnIndexAndDataType<DataTypeDecimal<Decimal128>>(src_column.type->getTypeId(),
call);
+ if (to_precision <= DecimalUtils::max_precision<Decimal32>)
+
callOnIndexAndDataType<DataTypeDecimal<Decimal32>>(src_col.type->getTypeId(),
call);
+ else if (to_precision <= DecimalUtils::max_precision<Decimal64>)
+
callOnIndexAndDataType<DataTypeDecimal<Decimal64>>(src_col.type->getTypeId(),
call);
+ else if (to_precision <= DecimalUtils::max_precision<Decimal128>)
+
callOnIndexAndDataType<DataTypeDecimal<Decimal128>>(src_col.type->getTypeId(),
call);
else
-
callOnIndexAndDataType<DataTypeDecimal<Decimal256>>(src_column.type->getTypeId(),
call);
-
+
callOnIndexAndDataType<DataTypeDecimal<Decimal256>>(src_col.type->getTypeId(),
call);
- if (!result_column)
- throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong call for {}
with {}", getName(), src_column.type->getName());
+ if (!dst_col)
+ throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong call for {}
with {}", getName(), src_col.type->getName());
- return result_column;
+ return dst_col;
}
private:
template <typename FromDataType, typename ToDataType>
requires(IsDataTypeDecimal<ToDataType> && (IsDataTypeDecimal<FromDataType>
|| IsDataTypeNumber<FromDataType>))
- static void executeInternal(
- const ColumnVectorOrDecimal<typename FromDataType::FieldType> &
col_source, ColumnPtr & result_column, size_t input_rows_count, UInt32
precision, UInt32 scale_to)
+ static void
+ executeInternal(const FromDataType::ColumnType & src_col, ColumnPtr &
dst_col, size_t rows, UInt32 to_precision, UInt32 to_scale)
{
using ToFieldType = typename ToDataType::FieldType;
+ using ToNativeType = typename ToFieldType::NativeType;
using ToColumnType = typename ToDataType::ColumnType;
- using T = typename FromDataType::FieldType;
-
- ColumnUInt8::MutablePtr col_null_map_to;
- ColumnUInt8::Container * vec_null_map_to [[maybe_unused]] = nullptr;
- UInt32 scale_from = 0;
- using ToFieldNativeType = typename ToFieldType::NativeType;
- ToFieldNativeType decimal_int_part_max = 0;
- ToFieldNativeType decimal_int_part_min = 0;
- if constexpr (IsDataTypeDecimal<FromDataType>)
- scale_from = col_source.getScale();
- else
- {
- decimal_int_part_max =
DecimalUtils::scaleMultiplier<ToFieldNativeType>(precision - scale_to) - 1;
- decimal_int_part_min = 1 -
DecimalUtils::scaleMultiplier<ToFieldNativeType>(precision - scale_to);
- }
- if constexpr (exception_mode == CheckExceptionMode::Null)
- {
- col_null_map_to = ColumnUInt8::create(input_rows_count, false);
- vec_null_map_to = &col_null_map_to->getData();
- }
+ using FromFieldType = typename FromDataType::FieldType;
- typename ToColumnType::MutablePtr col_to =
ToColumnType::create(input_rows_count, scale_to);
- auto & vec_to = col_to->getData();
- vec_to.resize_exact(input_rows_count);
+ using MaxFieldType = std::conditional_t<
+ is_decimal<FromFieldType>,
+ std::conditional_t<(sizeof(FromFieldType) > sizeof(ToFieldType)),
FromFieldType, ToFieldType>,
+ ToFieldType>;
+ using MaxNativeType = typename MaxFieldType::NativeType;
- auto & datas = col_source.getData();
- for (size_t i = 0; i < input_rows_count; ++i)
+ /// Calculate const parameters for decimal conversion outside the loop
to avoid unnecessary calculations.
+ ScaleDirection scale_direction;
+ UInt32 from_scale = 0;
+ MaxNativeType scale_multiplier = 0;
+ MaxNativeType pow10_to_precision =
DecimalUtils::scaleMultiplier<MaxNativeType>(to_precision);
+ if constexpr (IsDataTypeDecimal<FromDataType>)
{
- ToFieldType result;
- bool success = convertToDecimalImpl<FromDataType,
ToDataType>(datas[i], precision, scale_from, scale_to, decimal_int_part_max,
decimal_int_part_min, result);
- if constexpr (exception_mode == CheckExceptionMode::Null)
+ from_scale = src_col.getScale();
+ if (to_scale > from_scale)
{
- vec_to[i] = static_cast<ToFieldType>(result);
- (*vec_null_map_to)[i] = !success;
+ scale_direction = ScaleDirection::Up;
+ scale_multiplier =
DecimalUtils::scaleMultiplier<MaxNativeType>(to_scale - from_scale);
+ }
+ else if (to_scale < from_scale)
+ {
+ scale_direction = ScaleDirection::Down;
+ scale_multiplier =
DecimalUtils::scaleMultiplier<MaxNativeType>(from_scale - to_scale);
}
else
{
- if (success)
- vec_to[i] = static_cast<ToFieldType>(result);
- else
- throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal
value is overflow.");
+ scale_direction = ScaleDirection::None;
+ scale_multiplier = 1;
}
}
+ else
+ {
+ scale_multiplier =
DecimalUtils::scaleMultiplier<MaxNativeType>(to_scale);
+ }
- if constexpr (exception_mode == CheckExceptionMode::Null)
- result_column = ColumnNullable::create(std::move(col_to),
std::move(col_null_map_to));
+ auto & src_data = src_col.getData();
+
+ auto res_data_col = ToColumnType::create(rows, to_scale);
+ auto & res_data = res_data_col->getData();
+ auto res_nullmap_col = ColumnUInt8::create(rows, 0);
+ auto & res_nullmap_data = res_nullmap_col->getData();
+
+ if constexpr (IsDataTypeDecimal<FromDataType>)
+ {
+ if (scale_direction == ScaleDirection::Up)
+ for (size_t i = 0; i < rows; ++i)
+ res_nullmap_data[i] =
!convertDecimalToDecimalImpl<ScaleDirection::Up, FromDataType, ToDataType,
MaxNativeType>(
+ src_data[i], scale_multiplier, pow10_to_precision,
res_data[i]);
+ else if (scale_direction == ScaleDirection::Down)
+ for (size_t i = 0; i < rows; ++i)
+ res_nullmap_data[i] =
!convertDecimalToDecimalImpl<ScaleDirection::Down, FromDataType, ToDataType,
MaxNativeType>(
+ src_data[i], scale_multiplier, pow10_to_precision,
res_data[i]);
+ else
+ for (size_t i = 0; i < rows; ++i)
+ res_nullmap_data[i] =
!convertDecimalToDecimalImpl<ScaleDirection::None, FromDataType, ToDataType,
MaxNativeType>(
+ src_data[i], scale_multiplier, pow10_to_precision,
res_data[i]);
+ }
else
- result_column = std::move(col_to);
- }
+ {
+ for (size_t i = 0; i < rows; ++i)
+ res_nullmap_data[i]
+ = !convertNumberToDecimalImpl<FromDataType,
ToDataType>(src_data[i], scale_multiplier, pow10_to_precision, res_data[i]);
+ }
- template <typename FromDataType, typename ToDataType>
- requires(IsDataTypeDecimal<ToDataType>)
- static bool convertToDecimalImpl(
- const FromDataType::FieldType & value,
- UInt32 precision_to,
- UInt32 scale_from,
- UInt32 scale_to,
- typename ToDataType::FieldType::NativeType decimal_int_part_max,
- typename ToDataType::FieldType::NativeType decimal_int_part_min,
- typename ToDataType::FieldType & result)
- {
- using FromFieldType = typename FromDataType::FieldType;
- if constexpr (std::is_same_v<FromFieldType, Decimal32>)
- return convertDecimalsImpl<DataTypeDecimal<Decimal32>,
ToDataType>(value, precision_to, scale_from, scale_to, result);
- else if constexpr (std::is_same_v<FromFieldType, Decimal64>)
- return convertDecimalsImpl<DataTypeDecimal<Decimal64>,
ToDataType>(value, precision_to, scale_from, scale_to, result);
- else if constexpr (std::is_same_v<FromFieldType, Decimal128>)
- return convertDecimalsImpl<DataTypeDecimal<Decimal128>,
ToDataType>(value, precision_to, scale_from, scale_to, result);
- else if constexpr (std::is_same_v<FromFieldType, Decimal256>)
- return convertDecimalsImpl<DataTypeDecimal<Decimal256>,
ToDataType>(value, precision_to, scale_from, scale_to, result);
- else if constexpr (IsDataTypeNumber<FromDataType> &&
!std::is_same_v<FromFieldType, BFloat16>)
- return convertNumberToDecimalImpl<DataTypeNumber<FromFieldType>,
ToDataType>(value, scale_to, decimal_int_part_max, decimal_int_part_min,
result);
+ if constexpr (exception_mode == CheckExceptionMode::Throw)
+ {
+ if (!memoryIsZero(res_nullmap_data.data(), 0, rows))
+ throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal value
is overflow.");
+
+ dst_col = std::move(res_data_col);
+ }
else
- throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Convert from {} type
to decimal type is not implemented.", typeid(value).name());
+ dst_col = ColumnNullable::create(std::move(res_data_col),
std::move(res_nullmap_col));
}
template <typename FromDataType, typename ToDataType>
requires(IsDataTypeNumber<FromDataType> && IsDataTypeDecimal<ToDataType>)
- static inline bool convertNumberToDecimalImpl(
- const typename FromDataType::FieldType & value,
- UInt32 scale,
- typename ToDataType::FieldType::NativeType decimal_int_part_max,
- typename ToDataType::FieldType::NativeType decimal_int_part_min,
- typename ToDataType::FieldType & result)
+ static ALWAYS_INLINE bool convertNumberToDecimalImpl(
+ const typename FromDataType::FieldType & from,
+ const typename ToDataType::FieldType::NativeType & scale_multiplier,
+ const typename ToDataType::FieldType::NativeType & pow10_to_precision,
+ typename ToDataType::FieldType & to)
{
using FromFieldType = typename FromDataType::FieldType;
- using ToFieldNativeType = typename ToDataType::FieldType::NativeType;
- ToFieldNativeType int_part = 0;
- if constexpr (std::is_same_v<FromFieldType, Float32> ||
std::is_same_v<FromFieldType, Float64>)
- int_part = static_cast<ToFieldNativeType>(value);
+ using ToNativeType = typename ToDataType::FieldType::NativeType;
+
+ bool ok = false;
+ if constexpr (std::is_floating_point_v<FromFieldType>)
+ {
+ /// float to decimal
+ auto converted = from *
static_cast<FromFieldType>(scale_multiplier);
+ auto float_pow10_to_precision =
static_cast<FromFieldType>(pow10_to_precision);
+ ok = isFinite(from) && converted < float_pow10_to_precision &&
converted > -float_pow10_to_precision;
+ to = ok ? static_cast<ToNativeType>(converted) :
static_cast<ToNativeType>(0);
+ }
else
- int_part = value;
+ {
+ /// signed integer to decimal
+ using MaxNativeType = std::conditional_t<(sizeof(FromFieldType) >
sizeof(ToNativeType)), FromFieldType, ToNativeType>;
- return int_part >= decimal_int_part_min && int_part <=
decimal_int_part_max && tryConvertToDecimal<FromDataType, ToDataType>(value,
scale, result);
+ MaxNativeType converted = 0;
+ ok = !common::mulOverflow(static_cast<MaxNativeType>(from),
static_cast<MaxNativeType>(scale_multiplier), converted) && converted <
pow10_to_precision
+ && converted > -pow10_to_precision;
+ to = ok ? static_cast<ToNativeType>(converted) :
static_cast<ToNativeType>(0);
+ }
+ return ok;
}
- template <typename FromDataType, typename ToDataType>
+ template <ScaleDirection scale_direction, typename FromDataType, typename
ToDataType, typename MaxNativeType>
requires(IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>)
- static bool convertDecimalsImpl(
- const typename FromDataType::FieldType & value,
- UInt32 precision_to,
- UInt32 scale_from,
- UInt32 scale_to,
- typename ToDataType::FieldType & result)
+ static ALWAYS_INLINE bool convertDecimalToDecimalImpl(
+ const typename FromDataType::FieldType & from,
+ const MaxNativeType & scale_multiplier,
+ const MaxNativeType & pow10_to_precision,
+ typename ToDataType::FieldType & to)
{
using FromFieldType = typename FromDataType::FieldType;
using ToFieldType = typename ToDataType::FieldType;
- using MaxFieldType = std::conditional_t<(sizeof(FromFieldType) >
sizeof(ToFieldType)), FromFieldType, ToFieldType>;
- using MaxNativeType = typename MaxFieldType::NativeType;
+ using ToNativeType = typename ToFieldType::NativeType;
-
- auto false_value = []() -> bool
+ MaxNativeType converted;
+ bool ok = false;
+ if constexpr (scale_direction == ScaleDirection::Up)
{
- if constexpr (exception_mode == CheckExceptionMode::Null)
- return false;
- else
- throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal value
is overflow.");
- };
-
- MaxNativeType converted_value;
- if (scale_to > scale_from)
+ ok = !common::mulOverflow(static_cast<MaxNativeType>(from.value),
scale_multiplier, converted)
+ && converted < pow10_to_precision && converted >
-pow10_to_precision;
+ }
+ else if constexpr (scale_direction == ScaleDirection::None)
{
- converted_value =
DecimalUtils::scaleMultiplier<MaxNativeType>(scale_to - scale_from);
- if (common::mulOverflow(static_cast<MaxNativeType>(value.value),
converted_value, converted_value))
- return false_value();
+ converted = from.value;
+ ok = converted < pow10_to_precision && converted >
-pow10_to_precision;
}
- else if (scale_to == scale_from)
- converted_value = value.value;
else
- converted_value = value.value /
DecimalUtils::scaleMultiplier<MaxNativeType>(scale_from - scale_to);
-
- // if constexpr (sizeof(FromFieldType) > sizeof(ToFieldType))
- // {
- MaxNativeType pow10 = intExp10OfSize<MaxNativeType>(precision_to);
- if (converted_value <= -pow10 || converted_value >= pow10)
- return false_value();
- // }
+ {
+ converted = from.value / scale_multiplier;
+ ok = converted < pow10_to_precision && converted >
-pow10_to_precision;
+ }
- result = static_cast<typename
ToFieldType::NativeType>(converted_value);
- return true;
+ to = ok ? static_cast<ToNativeType>(converted) :
static_cast<ToNativeType>(0);
+ return ok;
}
};
diff --git
a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h
b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h
index beef342e08..eb2c4df4f4 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h
@@ -422,7 +422,6 @@ private:
}
};
-/// TODO(taiyang-li): implement JIT for binary deicmal arithmetic functions
template <class Operation, typename Name, OpMode mode = OpMode::Default>
class SparkFunctionDecimalBinaryArithmetic final : public IFunction
{
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionDivide.h
b/cpp-ch/local-engine/Functions/SparkFunctionDivide.h
index ccd9eeb61b..f924d559a5 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionDivide.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDivide.h
@@ -81,91 +81,117 @@ public:
DB::ColumnPtr
executeImpl(const DB::ColumnsWithTypeAndName & arguments, const
DB::DataTypePtr & result_type, size_t input_rows_count) const override
{
- if (arguments.size() != 2)
- throw
DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}'s
arguments number must be 2", name);
+ using L = Float64;
+ using R = Float64;
+ using T = Float64;
+
+ const DB::ColumnVector<L> * col_left = nullptr;
+ const DB::ColumnVector<R> * col_right = nullptr;
+ const DB::ColumnVector<L> * const_col_left =
checkAndGetColumnConstData<DB::ColumnVector<L>>(arguments[0].column.get());
+ const DB::ColumnVector<R> * const_col_right =
checkAndGetColumnConstData<DB::ColumnVector<R>>(arguments[1].column.get());
+
+ L left_const_val = 0;
+ if (const_col_left)
+ left_const_val = const_col_left->getElement(0);
+ else
+ col_left = assert_cast<const DB::ColumnVector<L>
*>(arguments[0].column.get());
+
+ R right_const_val = 0;
+ if (const_col_right)
+ {
+ right_const_val = const_col_right->getElement(0);
+ if (right_const_val == 0)
+ {
+ auto data_col = DB::ColumnVector<T>::create(1, 0);
+ auto null_map_col = DB::ColumnVector<UInt8>::create(1, 1);
+ return
DB::ColumnConst::create(DB::ColumnNullable::create(std::move(data_col),
std::move(null_map_col)), input_rows_count);
+ }
+ }
+ else
+ col_right = assert_cast<const DB::ColumnVector<R>
*>(arguments[1].column.get());
+
+ auto res_col = DB::ColumnVector<T>::create(input_rows_count, 0);
+ auto res_null_map = DB::ColumnVector<UInt8>::create(input_rows_count,
0);
+ DB::PaddedPODArray<T> & res_data = res_col->getData();
+ DB::PaddedPODArray<UInt8> & res_null_map_data =
res_null_map->getData();
+ vector(col_left, col_right, left_const_val, right_const_val, res_data,
res_null_map_data, input_rows_count);
+ return DB::ColumnNullable::create(std::move(res_col),
std::move(res_null_map));
+ }
- if (!isNativeNumber(arguments[0].type) ||
!isNativeNumber(arguments[1].type))
- throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Function {}'s arguments type must be native number", name);
-
- using Types = TypeList<
- DB::DataTypeFloat32,
- DB::DataTypeFloat64,
- DB::DataTypeUInt8,
- DB::DataTypeUInt16,
- DB::DataTypeUInt32,
- DB::DataTypeUInt64,
- DB::DataTypeInt8,
- DB::DataTypeInt16,
- DB::DataTypeInt32,
- DB::DataTypeInt64>;
-
- DB::ColumnPtr result = nullptr;
- bool valid = castTypeToEither(
- Types{},
- arguments[0].type.get(),
- [&](const auto & left_)
+ MULTITARGET_FUNCTION_AVX2_SSE42(
+ MULTITARGET_FUNCTION_HEADER(static void NO_SANITIZE_UNDEFINED
NO_INLINE),
+ vectorImpl,
+ MULTITARGET_FUNCTION_BODY(
+ (const DB::ColumnVector<Float64> * col_left,
+ const DB::ColumnVector<Float64> * col_right,
+ Float64 left_const_val,
+ Float64 right_const_val,
+ DB::PaddedPODArray<Float64> & res_data,
+ DB::PaddedPODArray<UInt8> & res_null_map_data,
+ size_t input_rows_count) /// NOLINT
{
- return castTypeToEither(
- Types{},
- arguments[1].type.get(),
- [&](const auto & right_)
+ if (col_left && col_right)
+ {
+ const auto & ldata = col_left->getData();
+ const auto & rdata = col_right->getData();
+
+ for (size_t i = 0; i < input_rows_count; ++i)
+ {
+ auto l = ldata[i];
+ auto r = rdata[i];
+ res_data[i] = SparkDivideFloatingImpl<Float64,
Float64>::apply(l, r ? r : 1);
+ res_null_map_data[i] = !rdata[i];
+ }
+ }
+ else if (col_left)
+ {
+ Float64 r = right_const_val;
+ for (size_t i = 0; i < input_rows_count; ++i)
{
- using L = typename
std::decay_t<decltype(left_)>::FieldType;
- using R = typename
std::decay_t<decltype(right_)>::FieldType;
- using T = typename
DB::NumberTraits::ResultOfFloatingPointDivision<L, R>::Type;
-
- const DB::ColumnVector<L> * col_left = nullptr;
- const DB::ColumnVector<R> * col_right = nullptr;
- const DB::ColumnVector<L> * const_col_left =
checkAndGetColumnConstData<DB::ColumnVector<L>>(arguments[0].column.get());
- const DB::ColumnVector<R> * const_col_right
- =
checkAndGetColumnConstData<DB::ColumnVector<R>>(arguments[1].column.get());
-
- L left_const_val = 0;
- if (const_col_left)
- left_const_val = const_col_left->getElement(0);
- else
- col_left = assert_cast<const DB::ColumnVector<L>
*>(arguments[0].column.get());
-
- R right_const_val = 0;
- if (const_col_right)
- {
- right_const_val = const_col_right->getElement(0);
- if (right_const_val == 0)
- {
- /// TODO(taiyang-li): return const column
instead
- auto data_col =
DB::ColumnVector<T>::create(arguments[0].column->size(), 0);
- auto null_map_col =
DB::ColumnVector<UInt8>::create(arguments[0].column->size(), 1);
- result =
DB::ColumnNullable::create(std::move(data_col), std::move(null_map_col));
- return true;
- }
- }
- else
- col_right = assert_cast<const DB::ColumnVector<R>
*>(arguments[1].column.get());
-
- auto res_values =
DB::ColumnVector<T>::create(input_rows_count, 0);
- auto res_null_map =
DB::ColumnVector<UInt8>::create(input_rows_count, 0);
- DB::PaddedPODArray<T> & res_data =
res_values->getData();
- DB::PaddedPODArray<UInt8> & res_null_map_data =
res_null_map->getData();
- for (size_t i = 0; i < input_rows_count; ++i)
- {
- L l = col_left ? col_left->getElement(i) :
left_const_val;
- R r = col_right ? col_right->getElement(i) :
right_const_val;
-
- /// TODO(taiyang-li): try to vectorize it
- if (r == 0)
- res_null_map_data[i] = 1;
- else
- res_data[i] = SparkDivideFloatingImpl<L,
R>::apply(l, r);
- }
-
- result =
DB::ColumnNullable::create(std::move(res_values), std::move(res_null_map));
- return true;
- });
- });
-
- if (!valid)
- throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Function {}'s arguments type is not valid", name);
- return result;
+ Float64 l = col_left->getData()[i];
+
+ /// r must not be zero because r = 0 is already
processed in fast path
+ /// No need to assign null_map_data[i] = 0, because it
is already 0
+ // res_null_map_data[i] = 0;
+ res_data[i] = SparkDivideFloatingImpl<Float64,
Float64>::apply(l, r);
+ }
+ }
+ else if (col_right)
+ {
+ Float64 l = left_const_val;
+ for (size_t i = 0; i < input_rows_count; ++i)
+ {
+ Float64 r = col_right->getData()[i];
+ res_null_map_data[i] = !r;
+ res_data[i] = SparkDivideFloatingImpl<Float64,
Float64>::apply(l, r ? r : 1);
+ }
+ }
+ }))
+
+ static void NO_INLINE vector(
+ const DB::ColumnVector<Float64> * col_left,
+ const DB::ColumnVector<Float64> * col_right,
+ Float64 left_const_val,
+ Float64 right_const_val,
+ DB::PaddedPODArray<Float64> & res_data,
+ DB::PaddedPODArray<UInt8> & res_null_map_data,
+ size_t input_rows_count)
+ {
+#if USE_MULTITARGET_CODE
+ if (isArchSupported(DB::TargetArch::AVX2))
+ {
+ vectorImplAVX2(col_left, col_right, left_const_val,
right_const_val, res_data, res_null_map_data, input_rows_count);
+ return;
+ }
+
+ if (isArchSupported(DB::TargetArch::SSE42))
+ {
+ vectorImplSSE42(col_left, col_right, left_const_val,
right_const_val, res_data, res_null_map_data, input_rows_count);
+ return;
+ }
+#endif
+
+ vectorImpl(col_left, col_right, left_const_val, right_const_val,
res_data, res_null_map_data, input_rows_count);
}
#if USE_EMBEDDED_COMPILER
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
index 5c6cfd0316..99bc5b3ff4 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
@@ -337,11 +337,12 @@ ExpressionParser::NodeRawConstPtr
ExpressionParser::parseExpression(ActionsDAG &
}
else if ((isDecimal(denull_input_type) ||
isNativeNumber(denull_input_type)) && substrait_type.has_decimal())
{
- int decimal_precision = substrait_type.decimal().precision();
- if (decimal_precision)
+ int precision = substrait_type.decimal().precision();
+ int scale = substrait_type.decimal().scale();
+ if (precision)
{
- args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeInt32>(), decimal_precision));
- args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeInt32>(), substrait_type.decimal().scale()));
+ args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeInt32>(), precision));
+ args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeInt32>(), scale));
result_node = toFunctionNode(actions_dag,
"checkDecimalOverflowSparkOrNull", args);
}
}
diff --git a/cpp-ch/local-engine/tests/CMakeLists.txt
b/cpp-ch/local-engine/tests/CMakeLists.txt
index 4617714b5e..56bf07aa06 100644
--- a/cpp-ch/local-engine/tests/CMakeLists.txt
+++ b/cpp-ch/local-engine/tests/CMakeLists.txt
@@ -97,7 +97,7 @@ if(ENABLE_BENCHMARKS)
benchmark_parquet_read.cpp
benchmark_spark_row.cpp
benchmark_unix_timestamp_function.cpp
- benchmark_spark_floor_function.cpp
+ benchmark_spark_functions.cpp
benchmark_cast_float_function.cpp
benchmark_to_datetime_function.cpp
benchmark_spark_divide_function.cpp
diff --git a/cpp-ch/local-engine/tests/benchmark_spark_floor_function.cpp
b/cpp-ch/local-engine/tests/benchmark_spark_functions.cpp
similarity index 82%
rename from cpp-ch/local-engine/tests/benchmark_spark_floor_function.cpp
rename to cpp-ch/local-engine/tests/benchmark_spark_functions.cpp
index 95da98d322..ce6f1b42ca 100644
--- a/cpp-ch/local-engine/tests/benchmark_spark_floor_function.cpp
+++ b/cpp-ch/local-engine/tests/benchmark_spark_functions.cpp
@@ -18,10 +18,7 @@
#if defined(__x86_64__)
#include <cstddef>
-#if USE_MULTITARGET_CODE
-#include <immintrin.h>
-#endif
-
+#include <Columns/ColumnsCommon.h>
#include <Columns/IColumn.h>
#include <Core/Block.h>
#include <DataTypes/DataTypeArray.h>
@@ -35,7 +32,11 @@
#include <benchmark/benchmark.h>
#include <Common/QueryContext.h>
#include <Common/TargetSpecific.h>
-#include <Columns/ColumnsCommon.h>
+#include <DataTypes/DataTypeNullable.h>
+
+#if USE_MULTITARGET_CODE
+#include <immintrin.h>
+#endif
using namespace DB;
@@ -44,7 +45,7 @@ static IColumn::Offsets createOffsets(size_t rows)
IColumn::Offsets offsets(rows, 0);
for (size_t i = 0; i < rows; ++i)
offsets[i] = offsets[i-1] + (rand() % 10);
- return std::move(offsets);
+ return offsets;
}
static ColumnPtr createColumn(const DataTypePtr & type, size_t rows)
@@ -79,6 +80,11 @@ static ColumnPtr createColumn(const DataTypePtr & type,
size_t rows)
double d = i * 1.0;
column->insert(d);
}
+ else if (isDecimal(type_not_nullable))
+ {
+ Decimal128 d = Decimal128(i * i);
+ column->insert(d);
+ }
else if (isString(type_not_nullable))
{
String s = "helloworld";
@@ -158,6 +164,212 @@ static void
BM_SparkFloorFunction_For_Float64(benchmark::State & state)
}
}
+BENCHMARK(BM_CHFloorFunction_For_Int64);
+BENCHMARK(BM_CHFloorFunction_For_Float64);
+BENCHMARK(BM_SparkFloorFunction_For_Int64);
+BENCHMARK(BM_SparkFloorFunction_For_Float64);
+
+static void BM_OptSparkDivide_VectorVector(benchmark::State & state)
+{
+ using namespace DB;
+ auto & factory = FunctionFactory::instance();
+ auto function = factory.get("sparkDivide",
local_engine::QueryContext::globalContext());
+ auto type = DataTypeFactory::instance().get("Nullable(Float64)");
+ auto left = createColumn(type, 65536);
+ auto right = createColumn(type, 65536);
+ auto block = Block({ColumnWithTypeAndName(left, type, "left"),
ColumnWithTypeAndName(right, type, "right")});
+ auto executable = function->build(block.getColumnsWithTypeAndName());
+ for (auto _ : state)
+ {
+ auto result = executable->execute(block.getColumnsWithTypeAndName(),
executable->getResultType(), block.rows(), false);
+ benchmark::DoNotOptimize(result);
+ }
+}
+
+static void BM_OptSparkDivide_VectorConstant(benchmark::State & state)
+{
+ using namespace DB;
+ auto & factory = FunctionFactory::instance();
+ auto function = factory.get("sparkDivide",
local_engine::QueryContext::globalContext());
+ auto type = DataTypeFactory::instance().get("Nullable(Float64)");
+ auto left = createColumn(type, 65536);
+ auto right = createColumn(type, 1);
+ auto const_right = ColumnConst::create(std::move(right), 65536);
+ auto block = Block({ColumnWithTypeAndName(left, type, "left"),
ColumnWithTypeAndName(std::move(const_right), type, "right")});
+ auto executable = function->build(block.getColumnsWithTypeAndName());
+ for (auto _ : state)
+ {
+ auto result = executable->execute(block.getColumnsWithTypeAndName(),
executable->getResultType(), block.rows(), false);
+ benchmark::DoNotOptimize(result);
+ }
+}
+
+static void BM_OptSparkDivide_ConstantVector(benchmark::State & state)
+{
+ using namespace DB;
+ auto & factory = FunctionFactory::instance();
+ auto function = factory.get("sparkDivide",
local_engine::QueryContext::globalContext());
+ auto type = DataTypeFactory::instance().get("Nullable(Float64)");
+ auto left = createColumn(type, 1);
+ auto const_left = ColumnConst::create(std::move(left), 65536);
+ auto right = createColumn(type, 65536);
+ auto block = Block({ColumnWithTypeAndName(std::move(const_left), type,
"left"), ColumnWithTypeAndName(std::move(right), type, "right")});
+ auto executable = function->build(block.getColumnsWithTypeAndName());
+ for (auto _ : state)
+ {
+ auto result = executable->execute(block.getColumnsWithTypeAndName(),
executable->getResultType(), block.rows(), false);
+ benchmark::DoNotOptimize(result);
+ }
+}
+
+BENCHMARK(BM_OptSparkDivide_VectorVector);
+BENCHMARK(BM_OptSparkDivide_VectorConstant);
+BENCHMARK(BM_OptSparkDivide_ConstantVector);
+
+static void BM_OptSparkCastFloatToInt(benchmark::State & state)
+{
+ using namespace DB;
+ auto & factory = FunctionFactory::instance();
+ auto function = factory.get("sparkCastFloatToInt32",
local_engine::QueryContext::globalContext());
+ auto type = DataTypeFactory::instance().get("Nullable(Float64)");
+ auto input = createColumn(type, 65536);
+ auto block = Block({ColumnWithTypeAndName(std::move(input), type,
"input")});
+ auto executable = function->build(block.getColumnsWithTypeAndName());
+ for (auto _ : state)
+ {
+ auto result = executable->execute(block.getColumnsWithTypeAndName(),
executable->getResultType(), block.rows(), false);
+ benchmark::DoNotOptimize(result);
+ }
+}
+
+BENCHMARK(BM_OptSparkCastFloatToInt);
+
+/// decimal to decimal, scale up
+static void BM_OptCheckDecimalOverflowSparkFromDecimal1(benchmark::State &
state)
+{
+ using namespace DB;
+ auto & factory = FunctionFactory::instance();
+ auto function = factory.get("checkDecimalOverflowSparkOrNull",
local_engine::QueryContext::globalContext());
+ auto type = DataTypeFactory::instance().get("Nullable(Decimal128(10))");
+
+ auto input = createColumn(type, 65536);
+ auto precision = ColumnConst::create(ColumnUInt32::create(1, 38), 65536);
+ auto scale = ColumnConst::create(ColumnUInt32::create(1, 5), 65536);
+
+ auto block = Block(
+ {ColumnWithTypeAndName(std::move(input), type, "input"),
+ ColumnWithTypeAndName(std::move(precision),
std::make_shared<DataTypeUInt32>(), "precision"),
+ ColumnWithTypeAndName(std::move(scale),
std::make_shared<DataTypeUInt32>(), "scale")});
+ auto executable = function->build(block.getColumnsWithTypeAndName());
+ for (auto _ : state)
+ {
+ auto result = executable->execute(block.getColumnsWithTypeAndName(),
executable->getResultType(), block.rows(), false);
+ benchmark::DoNotOptimize(result);
+ }
+}
+
+/// decimal to decimal, scale down
+static void BM_OptCheckDecimalOverflowSparkFromDecimal2(benchmark::State &
state)
+{
+ using namespace DB;
+ auto & factory = FunctionFactory::instance();
+ auto function = factory.get("checkDecimalOverflowSparkOrNull",
local_engine::QueryContext::globalContext());
+ auto type = DataTypeFactory::instance().get("Nullable(Decimal128(10))");
+
+ auto input = createColumn(type, 65536);
+ auto precision = ColumnConst::create(ColumnUInt32::create(1, 38), 65536);
+ auto scale = ColumnConst::create(ColumnUInt32::create(1, 15), 65536);
+
+ auto block = Block(
+ {ColumnWithTypeAndName(std::move(input), type, "input"),
+ ColumnWithTypeAndName(std::move(precision),
std::make_shared<DataTypeUInt32>(), "precision"),
+ ColumnWithTypeAndName(std::move(scale),
std::make_shared<DataTypeUInt32>(), "scale")});
+ auto executable = function->build(block.getColumnsWithTypeAndName());
+ for (auto _ : state)
+ {
+ auto result = executable->execute(block.getColumnsWithTypeAndName(),
executable->getResultType(), block.rows(), false);
+ benchmark::DoNotOptimize(result);
+ }
+}
+
+/// decimal to decimal, scale doesn't change
+static void BM_OptCheckDecimalOverflowSparkFromDecimal3(benchmark::State &
state)
+{
+ using namespace DB;
+ auto & factory = FunctionFactory::instance();
+ auto function = factory.get("checkDecimalOverflowSparkOrNull",
local_engine::QueryContext::globalContext());
+ auto type = DataTypeFactory::instance().get("Nullable(Decimal(38, 10))");
+
+ auto input = createColumn(type, 65536);
+ auto precision = ColumnConst::create(ColumnUInt32::create(1, 38), 65536);
+ auto scale = ColumnConst::create(ColumnUInt32::create(1, 10), 65536);
+
+ auto block = Block(
+ {ColumnWithTypeAndName(std::move(input), type, "input"),
+ ColumnWithTypeAndName(std::move(precision),
std::make_shared<DataTypeUInt32>(), "precision"),
+ ColumnWithTypeAndName(std::move(scale),
std::make_shared<DataTypeUInt32>(), "scale")});
+ auto executable = function->build(block.getColumnsWithTypeAndName());
+ for (auto _ : state)
+ {
+ auto result = executable->execute(block.getColumnsWithTypeAndName(),
executable->getResultType(), block.rows(), false);
+ benchmark::DoNotOptimize(result);
+ }
+}
+
+/// int to decimal
+static void BM_OptCheckDecimalOverflowSparkFromInt(benchmark::State & state)
+{
+ using namespace DB;
+ auto & factory = FunctionFactory::instance();
+ auto function = factory.get("checkDecimalOverflowSparkOrNull",
local_engine::QueryContext::globalContext());
+ auto type = DataTypeFactory::instance().get("Nullable(Int64)");
+
+ auto input = createColumn(type, 65536);
+ auto precision = ColumnConst::create(ColumnUInt32::create(1, 38), 65536);
+ auto scale = ColumnConst::create(ColumnUInt32::create(1, 10), 65536);
+
+ auto block = Block(
+ {ColumnWithTypeAndName(std::move(input), type, "input"),
+ ColumnWithTypeAndName(std::move(precision),
std::make_shared<DataTypeUInt32>(), "precision"),
+ ColumnWithTypeAndName(std::move(scale),
std::make_shared<DataTypeUInt32>(), "scale")});
+ auto executable = function->build(block.getColumnsWithTypeAndName());
+ for (auto _ : state)
+ {
+ auto result = executable->execute(block.getColumnsWithTypeAndName(),
executable->getResultType(), block.rows(), false);
+ benchmark::DoNotOptimize(result);
+ }
+}
+
+/// float to decimal
+static void BM_OptCheckDecimalOverflowSparkFromFloat(benchmark::State & state)
+{
+ using namespace DB;
+ auto & factory = FunctionFactory::instance();
+ auto function = factory.get("checkDecimalOverflowSparkOrNull",
local_engine::QueryContext::globalContext());
+ auto type = DataTypeFactory::instance().get("Nullable(Float64)");
+
+ auto input = createColumn(type, 65536);
+ auto precision = ColumnConst::create(ColumnUInt32::create(1, 38), 65536);
+ auto scale = ColumnConst::create(ColumnUInt32::create(1, 10), 65536);
+
+ auto block = Block(
+ {ColumnWithTypeAndName(std::move(input), type, "input"),
+ ColumnWithTypeAndName(std::move(precision),
std::make_shared<DataTypeUInt32>(), "precision"),
+ ColumnWithTypeAndName(std::move(scale),
std::make_shared<DataTypeUInt32>(), "scale")});
+ auto executable = function->build(block.getColumnsWithTypeAndName());
+ for (auto _ : state)
+ {
+ auto result = executable->execute(block.getColumnsWithTypeAndName(),
executable->getResultType(), block.rows(), false);
+ benchmark::DoNotOptimize(result);
+ }
+}
+
+BENCHMARK(BM_OptCheckDecimalOverflowSparkFromDecimal1);
+BENCHMARK(BM_OptCheckDecimalOverflowSparkFromDecimal2);
+BENCHMARK(BM_OptCheckDecimalOverflowSparkFromDecimal3);
+BENCHMARK(BM_OptCheckDecimalOverflowSparkFromInt);
+BENCHMARK(BM_OptCheckDecimalOverflowSparkFromFloat);
+
static void nanInfToNullAutoOpt(float * data, uint8_t * null_map, size_t size)
{
for (size_t i = 0; i < size; ++i)
@@ -267,10 +479,6 @@ static void BMNanInfToNull(benchmark::State & state)
}
BENCHMARK(BMNanInfToNull);
-BENCHMARK(BM_CHFloorFunction_For_Int64);
-BENCHMARK(BM_CHFloorFunction_For_Float64);
-BENCHMARK(BM_SparkFloorFunction_For_Int64);
-BENCHMARK(BM_SparkFloorFunction_For_Float64);
/*
@@ -1278,4 +1486,5 @@ BENCHMARK_TEMPLATE(BM_myFilterToIndicesAVX512, UInt32);
BENCHMARK_TEMPLATE(BM_myFilterToIndicesAVX512, UInt64);
*/
+
#endif
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index ff1d688472..3c1429a34f 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -540,6 +540,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
)
// test for sort node not present but gluten uses shuffle hash join
.exclude("SPARK-41048: Improve output partitioning and ordering with AQE
cache")
+ .exclude("SPARK-28224: Aggregate sum big decimal overflow")
// Rewrite this test since it checks the physical operator which is
changed in Gluten
.excludeCH("SPARK-27439: Explain result should match collected result
after view change")
.excludeCH("SPARK-28067: Aggregate sum should not return wrong results for
decimal overflow")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]