This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 55671cbf7c [GLUTEN-6975][CH] Rewrite decimal arithmetic (#7196)
55671cbf7c is described below
commit 55671cbf7c8bcd945b97869110872bf949589438
Author: Shuai li <[email protected]>
AuthorDate: Mon Sep 23 14:22:21 2024 +0800
[GLUTEN-6975][CH] Rewrite decimal arithmetic (#7196)
* enable tpchq1
---
.../backendsapi/clickhouse/CHListenerApi.scala | 4 +-
.../execution/GlutenClickHouseDecimalSuite.scala | 7 +-
.../AggregateFunctionSparkAvg.cpp | 36 +-
cpp-ch/local-engine/Common/GlutenDecimalUtils.h | 7 -
.../SparkFunctionDecimalBinaryArithmetic.cpp | 591 +++++++++++++++++++++
.../SparkFunctionDecimalBinaryArithmetic.h | 276 ++++++++++
cpp-ch/local-engine/Parser/SerializedPlanParser.h | 1 +
.../Parser/scalar_function_parser/arithmetic.cpp | 160 ++++--
.../gluten/expression/ExpressionConverter.scala | 16 +-
9 files changed, 1042 insertions(+), 56 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
index 58acda88fb..8065d35c85 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
@@ -25,7 +25,7 @@ import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.jni.JniLibLoader
import org.apache.gluten.vectorized.CHNativeExpressionEvaluator
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext}
import org.apache.spark.api.plugin.PluginContext
import org.apache.spark.internal.Logging
import org.apache.spark.listener.CHGlutenSQLAppStatusListener
@@ -85,7 +85,7 @@ class CHListenerApi extends ListenerApi with Logging {
conf.setCHConfig(
"timezone" -> conf.get("spark.sql.session.timeZone",
TimeZone.getDefault.getID),
"local_engine.settings.log_processors_profiles" -> "true")
-
+ conf.setCHSettings("spark_version", SPARK_VERSION)
// add memory limit for external sort
val externalSortKey =
CHConf.runtimeSettings("max_bytes_before_external_sort")
if (conf.getLong(externalSortKey, -1) < 0) {
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
index 40f442bc29..f772c909c8 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
@@ -69,7 +69,7 @@ class GlutenClickHouseDecimalSuite
(DecimalType.apply(18, 8), Seq()),
// 3/10: all value is null and compare with limit
// 1 Spark 3.5
- (DecimalType.apply(38, 19), if (isSparkVersionLE("3.3")) Seq(3, 10) else
Seq(1, 3, 10))
+ (DecimalType.apply(38, 19), if (isSparkVersionLE("3.3")) Seq(3, 10) else
Seq(3, 10))
)
private def createDecimalTables(dataType: DecimalType): Unit = {
@@ -309,7 +309,10 @@ class GlutenClickHouseDecimalSuite
"insert into decimals_test values(1, 100.0, 999.0)" +
", (2, 12345.123, 12345.123)" +
", (3, 0.1234567891011, 1234.1)" +
- ", (4, 123456789123456789.0, 1.123456789123456789)"
+ ", (4, 123456789123456789.0, 1.123456789123456789)" +
+ ", (5, 0, 0)" +
+ ", (6, 0, 1.23)" +
+ ", (7, 1.23, 0)"
spark.sql(createSql)
try {
diff --git
a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
index 0aa2331457..524a39d3a1 100644
--- a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
+++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
@@ -24,6 +24,7 @@
#include <Common/CHUtil.h>
#include <Common/GlutenDecimalUtils.h>
+#include <Common/GlutenSettings.h>
namespace DB
{
@@ -47,7 +48,7 @@ DataTypePtr getSparkAvgReturnType(const DataTypePtr &
arg_type)
return createDecimal<DataTypeDecimal>(precision_value, scale_value);
}
-template <typename T>
+template <typename T, bool SPARK35>
requires is_decimal<T>
class AggregateFunctionSparkAvg final : public AggregateFunctionAvg<T>
{
@@ -61,7 +62,7 @@ public:
{
}
- DataTypePtr createResultType(const DataTypes & argument_types_, UInt32
num_scale_, UInt32 round_scale_)
+ DataTypePtr createResultType(const DataTypes & argument_types_, UInt32
num_scale_, UInt32 /*round_scale_*/)
{
const DataTypePtr & data_type = argument_types_[0];
const UInt32 precision_value =
std::min<size_t>(getDecimalPrecision(*data_type) + 4,
DecimalUtils::max_precision<Decimal128>);
@@ -82,7 +83,7 @@ public:
else if (which.isDecimal64())
{
assert_cast<ColumnDecimal<Decimal64> &>(to).getData().push_back(
- divideDecimalAndUInt(this->data(place), num_scale, result_scale,
round_scale));
+ divideDecimalAndUInt(this->data(place), num_scale,
result_scale, round_scale));
}
else if (which.isDecimal128())
{
@@ -116,6 +117,9 @@ private:
auto result = value / avg.denominator;
+ if constexpr (SPARK35)
+ return result;
+
if (round_scale > result_scale)
return result;
@@ -128,8 +132,21 @@ private:
UInt32 round_scale;
};
-AggregateFunctionPtr
-createAggregateFunctionSparkAvg(const std::string & name, const DataTypes &
argument_types, const Array & parameters, const Settings * settings)
+template <bool Data, typename... TArgs>
+static IAggregateFunction * createWithDecimalType(const IDataType &
argument_type, TArgs && ... args)
+{
+ WhichDataType which(argument_type);
+ if (which.idx == TypeIndex::Decimal32) return new
AggregateFunctionSparkAvg<Decimal32, Data>(args...);
+ if (which.idx == TypeIndex::Decimal64) return new
AggregateFunctionSparkAvg<Decimal64, Data>(args...);
+ if (which.idx == TypeIndex::Decimal128) return new
AggregateFunctionSparkAvg<Decimal128, Data>(args...);
+ if (which.idx == TypeIndex::Decimal256) return new
AggregateFunctionSparkAvg<Decimal256, Data>(args...);
+ if constexpr (AggregateFunctionSparkAvg<DateTime64,
Data>::DateTime64Supported)
+ if (which.idx == TypeIndex::DateTime64) return new
AggregateFunctionSparkAvg<DateTime64, Data>(args...);
+ return nullptr;
+}
+
+AggregateFunctionPtr createAggregateFunctionSparkAvg(
+ const std::string & name, const DataTypes & argument_types, const Array &
parameters, const Settings * settings)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
@@ -140,13 +157,20 @@ createAggregateFunctionSparkAvg(const std::string & name,
const DataTypes & argu
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument
for aggregate function {}", data_type->getName(), name);
+ std::string version;
+ if (tryGetString(*settings, "spark_version", version) &&
version.starts_with("3.5"))
+ {
+ res.reset(createWithDecimalType<true>(*data_type, argument_types,
getDecimalScale(*data_type), 0));
+ return res;
+ }
+
bool allowPrecisionLoss =
settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).safeGet<bool>();
const UInt32 p1 = DB::getDecimalPrecision(*data_type);
const UInt32 s1 = DB::getDecimalScale(*data_type);
auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL;
auto [_, round_scale] = GlutenDecimalUtils::dividePrecisionScale(p1, s1,
p2, s2, allowPrecisionLoss);
- res.reset(createWithDecimalType<AggregateFunctionSparkAvg>(*data_type,
argument_types, getDecimalScale(*data_type), round_scale));
+ res.reset(createWithDecimalType<false>(*data_type, argument_types,
getDecimalScale(*data_type), round_scale));
return res;
}
diff --git a/cpp-ch/local-engine/Common/GlutenDecimalUtils.h
b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h
index 32af66ec59..cf600a5cc9 100644
--- a/cpp-ch/local-engine/Common/GlutenDecimalUtils.h
+++ b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h
@@ -95,13 +95,6 @@ public:
}
}
- static std::tuple<size_t, size_t> widerDecimalType(const size_t p1, const
size_t s1, const size_t p2, const size_t s2)
- {
- // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
- auto scale = std::max(s1, s2);
- auto range = std::max(p1 - s1, p2 - s2);
- return std::tuple(range + scale, scale);
- }
};
diff --git
a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp
new file mode 100644
index 0000000000..26d8e0deb3
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp
@@ -0,0 +1,591 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "SparkFunctionDecimalBinaryArithmetic.h"
+
+#include <Columns/ColumnDecimal.h>
+#include <Columns/ColumnNullable.h>
+#include <Columns/ColumnsNumber.h>
+#include <DataTypes/DataTypeNullable.h>
+#include <DataTypes/DataTypesDecimal.h>
+#include <Functions/FunctionFactory.h>
+#include <Functions/FunctionHelpers.h>
+#include <Functions/IFunction.h>
+#include <Functions/castTypeToEither.h>
+#include <Common/CurrentThread.h>
+#include <Common/GlutenDecimalUtils.h>
+#include <Common/ProfileEvents.h>
+#include <Common/Stopwatch.h>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+extern const int ILLEGAL_COLUMN;
+extern const int TYPE_MISMATCH;
+extern const int LOGICAL_ERROR;
+}
+
+}
+
+namespace local_engine
+{
+using namespace DB;
+
+namespace
+{
+enum class OpCase : uint8_t
+{
+ Vector,
+ LeftConstant,
+ RightConstant
+};
+
+enum class OpMode : uint8_t
+{
+ Default,
+ Effect
+};
+
+template <bool is_plus_minus, bool is_multiply, bool is_division, bool
is_modulo>
+bool calculateWith256(const IDataType & left, const IDataType & right)
+{
+ const size_t p1 = getDecimalPrecision(left);
+ const size_t s1 = getDecimalScale(left);
+ const size_t p2 = getDecimalPrecision(right);
+ const size_t s2 = getDecimalScale(right);
+
+ size_t precision;
+ if constexpr (is_plus_minus)
+ precision = std::max(s1, s2) + std::max(p1 - s1, p2 - s2) + 1;
+ else if constexpr (is_multiply)
+ precision = p1 + p2 + 1;
+ else if constexpr (is_division)
+ precision = p1 - s1 + s2 + std::max(static_cast<size_t>(6), s1 + p2 +
1);
+ else if constexpr (is_modulo)
+ precision = std::min(p1 - s1, p2 - s2) + std::max(s1, s2);
+ else
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "Not supported.");
+
+ return precision > DataTypeDecimal128::maxPrecision();
+}
+
+template <typename Operation, OpMode Mode>
+struct SparkDecimalBinaryOperation
+{
+private:
+ static constexpr bool is_plus_minus = SparkIsOperation<Operation>::plus ||
SparkIsOperation<Operation>::minus;
+ static constexpr bool is_multiply = SparkIsOperation<Operation>::multiply;
+ static constexpr bool is_division = SparkIsOperation<Operation>::division;
+ static constexpr bool is_modulo = SparkIsOperation<Operation>::modulo;
+
+public:
+ template <typename A, typename B, typename R>
+ static ColumnPtr executeDecimal(const ColumnsWithTypeAndName & arguments,
const A & left, const B & right, const R & result)
+ {
+ using LeftDataType = std::decay_t<decltype(left)>; // e.g.
DataTypeDecimal<Decimal32>
+ using RightDataType = std::decay_t<decltype(right)>; // e.g.
DataTypeDecimal<Decimal32>
+ using ResultDataType = std::decay_t<decltype(result)>; // e.g.
DataTypeDecimal<Decimal32>
+
+ using ColVecLeft = ColumnVectorOrDecimal<typename
LeftDataType::FieldType>;
+ using ColVecRight = ColumnVectorOrDecimal<typename
RightDataType::FieldType>;
+
+ const ColumnPtr left_col = arguments[0].column;
+ const ColumnPtr right_col = arguments[1].column;
+
+ const auto * const col_left_raw = left_col.get();
+ const auto * const col_right_raw = right_col.get();
+
+ const size_t col_left_size = col_left_raw->size();
+
+ const ColumnConst * const col_left_const =
checkAndGetColumnConst<ColVecLeft>(col_left_raw);
+ const ColumnConst * const col_right_const =
checkAndGetColumnConst<ColVecRight>(col_right_raw);
+
+ const ColVecLeft * const col_left =
checkAndGetColumn<ColVecLeft>(col_left_raw);
+ const ColVecRight * const col_right =
checkAndGetColumn<ColVecRight>(col_right_raw);
+
+ if constexpr (Mode == OpMode::Effect)
+ {
+ return executeDecimalImpl<LeftDataType, RightDataType,
ResultDataType>(
+ left, right, col_left_const, col_right_const, col_left,
col_right, col_left_size, result);
+ }
+
+ if (calculateWith256<is_plus_minus, is_multiply, is_division,
is_modulo>(*arguments[0].type.get(), *arguments[1].type.get()))
+ {
+ return executeDecimalImpl<LeftDataType, RightDataType,
ResultDataType, true>(
+ left, right, col_left_const, col_right_const, col_left,
col_right, col_left_size, result);
+ }
+
+ return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType>(
+ left, right, col_left_const, col_right_const, col_left, col_right,
col_left_size, result);
+ }
+
+private:
+ // ResultDataType e.g. DataTypeDecimal<Decimal32>
+ template <class LeftDataType, class RightDataType, class ResultDataType,
bool CalculateWith256 = false>
+ static ColumnPtr executeDecimalImpl(
+ const auto & left,
+ const auto & right,
+ const ColumnConst * const col_left_const,
+ const ColumnConst * const col_right_const,
+ const auto * const col_left,
+ const auto * const col_right,
+ size_t col_left_size,
+ const ResultDataType & resultDataType)
+ {
+ using LeftFieldType = typename LeftDataType::FieldType;
+ using RightFieldType = typename RightDataType::FieldType;
+ using ResultFieldType = typename ResultDataType::FieldType;
+
+ using NativeResultType = NativeType<ResultFieldType>;
+ using ColVecResult = ColumnVectorOrDecimal<ResultFieldType>;
+
+ size_t max_scale;
+ if constexpr (is_multiply)
+ max_scale = left.getScale() + right.getScale();
+ else
+ max_scale = std::max(resultDataType.getScale(),
std::max(left.getScale(), right.getScale()));
+
+ NativeResultType scale_left = [&]
+ {
+ if constexpr (is_multiply)
+ return NativeResultType{1};
+
+ // cast scale same to left
+ auto diff_scale = max_scale - left.getScale();
+ if constexpr (is_division)
+ return
DecimalUtils::scaleMultiplier<NativeResultType>(diff_scale + max_scale);
+ else
+ return
DecimalUtils::scaleMultiplier<NativeResultType>(diff_scale);
+ }();
+
+ const NativeResultType scale_right = [&]
+ {
+ if constexpr (is_multiply)
+ return NativeResultType{1};
+ else
+ return
DecimalUtils::scaleMultiplier<NativeResultType>(max_scale - right.getScale());
+ }();
+
+
+ bool calculate_with_256 = false;
+ if constexpr (CalculateWith256)
+ calculate_with_256 = true;
+ else
+ {
+ auto p1 = left.getPrecision();
+ auto p2 = right.getPrecision();
+ if (DataTypeDecimal<LeftFieldType>::maxPrecision() < p1 +
max_scale - left.getScale()
+ || DataTypeDecimal<RightFieldType>::maxPrecision() < p2 +
max_scale - right.getScale())
+ calculate_with_256 = true;
+ }
+
+ ColumnUInt8::MutablePtr col_null_map_to =
ColumnUInt8::create(col_left_size, false);
+ ColumnUInt8::Container * vec_null_map_to = &col_null_map_to->getData();
+
+ typename ColVecResult::MutablePtr col_res = ColVecResult::create(0,
resultDataType.getScale());
+ auto & vec_res = col_res->getData();
+ vec_res.resize(col_left_size);
+
+ if (col_left && col_right)
+ {
+ if (calculate_with_256)
+ {
+ process<OpCase::Vector, true>(
+ col_left->getData(),
+ col_right->getData(),
+ vec_res,
+ scale_left,
+ scale_right,
+ *vec_null_map_to,
+ resultDataType,
+ max_scale);
+ }
+ else
+ {
+ process<OpCase::Vector, false>(
+ col_left->getData(),
+ col_right->getData(),
+ vec_res,
+ scale_left,
+ scale_right,
+ *vec_null_map_to,
+ resultDataType,
+ max_scale);
+ }
+ }
+ else if (col_left_const && col_right)
+ {
+ LeftFieldType const_left =
col_left_const->getValue<LeftFieldType>();
+
+ if (calculate_with_256)
+ {
+ process<OpCase::LeftConstant, true>(
+ const_left, col_right->getData(), vec_res, scale_left,
scale_right, *vec_null_map_to, resultDataType, max_scale);
+ }
+ else
+ {
+ process<OpCase::LeftConstant, false>(
+ const_left, col_right->getData(), vec_res, scale_left,
scale_right, *vec_null_map_to, resultDataType, max_scale);
+ }
+ }
+ else if (col_left && col_right_const)
+ {
+ RightFieldType const_right =
col_right_const->getValue<RightFieldType>();
+ if (calculate_with_256)
+ {
+ process<OpCase::RightConstant, true>(
+ col_left->getData(), const_right, vec_res, scale_left,
scale_right, *vec_null_map_to, resultDataType, max_scale);
+ }
+ else
+ {
+ process<OpCase::RightConstant, false>(
+ col_left->getData(), const_right, vec_res, scale_left,
scale_right, *vec_null_map_to, resultDataType, max_scale);
+ }
+ }
+ else
+ {
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "Not supported.");
+ }
+
+ return ColumnNullable::create(std::move(col_res),
std::move(col_null_map_to));
+ }
+
+ template <OpCase op_case, bool CalculateWith256, typename
ResultContainerType, typename NativeResultType, typename ResultDataType>
+ static static void NO_INLINE process(
+ const auto & a,
+ const auto & b,
+ ResultContainerType & result_container,
+ const NativeResultType & scale_a,
+ const NativeResultType & scale_b,
+ ColumnUInt8::Container & vec_null_map_to,
+ const ResultDataType & resultDataType,
+ const size_t & max_scale)
+ {
+ size_t size;
+ if constexpr (op_case == OpCase::LeftConstant)
+ size = b.size();
+ else
+ size = a.size();
+
+ if constexpr (op_case == OpCase::Vector)
+ {
+ for (size_t i = 0; i < size; ++i)
+ {
+ NativeResultType res;
+ if (calculate<CalculateWith256>(
+ unwrap<op_case, OpCase::LeftConstant>(a, i),
+ unwrap<op_case, OpCase::RightConstant>(b, i),
+ scale_a,
+ scale_b,
+ res,
+ resultDataType,
+ max_scale))
+ result_container[i] = res;
+ else
+ vec_null_map_to[i] = static_cast<UInt8>(1);
+ }
+ }
+ else if constexpr (op_case == OpCase::LeftConstant)
+ {
+ auto scaled_a = applyScaled(unwrap<op_case,
OpCase::LeftConstant>(a, 0), scale_a);
+ for (size_t i = 0; i < size; ++i)
+ {
+ NativeResultType res;
+ if (calculate<CalculateWith256>(
+ scaled_a,
+ unwrap<op_case, OpCase::RightConstant>(b, i),
+ static_cast<NativeResultType>(0),
+ scale_b,
+ res,
+ resultDataType,
+ max_scale))
+ result_container[i] = res;
+ else
+ vec_null_map_to[i] = static_cast<UInt8>(1);
+ }
+ }
+ else if constexpr (op_case == OpCase::RightConstant)
+ {
+ auto scaled_b = applyScaled(unwrap<op_case,
OpCase::RightConstant>(b, 0), scale_b);
+
+ for (size_t i = 0; i < size; ++i)
+ {
+ NativeResultType res;
+ if (calculate<CalculateWith256>(
+ unwrap<op_case, OpCase::LeftConstant>(a, i),
+ scaled_b,
+ scale_a,
+ static_cast<NativeResultType>(0),
+ res,
+ resultDataType,
+ max_scale))
+ result_container[i] = res;
+ else
+ vec_null_map_to[i] = static_cast<UInt8>(1);
+ }
+ }
+ }
+
+ // ResultNativeType = Int32/64/128/256
+ template <bool CalculateWith256, typename LeftNativeType, typename
RightNativeType, typename NativeResultType, typename ResultDataType>
+ static NO_SANITIZE_UNDEFINED bool calculate(
+ const LeftNativeType l,
+ const RightNativeType r,
+ const NativeResultType & scale_left,
+ const NativeResultType & scale_right,
+ NativeResultType & res,
+ const ResultDataType & resultDataType,
+ const size_t & max_scale)
+ {
+ if constexpr (CalculateWith256)
+ return calculateImpl<Int256>(l, r, scale_left, scale_right, res,
resultDataType, max_scale);
+ else if (is_division)
+ return calculateImpl<Int128>(l, r, scale_left, scale_right, res,
resultDataType, max_scale);
+ else
+ return calculateImpl<NativeResultType>(l, r, scale_left,
scale_right, res, resultDataType, max_scale);
+ }
+
+ template <typename CalcType, typename LeftNativeType, typename
RightNativeType, typename NativeResultType, typename ResultDataType>
+ static NO_SANITIZE_UNDEFINED bool calculateImpl(
+ const LeftNativeType & l,
+ const RightNativeType & r,
+ const NativeResultType & scale_left,
+ const NativeResultType & scale_right,
+ NativeResultType & res,
+ const ResultDataType & resultDataType,
+ const size_t & max_scale)
+ {
+ CalcType scaled_l = applyScaled(static_cast<CalcType>(l),
static_cast<CalcType>(scale_left));
+ CalcType scaled_r = applyScaled(static_cast<CalcType>(r),
static_cast<CalcType>(scale_right));
+
+ CalcType c_res = 0;
+ auto success = Operation::template apply<CalcType>(scaled_l, scaled_r,
c_res);
+ if (!success)
+ return false;
+
+ auto result_scale = resultDataType.getScale();
+ auto scale_diff = max_scale - result_scale;
+ chassert(scale_diff >= 0);
+ if (scale_diff)
+ {
+ auto scaled_diff =
DecimalUtils::scaleMultiplier<CalcType>(scale_diff);
+ DecimalDivideImpl::apply<CalcType>(c_res, scaled_diff, c_res);
+ }
+
+ // check overflow
+ if constexpr (std::is_same_v<CalcType, Int256> || is_division)
+ {
+ auto max_value =
intExp10OfSize<CalcType>(resultDataType.getPrecision());
+ if (c_res <= -max_value || c_res >= max_value)
+ return false;
+ }
+
+ res = static_cast<NativeResultType>(c_res);
+
+ return true;
+ }
+
+ template <OpCase op_case, OpCase target, class E>
+ static auto unwrap(const E & elem, size_t i)
+ {
+ if constexpr (op_case == target)
+ return elem.value;
+ else
+ return elem[i].value;
+ }
+
+ template <typename NativeType, typename ResultNativeType>
+ static ResultNativeType applyScaled(const NativeType & l, const
ResultNativeType & scale)
+ {
+ if (scale > 1)
+ return common::mulIgnoreOverflow(l, scale);
+
+ return static_cast<ResultNativeType>(l);
+ }
+};
+
+
+template <class Operation, typename Name, OpMode Mode = OpMode::Default>
+class SparkFunctionDecimalBinaryArithmetic final : public IFunction
+{
+ static constexpr bool is_plus_minus = SparkIsOperation<Operation>::plus ||
SparkIsOperation<Operation>::minus;
+ static constexpr bool is_multiply = SparkIsOperation<Operation>::multiply;
+ static constexpr bool is_division = SparkIsOperation<Operation>::division;
+ static constexpr bool is_modulo = SparkIsOperation<Operation>::modulo;
+
+public:
+ static constexpr auto name = Name::name;
+
+ static FunctionPtr create(ContextPtr context_) { return
std::make_shared<SparkFunctionDecimalBinaryArithmetic>(context_); }
+
+ explicit SparkFunctionDecimalBinaryArithmetic(ContextPtr context_) :
context(context_) { }
+
+ String getName() const override { return name; }
+ size_t getNumberOfArguments() const override { return 3; }
+ bool isSuitableForShortCircuitArgumentsExecution(const
DataTypesWithConstInfo & /*arguments*/) const override { return false; }
+ bool useDefaultImplementationForConstants() const override { return true; }
+ ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return
{2}; }
+
+ DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments)
const override
+ {
+ if (arguments.size() != 3)
+ throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function '{}' expects 3 arguments", getName());
+
+ if (!isDecimal(arguments[0].type) || !isDecimal(arguments[1].type) ||
!isDecimal(arguments[2].type))
+ throw Exception(
+ ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
+ "Illegal type {} {} {} of argument of function {}",
+ arguments[0].type->getName(),
+ arguments[1].type->getName(),
+ arguments[2].type->getName(),
+ getName());
+
+ return std::make_shared<DataTypeNullable>(arguments[2].type);
+ }
+
+ // executeImpl2
+ ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const
DataTypePtr &, size_t) const override
+ {
+ const auto & left_argument = arguments[0];
+ const auto & right_argument = arguments[1];
+
+ const auto * left_generic = left_argument.type.get();
+ const auto * right_generic = right_argument.type.get();
+
+ ColumnPtr res;
+ const bool valid = castBothTypes(
+ left_generic,
+ right_generic,
+ removeNullable(arguments[2].type).get(),
+ [&](const auto & left, const auto & right, const auto & result) {
+ return (res = SparkDecimalBinaryOperation<Operation,
Mode>::template executeDecimal(arguments, left, right, result))
+ != nullptr;
+ });
+
+ if (!valid)
+ {
+ // This is a logical error, because the types should have been
checked
+ // by getReturnTypeImpl().
+ throw Exception(
+ ErrorCodes::LOGICAL_ERROR,
+ "Arguments of '{}' have incorrect data types: '{}' of type
'{}',"
+ " '{}' of type '{}'",
+ getName(),
+ left_argument.name,
+ left_argument.type->getName(),
+ right_argument.name,
+ right_argument.type->getName());
+ }
+
+ return res;
+ }
+
+private:
+ template <typename F>
+ static bool castBothTypes(const IDataType * left, const IDataType * right,
const IDataType * result, F && f)
+ {
+ return castType(
+ left,
+ [&](const auto & left_)
+ {
+ return castType(
+ right,
+ [&](const auto & right_) { return castType(result,
[&](const auto & result_) { return f(left_, right_, result_); }); });
+ });
+ }
+
+ static bool castType(const IDataType * type, auto && f)
+ {
+ using Types = TypeList<DataTypeDecimal32, DataTypeDecimal64,
DataTypeDecimal128, DataTypeDecimal256>;
+ return castTypeToEither(Types{}, type, std::forward<decltype(f)>(f));
+ }
+
+ ContextPtr context;
+};
+
+struct NameSparkDecimalPlus
+{
+ static constexpr auto name = "sparkDecimalPlus";
+};
+struct NameSparkDecimalPlusEffect
+{
+ static constexpr auto name = "sparkDecimalPlusEffect";
+};
+struct NameSparkDecimalMinus
+{
+ static constexpr auto name = "sparkDecimalMinus";
+};
+struct NameSparkDecimalMinusEffect
+{
+ static constexpr auto name = "sparkDecimalMinusEffect";
+};
+struct NameSparkDecimalMultiply
+{
+ static constexpr auto name = "sparkDecimalMultiply";
+};
+struct NameSparkDecimalMultiplyEffect
+{
+ static constexpr auto name = "sparkDecimalMultiplyEffect";
+};
+struct NameSparkDecimalDivide
+{
+ static constexpr auto name = "sparkDecimalDivide";
+};
+struct NameSparkDecimalDivideEffect
+{
+ static constexpr auto name = "sparkDecimalDivideEffect";
+};
+struct NameSparkDecimalModulo
+{
+ static constexpr auto name = "NameSparkDecimalModulo";
+};
+struct NameSparkDecimalModuloEffect
+{
+ static constexpr auto name = "NameSparkDecimalModuloEffect";
+};
+
+
+using DecimalPlus = SparkFunctionDecimalBinaryArithmetic<DecimalPlusImpl,
NameSparkDecimalPlus>;
+using DecimalMinus = SparkFunctionDecimalBinaryArithmetic<DecimalMinusImpl,
NameSparkDecimalMinus>;
+using DecimalMultiply =
SparkFunctionDecimalBinaryArithmetic<DecimalMultiplyImpl,
NameSparkDecimalMultiply>;
+using DecimalDivide = SparkFunctionDecimalBinaryArithmetic<DecimalDivideImpl,
NameSparkDecimalDivide>;
+using DecimalModulo = SparkFunctionDecimalBinaryArithmetic<DecimalModuloImpl,
NameSparkDecimalModulo>;
+
+using DecimalPlusEffect =
SparkFunctionDecimalBinaryArithmetic<DecimalPlusImpl,
NameSparkDecimalPlusEffect, OpMode::Effect>;
+using DecimalMinusEffect =
SparkFunctionDecimalBinaryArithmetic<DecimalMinusImpl,
NameSparkDecimalMinusEffect, OpMode::Effect>;
+using DecimalMultiplyEffect =
SparkFunctionDecimalBinaryArithmetic<DecimalMultiplyImpl,
NameSparkDecimalMultiplyEffect, OpMode::Effect>;
+using DecimalDivideEffect =
SparkFunctionDecimalBinaryArithmetic<DecimalDivideImpl,
NameSparkDecimalDivideEffect, OpMode::Effect>;
+using DecimalModuloEffect =
SparkFunctionDecimalBinaryArithmetic<DecimalModuloImpl,
NameSparkDecimalModuloEffect, OpMode::Effect>;
+}
+
+REGISTER_FUNCTION(SparkDecimalFunctionArithmetic)
+{
+ factory.registerFunction<DecimalPlus>();
+ factory.registerFunction<DecimalMinus>();
+ factory.registerFunction<DecimalMultiply>();
+ factory.registerFunction<DecimalDivide>();
+ factory.registerFunction<DecimalModulo>();
+
+ factory.registerFunction<DecimalPlusEffect>();
+ factory.registerFunction<DecimalMinusEffect>();
+ factory.registerFunction<DecimalMultiplyEffect>();
+ factory.registerFunction<DecimalDivideEffect>();
+ factory.registerFunction<DecimalModuloEffect>();
+}
+}
diff --git
a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h
b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h
new file mode 100644
index 0000000000..05e5f4ff9f
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h
@@ -0,0 +1,276 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <base/arithmeticOverflow.h>
+
+namespace local_engine
+{
+
+static bool canCastLower(const Int256 & a, const Int256 & b)
+{
+ if (a.items[2] == 0 && a.items[3] == 0 && b.items[2] == 0 && b.items[3] ==
0)
+ return true;
+
+ return false;
+}
+
+static bool canCastLower(const Int128 & a, const Int128 & b)
+{
+ if (a.items[1] == 0 && b.items[1] == 0)
+ return true;
+
+ return false;
+}
+
+struct DecimalPlusImpl
+{
+ template <typename A>
+ static bool apply(A & a, A & b, A & r)
+ {
+ return !common::addOverflow(a, b, r);
+ }
+
+ template <>
+ static bool apply(Int128 & a, Int128 & b, Int128 & r)
+ {
+ if (canCastLower(a, b))
+ {
+ UInt64 low_result;
+ if (common::addOverflow(static_cast<UInt64>(a),
static_cast<UInt64>(b), low_result))
+ return !common::addOverflow(a, b, r);
+
+ r = static_cast<Int128>(low_result);
+ return true;
+ }
+
+ return !common::addOverflow(a, b, r);
+ }
+
+
+ template <>
+ static bool apply(Int256 & a, Int256 & b, Int256 & r)
+ {
+ if (canCastLower(a, b))
+ {
+ UInt128 low_result;
+ if (common::addOverflow(static_cast<UInt128>(a),
static_cast<UInt128>(b), low_result))
+ return !common::addOverflow(a, b, r);
+
+ r = static_cast<Int256>(low_result);
+ return true;
+ }
+
+ return !common::addOverflow(a, b, r);
+ }
+
+
+#if USE_EMBEDDED_COMPILER
+ static constexpr bool compilable = true;
+
+ static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left,
llvm::Value * right, bool)
+ {
+ return left->getType()->isIntegerTy() ? b.CreateAdd(left, right) :
b.CreateFAdd(left, right);
+ }
+#endif
+};
+
+struct DecimalMinusImpl
+{
+ /// Apply operation and check overflow. It's used for Deciamal operations.
@returns true if overflowed, false otherwise.
+ template <typename A>
+ static bool apply(A & a, A & b, A & r)
+ {
+ return !common::subOverflow(a, b, r);
+ }
+
+ template <>
+ static bool apply(Int128 & a, Int128 & b, Int128 & r)
+ {
+ if (canCastLower(a, b))
+ {
+ UInt64 low_result;
+ if (common::subOverflow(static_cast<UInt64>(a),
static_cast<UInt64>(b), low_result))
+ return !common::subOverflow(a, b, r);
+
+ r = static_cast<Int128>(low_result);
+ return true;
+ }
+
+ return !common::subOverflow(a, b, r);
+ }
+
+ template <>
+ static bool apply(Int256 & a, Int256 & b, Int256 & r)
+ {
+ if (canCastLower(a, b))
+ {
+ UInt128 low_result;
+ if (common::subOverflow(static_cast<UInt128>(a),
static_cast<UInt128>(b), low_result))
+ return !common::subOverflow(a, b, r);
+
+ r = static_cast<Int256>(low_result);
+ return true;
+ }
+
+ return !common::subOverflow(a, b, r);
+ }
+
+#if USE_EMBEDDED_COMPILER
+ static constexpr bool compilable = true;
+
+ static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left,
llvm::Value * right, bool)
+ {
+ return left->getType()->isIntegerTy() ? b.CreateSub(left, right) :
b.CreateFSub(left, right);
+ }
+#endif
+};
+
+
+struct DecimalMultiplyImpl
+{
+ /// Apply operation and check overflow. It's used for Decimal operations.
@returns true if overflowed, false otherwise.
+ template <typename A>
+ static bool apply(A & a, A & b, A & c)
+ {
+ return !common::mulOverflow(a, b, c);
+ }
+
+ template <Int128>
+ static bool apply(Int128 & a, Int128 & b, Int128 & r)
+ {
+ if (canCastLower(a, b))
+ {
+ UInt64 low_result = 0;
+ if (common::mulOverflow(static_cast<UInt64>(a),
static_cast<UInt64>(b), low_result))
+ return !common::mulOverflow(a, b, r);
+
+ r = static_cast<Int128>(low_result);
+ return true;
+ }
+
+ return !common::mulOverflow(a, b, r);
+ }
+
+#if USE_EMBEDDED_COMPILER
+ static constexpr bool compilable = true;
+
+ static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left,
llvm::Value * right, bool)
+ {
+ return left->getType()->isIntegerTy() ? b.CreateMul(left, right) :
b.CreateFMul(left, right);
+ }
+#endif
+};
+
+struct DecimalDivideImpl
+{
+ template <typename A>
+ static bool apply(A & a, A & b, A & r)
+ {
+ if (b == 0)
+ return false;
+
+ r = a / b;
+ return true;
+ }
+
+ template <>
+ static bool apply(Int128 & a, Int128 & b, Int128 & r)
+ {
+ if (b == 0)
+ return false;
+
+ if (canCastLower(a, b))
+ {
+ r = static_cast<Int128>(static_cast<UInt64>(a) /
static_cast<UInt64>(b));
+ return true;
+ }
+
+ r = a / b;
+ return true;
+ }
+
+ template <>
+ static bool apply(Int256 & a, Int256 & b, Int256 & r)
+ {
+ if (b == 0)
+ return false;
+
+ if (canCastLower(a, b))
+ {
+ UInt128 low_result = 0;
+ UInt128 low_a = static_cast<UInt128>(a);
+ UInt128 low_b = static_cast<UInt128>(b);
+ apply(low_a, low_b, low_result);
+ r = static_cast<Int256>(low_result);
+ return true;
+ }
+
+ r = a / b;
+ return true;
+ }
+
+#if USE_EMBEDDED_COMPILER
+ static constexpr bool compilable = true;
+
+ static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left,
llvm::Value * right, bool)
+ {
+ return left->getType()->isIntegerTy() ? b.CreateSub(left, right) :
b.CreateFSub(left, right);
+ }
+#endif
+};
+
+
+// ModuloImpl
+struct DecimalModuloImpl
+{
+ template <typename A>
+ static bool apply(A & a, A & b, A & r)
+ {
+ if (b == 0)
+ return false;
+
+ r = a % b;
+ return true;
+ }
+
+#if USE_EMBEDDED_COMPILER
+ static constexpr bool compilable = true;
+
+ static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left,
llvm::Value * right, bool)
+ {
+ return left->getType()->isIntegerTy() ? b.CreateSub(left, right) :
b.CreateFSub(left, right);
+ }
+#endif
+};
+
+template <typename Op1, typename Op2>
+struct IsSameOperation
+{
+ static constexpr bool value = std::is_same_v<Op1, Op2>;
+};
+
+template <typename Op>
+struct SparkIsOperation
+{
+ static constexpr bool plus = IsSameOperation<Op, DecimalPlusImpl>::value;
+ static constexpr bool minus = IsSameOperation<Op, DecimalMinusImpl>::value;
+ static constexpr bool multiply = IsSameOperation<Op,
DecimalMultiplyImpl>::value;
+ static constexpr bool division = IsSameOperation<Op,
DecimalDivideImpl>::value;
+ static constexpr bool modulo = IsSameOperation<Op,
DecimalModuloImpl>::value;
+};
+}
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
index c9a48106a9..9e1ffc0dd6 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
@@ -135,6 +135,7 @@ public:
IQueryPlanStep * addRollbackFilterHeaderStep(QueryPlanPtr & query_plan,
const Block & input_header);
static std::pair<DataTypePtr, Field> parseLiteral(const
substrait::Expression_Literal & literal);
+ ContextPtr getContext() const { return context; }
std::vector<QueryPlanPtr> extra_plan_holder;
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp
index 6aba310bf0..f73305df02 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp
@@ -22,6 +22,7 @@
#include <Parser/TypeParser.h>
#include <Common/BlockTypeUtils.h>
#include <Common/CHUtil.h>
+#include <Common/GlutenSettings.h>
namespace DB::ErrorCodes
{
@@ -136,8 +137,11 @@ protected:
return toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull",
overflow_args);
}
- virtual const DB::ActionsDAG::Node *
- createFunctionNode(DB::ActionsDAG & actions_dag, const String & func_name,
const DB::ActionsDAG::NodeRawConstPtrs & args) const
+ virtual const DB::ActionsDAG::Node * createFunctionNode(
+ DB::ActionsDAG & actions_dag,
+ const String & func_name,
+ const DB::ActionsDAG::NodeRawConstPtrs & args,
+ DataTypePtr result_type) const
{
return toFunctionNode(actions_dag, func_name, args);
}
@@ -154,33 +158,8 @@ public:
const auto left_type = DB::removeNullable(parsed_args[0]->result_type);
const auto right_type =
DB::removeNullable(parsed_args[1]->result_type);
- const bool converted = isDecimal(left_type) && isDecimal(right_type);
-
- if (converted)
- {
- const DecimalType evalType = getDecimalType(left_type, right_type);
- parsed_args = convertBinaryArithmeticFunDecimalArgs(actions_dag,
parsed_args, evalType, substrait_func);
- }
-
- const auto * func_node = createFunctionNode(actions_dag, ch_func_name,
parsed_args);
-
- if (converted)
- {
- const auto parsed_output_type =
removeNullable(TypeParser::parseType(substrait_func.output_type()));
- assert(isDecimal(parsed_output_type));
- const Int32 parsed_precision =
getDecimalPrecision(*parsed_output_type);
- const Int32 parsed_scale = getDecimalScale(*parsed_output_type);
- func_node = checkDecimalOverflow(actions_dag, func_node,
parsed_precision, parsed_scale);
-#ifndef NDEBUG
- const auto output_type = removeNullable(func_node->result_type);
- const Int32 output_precision = getDecimalPrecision(*output_type);
- const Int32 output_scale = getDecimalScale(*output_type);
- if (output_precision != parsed_precision || output_scale !=
parsed_scale)
- throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Function {}
has wrong output type", getName());
-#endif
-
- return func_node;
- }
+ const auto result_type =
removeNullable(TypeParser::parseType(substrait_func.output_type()));
+ const auto * func_node = createFunctionNode(actions_dag, ch_func_name,
parsed_args, result_type);
return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag);
}
};
@@ -199,6 +178,32 @@ protected:
{
return DecimalType::evalAddSubstractDecimalType(p1, s1, p2, s2);
}
+
+ const DB::ActionsDAG::Node * createFunctionNode(
+ DB::ActionsDAG & actions_dag,
+ const String & func_name,
+ const DB::ActionsDAG::NodeRawConstPtrs & new_args,
+ DataTypePtr result_type) const override
+ {
+ const auto * left_arg = new_args[0];
+ const auto * right_arg = new_args[1];
+
+ if (isDecimal(removeNullable(left_arg->result_type)) &&
isDecimal(removeNullable(right_arg->result_type)))
+ {
+ const ActionsDAG::Node * type_node =
&actions_dag.addColumn(ColumnWithTypeAndName(
+ result_type->createColumnConstWithDefaultValue(1),
result_type, getUniqueName(result_type->getName())));
+
+ const auto & settings =
plan_parser->getContext()->getSettingsRef();
+ auto function_name
+ = settings.has("arithmetic.decimal.mode") &&
settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
+ ? "sparkDecimalPlusEffect"
+ : "sparkDecimalPlus";
+
+ return toFunctionNode(actions_dag, function_name, {left_arg,
right_arg, type_node});
+ }
+
+ return toFunctionNode(actions_dag, "plus", {left_arg, right_arg});
+ }
};
class FunctionParserMinus final : public FunctionParserBinaryArithmetic
@@ -215,6 +220,32 @@ protected:
{
return DecimalType::evalAddSubstractDecimalType(p1, s1, p2, s2);
}
+
+ const DB::ActionsDAG::Node * createFunctionNode(
+ DB::ActionsDAG & actions_dag,
+ const String & func_name,
+ const DB::ActionsDAG::NodeRawConstPtrs & new_args,
+ DataTypePtr result_type) const override
+ {
+ const auto * left_arg = new_args[0];
+ const auto * right_arg = new_args[1];
+
+ if (isDecimal(removeNullable(left_arg->result_type)) &&
isDecimal(removeNullable(right_arg->result_type)))
+ {
+ const ActionsDAG::Node * type_node =
&actions_dag.addColumn(ColumnWithTypeAndName(
+ result_type->createColumnConstWithDefaultValue(1),
result_type, getUniqueName(result_type->getName())));
+
+ const auto & settings =
plan_parser->getContext()->getSettingsRef();
+ auto function_name
+ = settings.has("arithmetic.decimal.mode") &&
settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
+ ? "sparkDecimalMinusEffect"
+ : "sparkDecimalMinus";
+
+ return toFunctionNode(actions_dag, function_name, {left_arg,
right_arg, type_node});
+ }
+
+ return toFunctionNode(actions_dag, "minus", {left_arg, right_arg});
+ }
};
class FunctionParserMultiply final : public FunctionParserBinaryArithmetic
@@ -230,6 +261,32 @@ protected:
{
return DecimalType::evalMultiplyDecimalType(p1, s1, p2, s2);
}
+
+ const DB::ActionsDAG::Node * createFunctionNode(
+ DB::ActionsDAG & actions_dag,
+ const String & func_name,
+ const DB::ActionsDAG::NodeRawConstPtrs & new_args,
+ DataTypePtr result_type) const override
+ {
+ const auto * left_arg = new_args[0];
+ const auto * right_arg = new_args[1];
+
+ if (isDecimal(removeNullable(left_arg->result_type)) &&
isDecimal(removeNullable(right_arg->result_type)))
+ {
+ const ActionsDAG::Node * type_node =
&actions_dag.addColumn(ColumnWithTypeAndName(
+ result_type->createColumnConstWithDefaultValue(1),
result_type, getUniqueName(result_type->getName())));
+
+ const auto & settings =
plan_parser->getContext()->getSettingsRef();
+ auto function_name
+ = settings.has("arithmetic.decimal.mode") &&
settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
+ ? "sparkDecimalMultiplyEffect"
+ : "sparkDecimalMultiply";
+
+ return toFunctionNode(actions_dag, function_name, {left_arg,
right_arg, type_node});
+ }
+
+ return toFunctionNode(actions_dag, "multiply", {left_arg, right_arg});
+ }
};
class FunctionParserModulo final : public FunctionParserBinaryArithmetic
@@ -245,6 +302,33 @@ protected:
{
return DecimalType::evalModuloDecimalType(p1, s1, p2, s2);
}
+
+ const DB::ActionsDAG::Node * createFunctionNode(
+ DB::ActionsDAG & actions_dag,
+ const String & func_name,
+ const DB::ActionsDAG::NodeRawConstPtrs & new_args,
+ DataTypePtr result_type) const override
+ {
+ assert(func_name == name);
+ const auto * left_arg = new_args[0];
+ const auto * right_arg = new_args[1];
+
+ if (isDecimal(removeNullable(left_arg->result_type)) ||
isDecimal(removeNullable(right_arg->result_type)))
+ {
+ const ActionsDAG::Node * type_node =
&actions_dag.addColumn(ColumnWithTypeAndName(
+ result_type->createColumnConstWithDefaultValue(1),
result_type, getUniqueName(result_type->getName())));
+
+ const auto & settings =
plan_parser->getContext()->getSettingsRef();
+ auto function_name
+ = settings.has("arithmetic.decimal.mode") &&
settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
+ ? "NameSparkDecimalModuloEffect"
+ : "NameSparkDecimalModulo";
+ ;
+ return toFunctionNode(actions_dag, function_name, {left_arg,
right_arg, type_node});
+ }
+
+ return toFunctionNode(actions_dag, "modulo", {left_arg, right_arg});
+ }
};
class FunctionParserDivide final : public FunctionParserBinaryArithmetic
@@ -262,14 +346,28 @@ protected:
}
const DB::ActionsDAG::Node * createFunctionNode(
- DB::ActionsDAG & actions_dag, const String & func_name, const
DB::ActionsDAG::NodeRawConstPtrs & new_args) const override
+ DB::ActionsDAG & actions_dag,
+ const String & func_name,
+ const DB::ActionsDAG::NodeRawConstPtrs & new_args,
+ DataTypePtr result_type) const override
{
assert(func_name == name);
const auto * left_arg = new_args[0];
const auto * right_arg = new_args[1];
if (isDecimal(removeNullable(left_arg->result_type)) ||
isDecimal(removeNullable(right_arg->result_type)))
- return toFunctionNode(actions_dag, "sparkDivideDecimal",
{left_arg, right_arg});
+ {
+ const ActionsDAG::Node * type_node =
&actions_dag.addColumn(ColumnWithTypeAndName(
+ result_type->createColumnConstWithDefaultValue(1),
result_type, getUniqueName(result_type->getName())));
+
+ const auto & settings =
plan_parser->getContext()->getSettingsRef();
+ auto function_name
+ = settings.has("arithmetic.decimal.mode") &&
settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
+ ? "sparkDecimalDivideEffect"
+ : "sparkDecimalDivide";
+ ;
+ return toFunctionNode(actions_dag, function_name, {left_arg,
right_arg, type_node});
+ }
return toFunctionNode(actions_dag, "sparkDivide", {left_arg,
right_arg});
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index 0aa6761584..d14a591277 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -446,6 +446,10 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
LiteralTransformer(m.nullOnOverflow)),
m
)
+ case PromotePrecision(_ @Cast(child, _: DecimalType, _, _))
+ if child.dataType
+ .isInstanceOf[DecimalType] &&
!BackendsApiManager.getSettings.transformCheckOverflow =>
+ replaceWithExpressionTransformer0(child, attributeSeq, expressionsMap)
case _: NormalizeNaNAndZero | _: PromotePrecision | _: TaggingExpression
=>
ChildTransformer(
substraitExprName,
@@ -466,16 +470,12 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
if !BackendsApiManager.getSettings.transformCheckOverflow &&
DecimalArithmeticUtil.isDecimalArithmetic(b) =>
DecimalArithmeticUtil.checkAllowDecimalArithmetic()
- val leftChild =
+ val arithmeticExprName = getAndCheckSubstraitName(b, expressionsMap)
+ val left =
replaceWithExpressionTransformer0(b.left, attributeSeq,
expressionsMap)
- val rightChild =
+ val right =
replaceWithExpressionTransformer0(b.right, attributeSeq,
expressionsMap)
- DecimalArithmeticExpressionTransformer(
- getAndCheckSubstraitName(b, expressionsMap),
- leftChild,
- rightChild,
- decimalType,
- b)
+ DecimalArithmeticExpressionTransformer(arithmeticExprName, left,
right, decimalType, b)
case c: CheckOverflow =>
CheckOverflowTransformer(
substraitExprName,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]