This is an automated email from the ASF dual-hosted git repository.
liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new d79250acc [GLUTEN-3452][CH]Bug fix decimal divide (#4951)
d79250acc is described below
commit d79250acc1fe9b450ef37358b1e61ae527bec844
Author: KevinyhZou <[email protected]>
AuthorDate: Tue Mar 19 17:57:35 2024 +0800
[GLUTEN-3452][CH]Bug fix decimal divide (#4951)
What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)
(Fixes: #3452)
How was this patch tested?
TEST BY UT
(If this patch involves UI changes, please attach a screenshot; otherwise,
remove this)
---
.../execution/GlutenClickHouseHiveTableSuite.scala | 15 ++
.../Functions/SparkFunctionDecimalDivide.cpp | 28 ++++
.../Functions/SparkFunctionDecimalDivide.h | 176 +++++++++++++++++++++
.../Parser/scalar_function_parser/divide.cpp | 5 +-
4 files changed, 223 insertions(+), 1 deletion(-)
diff --git
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala
index b40a0fe0d..e482746ef 100644
---
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala
+++
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala
@@ -1237,4 +1237,19 @@ class GlutenClickHouseHiveTableSuite
}
}
}
+
+ test("GLUTEN-3452: Bug fix decimal divide") {
+ withSQLConf((SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key, "false")) {
+ val table_create_sql =
+ """
+ | create table test_tbl_3452(d1 decimal(12,2), d2 decimal(15,3))
stored as parquet;
+ |""".stripMargin
+ val data_insert_sql = "insert into test_tbl_3452 values(13.0, 0),(11,
NULL), (12.3, 200)"
+ val select_sql = "select d1/d2, d1/0, d1/cast(0 as decimal) from
test_tbl_3452"
+ spark.sql(table_create_sql);
+ spark.sql(data_insert_sql)
+ compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
+ spark.sql("drop table test_tbl_3452")
+ }
+ }
}
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionDecimalDivide.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionDecimalDivide.cpp
new file mode 100644
index 000000000..0de3b757e
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalDivide.cpp
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <Functions/SparkFunctionDecimalDivide.h>
+
+namespace local_engine
+{
+
+REGISTER_FUNCTION(SparkFunctionDecimalDivide)
+{
+ factory.registerFunction<SparkFunctionDecimalDivide<DivideDecimalsImpl>>();
+}
+
+}
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionDecimalDivide.h
b/cpp-ch/local-engine/Functions/SparkFunctionDecimalDivide.h
new file mode 100644
index 000000000..1b93e77d9
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalDivide.h
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <DataTypes/DataTypeNullable.h>
+#include <Functions/FunctionsDecimalArithmetics.h>
+#include <Functions/FunctionFactory.h>
+#include <Columns/ColumnNullable.h>
+#include <Columns/ColumnVector.h>
+#include <Columns/ColumnDecimal.h>
+
+using namespace DB;
+
+namespace DB
+{
+namespace ErrorCodes
+{
+ extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+ extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+}
+}
+
+namespace local_engine
+{
+struct DivideDecimalsImpl
+{
+ static constexpr auto name = "sparkDivideDecimal";
+ template <typename FirstType, typename SecondType>
+ static inline Decimal256
+ execute(FirstType a, SecondType b, UInt16 scale_a, UInt16 scale_b, UInt16
result_scale)
+ {
+ if (a.value == 0 || b.value == 0)
+ return Decimal256(0);
+
+ Int256 sign_a = a.value < 0 ? -1 : 1;
+ Int256 sign_b = b.value < 0 ? -1 : 1;
+
+ std::vector<UInt8> a_digits = DecimalOpHelpers::toDigits(a.value *
sign_a);
+
+ while (scale_a < scale_b + result_scale)
+ {
+ a_digits.push_back(0);
+ ++scale_a;
+ }
+
+ while (scale_a > scale_b + result_scale && !a_digits.empty())
+ {
+ a_digits.pop_back();
+ --scale_a;
+ }
+
+ if (a_digits.empty())
+ return Decimal256(0);
+
+ std::vector<UInt8> divided = DecimalOpHelpers::divide(a_digits,
b.value * sign_b);
+
+ if (divided.size() > DecimalUtils::max_precision<Decimal256>)
+ throw DB::Exception(ErrorCodes::DECIMAL_OVERFLOW, "Numeric
overflow: result bigger that Decimal256");
+ return Decimal256(sign_a * sign_b *
DecimalOpHelpers::fromDigits(divided));
+ }
+};
+
+template <typename Transform>
+class SparkFunctionDecimalDivide : public
FunctionsDecimalArithmetics<Transform>
+{
+public:
+ static constexpr auto name = Transform::name;
+ static DB::FunctionPtr create(DB::ContextPtr) { return
std::make_shared<SparkFunctionDecimalDivide>(); }
+ SparkFunctionDecimalDivide() = default;
+ ~SparkFunctionDecimalDivide() override = default;
+ String getName() const override { return name; }
+
+ DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments)
const override
+ {
+ return
makeNullable(FunctionsDecimalArithmetics<Transform>::getReturnTypeImpl(arguments));
+ }
+
+ ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const
DataTypePtr & result_type, size_t input_rows) const override
+ {
+ if (arguments.size() != 2)
+ throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} must have 2 arugments.", name);
+
+ ColumnPtr res_col = nullptr;
+ MutableColumnPtr null_map_col = ColumnUInt8::create(input_rows, 0);
+ auto getNonNullableColumn = [&](const ColumnPtr & col) -> const
ColumnPtr
+ {
+ if (col->isNullable())
+ {
+ auto * nullable_col = checkAndGetColumn<const
ColumnNullable>(col.get());
+ return nullable_col->getNestedColumnPtr();
+ }
+ else
+ return col;
+ };
+
+ ColumnWithTypeAndName new_arg0
{getNonNullableColumn(arguments[0].column), removeNullable(arguments[0].type),
arguments[0].name};
+ ColumnWithTypeAndName new_arg1
{getNonNullableColumn(arguments[1].column), removeNullable(arguments[1].type),
arguments[0].name};
+ ColumnsWithTypeAndName new_args {new_arg0, new_arg1};
+ bool arg_type_valid = true;
+
+ if (isDecimal(new_arg1.type))
+ {
+ using Types = TypeList<DataTypeDecimal32, DataTypeDecimal64,
DataTypeDecimal128, DataTypeDecimal256>;
+ arg_type_valid = castTypeToEither(Types{}, new_arg1.type.get(),
[&](const auto & right_)
+ {
+ using R = typename std::decay_t<decltype(right_)>::FieldType;
+ const ColumnDecimal<R> * const_col_right =
checkAndGetColumnConstData<ColumnDecimal<R>>(new_arg1.column.get());
+ if (const_col_right && const_col_right->getElement(0).value ==
0)
+ {
+ null_map_col = ColumnUInt8::create(input_rows, 1);
+ res_col = ColumnDecimal<Decimal256>::create(input_rows, 0);
+ }
+ else
+ res_col =
FunctionsDecimalArithmetics<Transform>::executeImpl(new_args,
removeNullable(result_type), input_rows);
+
+ if (!const_col_right)
+ {
+ const ColumnDecimal<R> * col_right = assert_cast<const
ColumnDecimal<R> *>(new_arg1.column.get());
+ PaddedPODArray<UInt8> & null_map =
assert_cast<ColumnVector<UInt8>*>(null_map_col.get())->getData();
+ for (size_t i = 0; i < col_right->size(); ++i)
+ null_map[i] = (col_right->getElement(i).value == 0 ||
arguments[1].column->isNullAt(i));
+ }
+ return true;
+ });
+ }
+ else if (isNumber(new_arg1.type))
+ {
+ using Types = TypeList<DataTypeFloat32,
DataTypeFloat64,DataTypeUInt8, DataTypeUInt16, DataTypeUInt32,
+ DataTypeUInt64, DataTypeUInt128, DataTypeUInt256,
DataTypeInt8, DataTypeInt16, DataTypeInt32,
+ DataTypeInt64, DataTypeInt128, DataTypeInt256>;
+ arg_type_valid = castTypeToEither(Types{}, new_arg1.type.get(),
[&](const auto & right_)
+ {
+ using R = typename std::decay_t<decltype(right_)>::FieldType;
+ const ColumnVector<R> * const_col_right =
checkAndGetColumnConstData<ColumnVector<R>>(new_arg1.column.get());
+ if (const_col_right && const_col_right->getElement(0) == 0)
+ {
+ null_map_col = ColumnUInt8::create(input_rows, 1);
+ res_col = ColumnDecimal<Decimal256>::create(input_rows, 0);
+ }
+ else
+ res_col =
FunctionsDecimalArithmetics<Transform>::executeImpl(new_args,
removeNullable(result_type), input_rows);
+
+ if (!const_col_right)
+ {
+ const ColumnVector<R> * col_right = assert_cast<const
ColumnVector<R> *>(new_arg1.column.get());
+ PaddedPODArray<UInt8> & null_map =
assert_cast<ColumnVector<UInt8>*>(null_map_col.get())->getData();
+ for (size_t i = 0; i < col_right->size(); ++i)
+ null_map[i] = (col_right->getElement(i) == 0 ||
arguments[1].column->isNullAt(i));
+ }
+ return true;
+ });
+ }
+ else
+ throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function
{}'s arguments type must be numbeic", name);
+
+ if (!arg_type_valid)
+ throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function
{}'s arguments type is not valid.", name);
+
+ return ColumnNullable::create(res_col, std::move(null_map_col));
+ }
+
+};
+}
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/divide.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/divide.cpp
index cddddf307..71b695ac3 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/divide.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/divide.cpp
@@ -56,7 +56,10 @@ public:
const auto * left_arg = new_args[0];
const auto * right_arg = new_args[1];
-
+
+ if (isDecimal(removeNullable(left_arg->result_type)) ||
isDecimal(removeNullable(right_arg->result_type)))
+ return toFunctionNode(actions_dag, "sparkDivideDecimal",
{left_arg, right_arg});
+
const auto * divide_node = toFunctionNode(actions_dag, "divide",
{left_arg, right_arg});
DataTypePtr result_type = divide_node->result_type;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]