This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 4c52976e4 [GLUTEN-6176][CH] Support aggreate avg return decimal (#6177)
4c52976e4 is described below
commit 4c52976e4fce98e861da210f13a85a74d45f386e
Author: Shuai li <[email protected]>
AuthorDate: Tue Jun 25 10:28:39 2024 +0800
[GLUTEN-6176][CH] Support aggreate avg return decimal (#6177)
* Support aggreate avg return decimal
* update version
* fix rebase
* add ut
---
.../execution/GlutenClickHouseDecimalSuite.scala | 5 +-
.../AggregateFunctionSparkAvg.cpp | 158 +++++++++++++++++++++
cpp-ch/local-engine/Common/CHUtil.cpp | 9 +-
cpp-ch/local-engine/Common/CHUtil.h | 5 +-
cpp-ch/local-engine/Common/GlutenDecimalUtils.h | 108 ++++++++++++++
cpp-ch/local-engine/Parser/RelParser.cpp | 23 ++-
.../scala/org/apache/gluten/GlutenConfig.scala | 8 +-
7 files changed, 303 insertions(+), 13 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
index 088487101..7320b7c05 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
@@ -67,9 +67,9 @@ class GlutenClickHouseDecimalSuite
private val decimalTPCHTables: Seq[(DecimalType, Seq[Int])] = Seq.apply(
(DecimalType.apply(9, 4), Seq()),
// 1: ch decimal avg is float
- (DecimalType.apply(18, 8), Seq(1)),
+ (DecimalType.apply(18, 8), Seq()),
// 1: ch decimal avg is float, 3/10: all value is null and compare with
limit
- (DecimalType.apply(38, 19), Seq(1, 3, 10))
+ (DecimalType.apply(38, 19), Seq(3, 10))
)
private def createDecimalTables(dataType: DecimalType): Unit = {
@@ -337,7 +337,6 @@ class GlutenClickHouseDecimalSuite
allowPrecisionLoss =>
Range
.inclusive(1, 22)
- .filter(_ != 17) // Ignore Q17 which include avg
.foreach {
sql_num =>
{
diff --git
a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
new file mode 100644
index 000000000..5eb3a0b36
--- /dev/null
+++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <AggregateFunctions/AggregateFunctionAvg.h>
+#include <AggregateFunctions/AggregateFunctionFactory.h>
+#include <AggregateFunctions/FactoryHelpers.h>
+#include <AggregateFunctions/Helpers.h>
+#include <DataTypes/DataTypeTuple.h>
+
+#include <algorithm>
+
+#include <Common/CHUtil.h>
+#include <Common/GlutenDecimalUtils.h>
+
+namespace DB
+{
+struct Settings;
+
+namespace ErrorCodes
+{
+
+}
+}
+
+namespace local_engine
+{
+using namespace DB;
+
+
+DataTypePtr getSparkAvgReturnType(const DataTypePtr & arg_type)
+{
+ const UInt32 precision_value =
std::min<size_t>(getDecimalPrecision(*arg_type) + 4,
DecimalUtils::max_precision<Decimal128>);
+ const auto scale_value = std::min(getDecimalScale(*arg_type) + 4,
precision_value);
+ return createDecimal<DataTypeDecimal>(precision_value, scale_value);
+}
+
+template <typename T>
+requires is_decimal<T>
+class AggregateFunctionSparkAvg final : public AggregateFunctionAvg<T>
+{
+public:
+ using Base = AggregateFunctionAvg<T>;
+
+ explicit AggregateFunctionSparkAvg(const DataTypes & argument_types_,
UInt32 num_scale_, UInt32 round_scale_)
+ : Base(argument_types_, createResultType(argument_types_, num_scale_,
round_scale_), num_scale_)
+ , num_scale(num_scale_)
+ , round_scale(round_scale_)
+ {
+ }
+
+ DataTypePtr createResultType(const DataTypes & argument_types_, UInt32
num_scale_, UInt32 round_scale_)
+ {
+ const DataTypePtr & data_type = argument_types_[0];
+ const UInt32 precision_value =
std::min<size_t>(getDecimalPrecision(*data_type) + 4,
DecimalUtils::max_precision<Decimal128>);
+ const auto scale_value = std::min(num_scale_ + 4, precision_value);
+ return createDecimal<DataTypeDecimal>(precision_value, scale_value);
+ }
+
+ void insertResultInto(AggregateDataPtr __restrict place, IColumn & to,
Arena *) const override
+ {
+ const DataTypePtr & result_type = this->getResultType();
+ auto result_scale = getDecimalScale(*result_type);
+ WhichDataType which(result_type);
+ if (which.isDecimal32())
+ {
+ assert_cast<ColumnDecimal<Decimal32> &>(to).getData().push_back(
+ divideDecimalAndUInt(this->data(place), num_scale,
result_scale, round_scale));
+ }
+ else if (which.isDecimal64())
+ {
+ assert_cast<ColumnDecimal<Decimal64> &>(to).getData().push_back(
+ divideDecimalAndUInt(this->data(place), num_scale, result_scale,
round_scale));
+ }
+ else if (which.isDecimal128())
+ {
+ assert_cast<ColumnDecimal<Decimal128> &>(to).getData().push_back(
+ divideDecimalAndUInt(this->data(place), num_scale,
result_scale, round_scale));
+ }
+ else
+ {
+ assert_cast<ColumnDecimal<Decimal256> &>(to).getData().push_back(
+ divideDecimalAndUInt(this->data(place), num_scale,
result_scale, round_scale));
+ }
+ }
+
+ String getName() const override { return "sparkAvg"; }
+
+private:
+ Int128 NO_SANITIZE_UNDEFINED
+ divideDecimalAndUInt(AvgFraction<AvgFieldType<T>, UInt64> avg, UInt32
num_scale, UInt32 result_scale, UInt32 round_scale) const
+ {
+ auto value = avg.numerator.value;
+ if (result_scale > num_scale)
+ {
+ auto diff =
DecimalUtils::scaleMultiplier<AvgFieldType<T>>(result_scale - num_scale);
+ value = value * diff;
+ }
+ else if (result_scale < num_scale)
+ {
+ auto diff =
DecimalUtils::scaleMultiplier<AvgFieldType<T>>(num_scale - result_scale);
+ value = value / diff;
+ }
+
+ auto result = value / avg.denominator;
+
+ if (round_scale > result_scale)
+ return result;
+
+ auto round_diff =
DecimalUtils::scaleMultiplier<AvgFieldType<T>>(result_scale - round_scale);
+ return (result + round_diff / 2) / round_diff * round_diff;
+ }
+
+private:
+ UInt32 num_scale;
+ UInt32 round_scale;
+};
+
+AggregateFunctionPtr
+createAggregateFunctionSparkAvg(const std::string & name, const DataTypes &
argument_types, const Array & parameters, const Settings * settings)
+{
+ assertNoParameters(name, parameters);
+ assertUnary(name, argument_types);
+
+ AggregateFunctionPtr res;
+ const DataTypePtr & data_type = argument_types[0];
+ if (!isDecimal(data_type))
+ throw Exception(
+ ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument
for aggregate function {}", data_type->getName(), name);
+
+ bool allowPrecisionLoss =
settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).get<bool>();
+ const UInt32 p1 = DB::getDecimalPrecision(*data_type);
+ const UInt32 s1 = DB::getDecimalScale(*data_type);
+ auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL;
+ auto [_, round_scale] = GlutenDecimalUtils::dividePrecisionScale(p1, s1,
p2, s2, allowPrecisionLoss);
+
+ res.reset(createWithDecimalType<AggregateFunctionSparkAvg>(*data_type,
argument_types, getDecimalScale(*data_type), round_scale));
+ return res;
+}
+
+void registerAggregateFunctionSparkAvg(AggregateFunctionFactory & factory)
+{
+ factory.registerFunction("sparkAvg", createAggregateFunctionSparkAvg);
+}
+
+}
diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp
b/cpp-ch/local-engine/Common/CHUtil.cpp
index ae3f6dbd5..588cc1cb2 100644
--- a/cpp-ch/local-engine/Common/CHUtil.cpp
+++ b/cpp-ch/local-engine/Common/CHUtil.cpp
@@ -626,6 +626,7 @@ void
BackendInitializerUtil::initSettings(std::map<std::string, std::string> & b
settings.set("date_time_input_format", "best_effort");
settings.set(MERGETREE_MERGE_AFTER_INSERT, true);
settings.set(MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE, false);
+ settings.set(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS, true);
for (const auto & [key, value] : backend_conf_map)
{
@@ -665,6 +666,11 @@ void
BackendInitializerUtil::initSettings(std::map<std::string, std::string> & b
settings.set("session_timezone", time_zone_val);
LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{}
value:{}", "session_timezone", time_zone_val);
}
+ else if (key == DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)
+ {
+ settings.set(key, toField(key, value));
+ LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{}
value:{}", key, value);
+ }
}
/// Finally apply some fixed kvs to settings.
@@ -788,6 +794,7 @@ void BackendInitializerUtil::updateNewSettings(const
DB::ContextMutablePtr & con
extern void
registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCombinatorFactory
&);
extern void registerAggregateFunctionsBloomFilter(AggregateFunctionFactory &);
+extern void registerAggregateFunctionSparkAvg(AggregateFunctionFactory &);
extern void registerFunctions(FunctionFactory &);
void registerAllFunctions()
@@ -797,7 +804,7 @@ void registerAllFunctions()
DB::registerAggregateFunctions();
auto & agg_factory = AggregateFunctionFactory::instance();
registerAggregateFunctionsBloomFilter(agg_factory);
-
+ registerAggregateFunctionSparkAvg(agg_factory);
{
/// register aggregate function combinators from local_engine
auto & factory = AggregateFunctionCombinatorFactory::instance();
diff --git a/cpp-ch/local-engine/Common/CHUtil.h
b/cpp-ch/local-engine/Common/CHUtil.h
index 245d7b3d1..0321d410a 100644
--- a/cpp-ch/local-engine/Common/CHUtil.h
+++ b/cpp-ch/local-engine/Common/CHUtil.h
@@ -37,7 +37,10 @@ namespace local_engine
{
static const String MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE =
"mergetree.insert_without_local_storage";
static const String MERGETREE_MERGE_AFTER_INSERT =
"mergetree.merge_after_insert";
-static const std::unordered_set<String>
BOOL_VALUE_SETTINGS{MERGETREE_MERGE_AFTER_INSERT,
MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE};
+static const std::string DECIMAL_OPERATIONS_ALLOW_PREC_LOSS =
"spark.sql.decimalOperations.allowPrecisionLoss";
+
+static const std::unordered_set<String> BOOL_VALUE_SETTINGS{
+ MERGETREE_MERGE_AFTER_INSERT, MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE,
DECIMAL_OPERATIONS_ALLOW_PREC_LOSS};
static const std::unordered_set<String> LONG_VALUE_SETTINGS{
"optimize.maxfilesize", "optimize.minFileSize",
"mergetree.max_num_part_per_merge_task"};
diff --git a/cpp-ch/local-engine/Common/GlutenDecimalUtils.h
b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h
new file mode 100644
index 000000000..32af66ec5
--- /dev/null
+++ b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h
@@ -0,0 +1,108 @@
+/*
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+
+namespace local_engine
+{
+
+class GlutenDecimalUtils
+{
+public:
+ static constexpr size_t MAX_PRECISION = 38;
+ static constexpr size_t MAX_SCALE = 38;
+ static constexpr auto system_Default = std::tuple(MAX_PRECISION, 18);
+ static constexpr auto user_Default = std::tuple(10, 0);
+ static constexpr size_t MINIMUM_ADJUSTED_SCALE = 6;
+
+ // The decimal types compatible with other numeric types
+ static constexpr auto BOOLEAN_DECIMAL = std::tuple(1, 0);
+ static constexpr auto BYTE_DECIMAL = std::tuple(3, 0);
+ static constexpr auto SHORT_DECIMAL = std::tuple(5, 0);
+ static constexpr auto INT_DECIMAL = std::tuple(10, 0);
+ static constexpr auto LONG_DECIMAL = std::tuple(20, 0);
+ static constexpr auto FLOAT_DECIMAL = std::tuple(14, 7);
+ static constexpr auto DOUBLE_DECIMAL = std::tuple(30, 15);
+ static constexpr auto BIGINT_DECIMAL = std::tuple(MAX_PRECISION, 0);
+
+ static std::tuple<size_t, size_t> adjustPrecisionScale(size_t precision,
size_t scale)
+ {
+ if (precision <= MAX_PRECISION)
+ {
+ // Adjustment only needed when we exceed max precision
+ return std::tuple(precision, scale);
+ }
+ else if (scale < 0)
+ {
+ // Decimal can have negative scale (SPARK-24468). In this case, we
cannot allow a precision
+ // loss since we would cause a loss of digits in the integer part.
+ // In this case, we are likely to meet an overflow.
+ return std::tuple(GlutenDecimalUtils::MAX_PRECISION, scale);
+ }
+ else
+ {
+ // Precision/scale exceed maximum precision. Result must be
adjusted to MAX_PRECISION.
+ auto intDigits = precision - scale;
+ // If original scale is less than MINIMUM_ADJUSTED_SCALE, use
original scale value; otherwise
+ // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
+ auto minScaleValue = std::min(scale,
GlutenDecimalUtils::MINIMUM_ADJUSTED_SCALE);
+ // The resulting scale is the maximum between what is available
without causing a loss of
+ // digits for the integer part of the decimal and the minimum
guaranteed scale, which is
+ // computed above
+ auto adjustedScale = std::max(GlutenDecimalUtils::MAX_PRECISION -
intDigits, minScaleValue);
+
+ return std::tuple(GlutenDecimalUtils::MAX_PRECISION,
adjustedScale);
+ }
+ }
+
+ static std::tuple<size_t, size_t> dividePrecisionScale(size_t p1, size_t
s1, size_t p2, size_t s2, bool allowPrecisionLoss)
+ {
+ if (allowPrecisionLoss)
+ {
+ // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
+ // Scale: max(6, s1 + p2 + 1)
+ const size_t intDig = p1 - s1 + s2;
+ const size_t scale = std::max(MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1);
+ const size_t precision = intDig + scale;
+ return adjustPrecisionScale(precision, scale);
+ }
+ else
+ {
+ auto intDig = std::min(MAX_SCALE, p1 - s1 + s2);
+ auto decDig = std::min(MAX_SCALE, std::max(static_cast<size_t>(6),
s1 + p2 + 1));
+ auto diff = (intDig + decDig) - MAX_SCALE;
+ if (diff > 0)
+ {
+ decDig -= diff / 2 + 1;
+ intDig = MAX_SCALE - decDig;
+ }
+ return std::tuple(intDig + decDig, decDig);
+ }
+ }
+
+ static std::tuple<size_t, size_t> widerDecimalType(const size_t p1, const
size_t s1, const size_t p2, const size_t s2)
+ {
+ // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
+ auto scale = std::max(s1, s2);
+ auto range = std::max(p1 - s1, p2 - s2);
+ return std::tuple(range + scale, scale);
+ }
+
+};
+
+}
diff --git a/cpp-ch/local-engine/Parser/RelParser.cpp
b/cpp-ch/local-engine/Parser/RelParser.cpp
index 7fc807827..282339c4d 100644
--- a/cpp-ch/local-engine/Parser/RelParser.cpp
+++ b/cpp-ch/local-engine/Parser/RelParser.cpp
@@ -15,12 +15,16 @@
* limitations under the License.
*/
#include "RelParser.h"
+
#include <string>
+#include <google/protobuf/wrappers.pb.h>
+
#include <AggregateFunctions/AggregateFunctionFactory.h>
+#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/IDataType.h>
-#include <Common/Exception.h>
-#include <google/protobuf/wrappers.pb.h>
#include <Poco/StringTokenizer.h>
+#include <Common/Exception.h>
+
namespace DB
{
@@ -38,7 +42,20 @@ AggregateFunctionPtr RelParser::getAggregateFunction(
{
auto & factory = AggregateFunctionFactory::instance();
auto action = NullsAction::EMPTY;
- return factory.get(name, action, arg_types, parameters, properties);
+
+ String function_name = name;
+ if (name == "avg" && isDecimal(removeNullable(arg_types[0])))
+ function_name = "sparkAvg";
+ else if (name == "avgPartialMerge")
+ {
+ if (auto agg_func = typeid_cast<const DataTypeAggregateFunction
*>(arg_types[0].get());
+ !agg_func->getArgumentsDataTypes().empty() &&
isDecimal(removeNullable(agg_func->getArgumentsDataTypes()[0])))
+ {
+ function_name = "sparkAvgPartialMerge";
+ }
+ }
+
+ return factory.get(function_name, action, arg_types, parameters,
properties);
}
std::optional<String> RelParser::parseSignatureFunctionName(UInt32
function_ref)
diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
index 148e8cdc0..4b4e29e7d 100644
--- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
@@ -718,7 +718,9 @@ object GlutenConfig {
GLUTEN_OFFHEAP_SIZE_IN_BYTES_KEY,
GLUTEN_TASK_OFFHEAP_SIZE_IN_BYTES_KEY,
- GLUTEN_OFFHEAP_ENABLED
+ GLUTEN_OFFHEAP_ENABLED,
+ SESSION_LOCAL_TIMEZONE.key,
+ DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key
)
nativeConfMap.putAll(conf.filter(e => keys.contains(e._1)).asJava)
@@ -735,10 +737,6 @@ object GlutenConfig {
.filter(_._1.startsWith(SPARK_ABFS_ACCOUNT_KEY))
.foreach(entry => nativeConfMap.put(entry._1, entry._2))
- conf
- .filter(_._1.startsWith(SQLConf.SESSION_LOCAL_TIMEZONE.key))
- .foreach(entry => nativeConfMap.put(entry._1, entry._2))
-
// return
nativeConfMap
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]