This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 6b1d63c81 [GLUTEN-6975][CH] Fix decimal cast overflow exception
6b1d63c81 is described below
commit 6b1d63c81f20a25fee6bdc715b3308237cbb29ed
Author: Shuai li <[email protected]>
AuthorDate: Wed Sep 25 10:52:18 2024 +0800
[GLUTEN-6975][CH] Fix decimal cast overflow exception
---
.../execution/GlutenClickHouseDecimalSuite.scala | 18 +++++
.../SparkFunctionDecimalBinaryArithmetic.cpp | 82 +++++++++++++---------
.../local-engine/Parser/SerializedPlanParser.cpp | 7 ++
3 files changed, 73 insertions(+), 34 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
index f772c909c..c0a32a034 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
@@ -335,6 +335,24 @@ class GlutenClickHouseDecimalSuite
spark.sql("drop table if exists decimals_test")
}
}
+
+ test("test castornull") {
+ // prepare
+ val createSql =
+ "create table decimals_cast_test(a decimal(18,8)) using parquet"
+ val inserts =
+ "insert into decimals_cast_test values(123456789.12345678)"
+ spark.sql(createSql)
+
+ try {
+ spark.sql(inserts)
+ val q1 = "select cast(a as decimal(9,2)) from decimals_cast_test"
+ compareResultsAgainstVanillaSpark(q1, compareResult = true, _ => {})
+ } finally {
+ spark.sql("drop table if exists decimals_cast_test")
+ }
+ }
+
// FIXME: Support AVG for Decimal Type
Seq("true", "false").foreach {
allowPrecisionLoss =>
diff --git
a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp
index 26d8e0deb..a2d5dec73 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp
@@ -121,23 +121,32 @@ public:
if constexpr (Mode == OpMode::Effect)
{
- return executeDecimalImpl<LeftDataType, RightDataType,
ResultDataType>(
+ return executeDecimalImpl<LeftDataType, RightDataType,
ResultDataType, NativeType<typename ResultDataType::FieldType>>(
left, right, col_left_const, col_right_const, col_left,
col_right, col_left_size, result);
}
if (calculateWith256<is_plus_minus, is_multiply, is_division,
is_modulo>(*arguments[0].type.get(), *arguments[1].type.get()))
{
- return executeDecimalImpl<LeftDataType, RightDataType,
ResultDataType, true>(
+ return executeDecimalImpl<LeftDataType, RightDataType,
ResultDataType, Int256, true>(
left, right, col_left_const, col_right_const, col_left,
col_right, col_left_size, result);
}
- return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType>(
- left, right, col_left_const, col_right_const, col_left, col_right,
col_left_size, result);
+ size_t max_scale = getMaxScaled(left.getScale(), right.getScale(),
result.getScale());
+ if (is_division && max_scale - left.getScale() + max_scale >
DataTypeDecimal<typename ResultDataType::FieldType>::maxPrecision())
+ {
+ return executeDecimalImpl<LeftDataType, RightDataType,
ResultDataType, Int256, true>(
+ left, right, col_left_const, col_right_const, col_left,
col_right, col_left_size, result);
+ }
+ else
+ {
+ return executeDecimalImpl<LeftDataType, RightDataType,
ResultDataType, NativeType<typename ResultDataType::FieldType>>(
+ left, right, col_left_const, col_right_const, col_left,
col_right, col_left_size, result);
+ }
}
private:
// ResultDataType e.g. DataTypeDecimal<Decimal32>
- template <class LeftDataType, class RightDataType, class ResultDataType,
bool CalculateWith256 = false>
+ template <class LeftDataType, class RightDataType, class ResultDataType,
class ScaleDataType, bool CalculateWith256 = false>
static ColumnPtr executeDecimalImpl(
const auto & left,
const auto & right,
@@ -152,34 +161,29 @@ private:
using RightFieldType = typename RightDataType::FieldType;
using ResultFieldType = typename ResultDataType::FieldType;
- using NativeResultType = NativeType<ResultFieldType>;
using ColVecResult = ColumnVectorOrDecimal<ResultFieldType>;
- size_t max_scale;
- if constexpr (is_multiply)
- max_scale = left.getScale() + right.getScale();
- else
- max_scale = std::max(resultDataType.getScale(),
std::max(left.getScale(), right.getScale()));
+ size_t max_scale = getMaxScaled(left.getScale(), right.getScale(),
resultDataType.getScale());
- NativeResultType scale_left = [&]
+ ScaleDataType scale_left = [&]
{
if constexpr (is_multiply)
- return NativeResultType{1};
+ return ScaleDataType{1};
// cast scale same to left
auto diff_scale = max_scale - left.getScale();
if constexpr (is_division)
- return
DecimalUtils::scaleMultiplier<NativeResultType>(diff_scale + max_scale);
+ return DecimalUtils::scaleMultiplier<ScaleDataType>(diff_scale
+ max_scale);
else
- return
DecimalUtils::scaleMultiplier<NativeResultType>(diff_scale);
+ return
DecimalUtils::scaleMultiplier<ScaleDataType>(diff_scale);
}();
- const NativeResultType scale_right = [&]
+ const ScaleDataType scale_right = [&]
{
if constexpr (is_multiply)
- return NativeResultType{1};
+ return ScaleDataType{1};
else
- return
DecimalUtils::scaleMultiplier<NativeResultType>(max_scale - right.getScale());
+ return DecimalUtils::scaleMultiplier<ScaleDataType>(max_scale
- right.getScale());
}();
@@ -266,17 +270,19 @@ private:
return ColumnNullable::create(std::move(col_res),
std::move(col_null_map_to));
}
- template <OpCase op_case, bool CalculateWith256, typename
ResultContainerType, typename NativeResultType, typename ResultDataType>
+ template <OpCase op_case, bool CalculateWith256, typename
ResultContainerType, typename ResultDataType, typename ScaleDataType>
static static void NO_INLINE process(
const auto & a,
const auto & b,
ResultContainerType & result_container,
- const NativeResultType & scale_a,
- const NativeResultType & scale_b,
+ const ScaleDataType & scale_a,
+ const ScaleDataType & scale_b,
ColumnUInt8::Container & vec_null_map_to,
const ResultDataType & resultDataType,
const size_t & max_scale)
{
+ using NativeResultType = NativeType<typename
ResultDataType::FieldType>;
+
size_t size;
if constexpr (op_case == OpCase::LeftConstant)
size = b.size();
@@ -303,14 +309,14 @@ private:
}
else if constexpr (op_case == OpCase::LeftConstant)
{
- auto scaled_a = applyScaled(unwrap<op_case,
OpCase::LeftConstant>(a, 0), scale_a);
+ ScaleDataType scaled_a = applyScaled(unwrap<op_case,
OpCase::LeftConstant>(a, 0), scale_a);
for (size_t i = 0; i < size; ++i)
{
NativeResultType res;
if (calculate<CalculateWith256>(
scaled_a,
unwrap<op_case, OpCase::RightConstant>(b, i),
- static_cast<NativeResultType>(0),
+ static_cast<ScaleDataType>(0),
scale_b,
res,
resultDataType,
@@ -322,7 +328,7 @@ private:
}
else if constexpr (op_case == OpCase::RightConstant)
{
- auto scaled_b = applyScaled(unwrap<op_case,
OpCase::RightConstant>(b, 0), scale_b);
+ ScaleDataType scaled_b = applyScaled(unwrap<op_case,
OpCase::RightConstant>(b, 0), scale_b);
for (size_t i = 0; i < size; ++i)
{
@@ -331,7 +337,7 @@ private:
unwrap<op_case, OpCase::LeftConstant>(a, i),
scaled_b,
scale_a,
- static_cast<NativeResultType>(0),
+ static_cast<ScaleDataType>(0),
res,
resultDataType,
max_scale))
@@ -343,12 +349,12 @@ private:
}
// ResultNativeType = Int32/64/128/256
- template <bool CalculateWith256, typename LeftNativeType, typename
RightNativeType, typename NativeResultType, typename ResultDataType>
+ template <bool CalculateWith256, typename LeftNativeType, typename
RightNativeType, typename NativeResultType, typename ResultDataType, typename
ScaleDataType>
static NO_SANITIZE_UNDEFINED bool calculate(
const LeftNativeType l,
const RightNativeType r,
- const NativeResultType & scale_left,
- const NativeResultType & scale_right,
+ const ScaleDataType & scale_left,
+ const ScaleDataType & scale_right,
NativeResultType & res,
const ResultDataType & resultDataType,
const size_t & max_scale)
@@ -361,12 +367,12 @@ private:
return calculateImpl<NativeResultType>(l, r, scale_left,
scale_right, res, resultDataType, max_scale);
}
- template <typename CalcType, typename LeftNativeType, typename
RightNativeType, typename NativeResultType, typename ResultDataType>
+ template <typename CalcType, typename LeftNativeType, typename
RightNativeType, typename NativeResultType, typename ResultDataType, typename
ScaleDataType>
static NO_SANITIZE_UNDEFINED bool calculateImpl(
const LeftNativeType & l,
const RightNativeType & r,
- const NativeResultType & scale_left,
- const NativeResultType & scale_right,
+ const ScaleDataType & scale_left,
+ const ScaleDataType & scale_right,
NativeResultType & res,
const ResultDataType & resultDataType,
const size_t & max_scale)
@@ -410,13 +416,21 @@ private:
return elem[i].value;
}
- template <typename NativeType, typename ResultNativeType>
- static ResultNativeType applyScaled(const NativeType & l, const
ResultNativeType & scale)
+ template <typename NativeType, typename ScaleType>
+ static ScaleType applyScaled(const NativeType & l, const ScaleType & scale)
{
if (scale > 1)
return common::mulIgnoreOverflow(l, scale);
- return static_cast<ResultNativeType>(l);
+ return static_cast<ScaleType>(l);
+ }
+
+ static size_t getMaxScaled(const size_t left_scale, const size_t
right_scale, const size_t result_scale)
+ {
+ if constexpr (is_multiply)
+ return left_scale + right_scale;
+ else
+ return std::max(result_scale, std::max(left_scale, right_scale));
}
};
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index 75ba2a115..9f90ae364 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -1032,6 +1032,13 @@ const ActionsDAG::Node *
SerializedPlanParser::parseExpression(ActionsDAG & acti
String function_name = "sparkCastFloatTo" +
non_nullable_output_type->getName();
function_node = toFunctionNode(actions_dag, function_name,
args);
}
+ else if ((isDecimal(non_nullable_input_type) &&
substrait_type.has_decimal()))
+ {
+ args.emplace_back(addColumn(actions_dag,
std::make_shared<DataTypeInt32>(), substrait_type.decimal().precision()));
+ args.emplace_back(addColumn(actions_dag,
std::make_shared<DataTypeInt32>(), substrait_type.decimal().scale()));
+
+ function_node = toFunctionNode(actions_dag,
"checkDecimalOverflowSparkOrNull", args);
+ }
else
{
if (isString(non_nullable_input_type) &&
isInt(non_nullable_output_type))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]