This is an automated email from the ASF dual-hosted git repository.

changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 4c52976e4 [GLUTEN-6176][CH] Support aggreate avg return decimal (#6177)
4c52976e4 is described below

commit 4c52976e4fce98e861da210f13a85a74d45f386e
Author: Shuai li <[email protected]>
AuthorDate: Tue Jun 25 10:28:39 2024 +0800

    [GLUTEN-6176][CH] Support aggreate avg return decimal (#6177)
    
    * Support aggreate avg return decimal
    
    * update version
    
    * fix rebase
    
    * add ut
---
 .../execution/GlutenClickHouseDecimalSuite.scala   |   5 +-
 .../AggregateFunctionSparkAvg.cpp                  | 158 +++++++++++++++++++++
 cpp-ch/local-engine/Common/CHUtil.cpp              |   9 +-
 cpp-ch/local-engine/Common/CHUtil.h                |   5 +-
 cpp-ch/local-engine/Common/GlutenDecimalUtils.h    | 108 ++++++++++++++
 cpp-ch/local-engine/Parser/RelParser.cpp           |  23 ++-
 .../scala/org/apache/gluten/GlutenConfig.scala     |   8 +-
 7 files changed, 303 insertions(+), 13 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
index 088487101..7320b7c05 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
@@ -67,9 +67,9 @@ class GlutenClickHouseDecimalSuite
   private val decimalTPCHTables: Seq[(DecimalType, Seq[Int])] = Seq.apply(
     (DecimalType.apply(9, 4), Seq()),
     // 1: ch decimal avg is float
-    (DecimalType.apply(18, 8), Seq(1)),
+    (DecimalType.apply(18, 8), Seq()),
     // 1: ch decimal avg is float, 3/10: all value is null and compare with 
limit
-    (DecimalType.apply(38, 19), Seq(1, 3, 10))
+    (DecimalType.apply(38, 19), Seq(3, 10))
   )
 
   private def createDecimalTables(dataType: DecimalType): Unit = {
@@ -337,7 +337,6 @@ class GlutenClickHouseDecimalSuite
     allowPrecisionLoss =>
       Range
         .inclusive(1, 22)
-        .filter(_ != 17) // Ignore Q17 which include avg
         .foreach {
           sql_num =>
             {
diff --git 
a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp 
b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
new file mode 100644
index 000000000..5eb3a0b36
--- /dev/null
+++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <AggregateFunctions/AggregateFunctionAvg.h>
+#include <AggregateFunctions/AggregateFunctionFactory.h>
+#include <AggregateFunctions/FactoryHelpers.h>
+#include <AggregateFunctions/Helpers.h>
+#include <DataTypes/DataTypeTuple.h>
+
+#include <algorithm>
+
+#include <Common/CHUtil.h>
+#include <Common/GlutenDecimalUtils.h>
+
+namespace DB
+{
+struct Settings;
+
+namespace ErrorCodes
+{
+
+}
+}
+
+namespace local_engine
+{
+using namespace DB;
+
+
+DataTypePtr getSparkAvgReturnType(const DataTypePtr & arg_type)
+{
+    const UInt32 precision_value = 
std::min<size_t>(getDecimalPrecision(*arg_type) + 4, 
DecimalUtils::max_precision<Decimal128>);
+    const auto scale_value = std::min(getDecimalScale(*arg_type) + 4, 
precision_value);
+    return createDecimal<DataTypeDecimal>(precision_value, scale_value);
+}
+
+template <typename T>
+requires is_decimal<T>
+class AggregateFunctionSparkAvg final : public AggregateFunctionAvg<T>
+{
+public:
+    using Base = AggregateFunctionAvg<T>;
+
+    explicit AggregateFunctionSparkAvg(const DataTypes & argument_types_, 
UInt32 num_scale_, UInt32 round_scale_)
+        : Base(argument_types_, createResultType(argument_types_, num_scale_, 
round_scale_), num_scale_)
+        , num_scale(num_scale_)
+        , round_scale(round_scale_)
+    {
+    }
+
+    DataTypePtr createResultType(const DataTypes & argument_types_, UInt32 
num_scale_, UInt32 round_scale_)
+    {
+        const DataTypePtr & data_type = argument_types_[0];
+        const UInt32 precision_value = 
std::min<size_t>(getDecimalPrecision(*data_type) + 4, 
DecimalUtils::max_precision<Decimal128>);
+        const auto scale_value = std::min(num_scale_ + 4, precision_value);
+        return createDecimal<DataTypeDecimal>(precision_value, scale_value);
+    }
+
+    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, 
Arena *) const override
+    {
+        const DataTypePtr & result_type = this->getResultType();
+        auto result_scale = getDecimalScale(*result_type);
+        WhichDataType which(result_type);
+        if (which.isDecimal32())
+        {
+            assert_cast<ColumnDecimal<Decimal32> &>(to).getData().push_back(
+                divideDecimalAndUInt(this->data(place), num_scale, 
result_scale, round_scale));
+        }
+        else if (which.isDecimal64())
+        {
+            assert_cast<ColumnDecimal<Decimal64> &>(to).getData().push_back(
+          divideDecimalAndUInt(this->data(place), num_scale, result_scale, 
round_scale));
+        }
+        else if (which.isDecimal128())
+        {
+            assert_cast<ColumnDecimal<Decimal128> &>(to).getData().push_back(
+                divideDecimalAndUInt(this->data(place), num_scale, 
result_scale, round_scale));
+        }
+        else
+        {
+            assert_cast<ColumnDecimal<Decimal256> &>(to).getData().push_back(
+                divideDecimalAndUInt(this->data(place), num_scale, 
result_scale, round_scale));
+        }
+    }
+
+    String getName() const override { return "sparkAvg"; }
+
+private:
+    Int128 NO_SANITIZE_UNDEFINED
+    divideDecimalAndUInt(AvgFraction<AvgFieldType<T>, UInt64> avg, UInt32 
num_scale, UInt32 result_scale, UInt32 round_scale) const
+    {
+        auto value = avg.numerator.value;
+        if (result_scale > num_scale)
+        {
+            auto diff = 
DecimalUtils::scaleMultiplier<AvgFieldType<T>>(result_scale - num_scale);
+            value = value * diff;
+        }
+        else if (result_scale < num_scale)
+        {
+            auto diff = 
DecimalUtils::scaleMultiplier<AvgFieldType<T>>(num_scale - result_scale);
+            value = value / diff;
+        }
+
+        auto result = value / avg.denominator;
+
+        if (round_scale > result_scale)
+            return result;
+
+        auto round_diff = 
DecimalUtils::scaleMultiplier<AvgFieldType<T>>(result_scale - round_scale);
+        return (result + round_diff / 2) / round_diff * round_diff;
+    }
+
+private:
+    UInt32 num_scale;
+    UInt32 round_scale;
+};
+
+AggregateFunctionPtr
+createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & 
argument_types, const Array & parameters, const Settings * settings)
+{
+    assertNoParameters(name, parameters);
+    assertUnary(name, argument_types);
+
+    AggregateFunctionPtr res;
+    const DataTypePtr & data_type = argument_types[0];
+    if (!isDecimal(data_type))
+        throw Exception(
+            ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument 
for aggregate function {}", data_type->getName(), name);
+
+    bool allowPrecisionLoss = 
settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).get<bool>();
+    const UInt32 p1 = DB::getDecimalPrecision(*data_type);
+    const UInt32 s1 = DB::getDecimalScale(*data_type);
+    auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL;
+    auto [_, round_scale] = GlutenDecimalUtils::dividePrecisionScale(p1, s1, 
p2, s2, allowPrecisionLoss);
+
+    res.reset(createWithDecimalType<AggregateFunctionSparkAvg>(*data_type, 
argument_types, getDecimalScale(*data_type), round_scale));
+    return res;
+}
+
+void registerAggregateFunctionSparkAvg(AggregateFunctionFactory & factory)
+{
+    factory.registerFunction("sparkAvg", createAggregateFunctionSparkAvg);
+}
+
+}
diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp 
b/cpp-ch/local-engine/Common/CHUtil.cpp
index ae3f6dbd5..588cc1cb2 100644
--- a/cpp-ch/local-engine/Common/CHUtil.cpp
+++ b/cpp-ch/local-engine/Common/CHUtil.cpp
@@ -626,6 +626,7 @@ void 
BackendInitializerUtil::initSettings(std::map<std::string, std::string> & b
     settings.set("date_time_input_format", "best_effort");
     settings.set(MERGETREE_MERGE_AFTER_INSERT, true);
     settings.set(MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE, false);
+    settings.set(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS, true);
 
     for (const auto & [key, value] : backend_conf_map)
     {
@@ -665,6 +666,11 @@ void 
BackendInitializerUtil::initSettings(std::map<std::string, std::string> & b
             settings.set("session_timezone", time_zone_val);
             LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{} 
value:{}", "session_timezone", time_zone_val);
         }
+        else if (key == DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)
+        {
+            settings.set(key, toField(key, value));
+            LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{} 
value:{}", key, value);
+        }
     }
 
     /// Finally apply some fixed kvs to settings.
@@ -788,6 +794,7 @@ void BackendInitializerUtil::updateNewSettings(const 
DB::ContextMutablePtr & con
 
 extern void 
registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCombinatorFactory
 &);
 extern void registerAggregateFunctionsBloomFilter(AggregateFunctionFactory &);
+extern void registerAggregateFunctionSparkAvg(AggregateFunctionFactory &);
 extern void registerFunctions(FunctionFactory &);
 
 void registerAllFunctions()
@@ -797,7 +804,7 @@ void registerAllFunctions()
     DB::registerAggregateFunctions();
     auto & agg_factory = AggregateFunctionFactory::instance();
     registerAggregateFunctionsBloomFilter(agg_factory);
-
+    registerAggregateFunctionSparkAvg(agg_factory);
     {
         /// register aggregate function combinators from local_engine
         auto & factory = AggregateFunctionCombinatorFactory::instance();
diff --git a/cpp-ch/local-engine/Common/CHUtil.h 
b/cpp-ch/local-engine/Common/CHUtil.h
index 245d7b3d1..0321d410a 100644
--- a/cpp-ch/local-engine/Common/CHUtil.h
+++ b/cpp-ch/local-engine/Common/CHUtil.h
@@ -37,7 +37,10 @@ namespace local_engine
 {
 static const String MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE = 
"mergetree.insert_without_local_storage";
 static const String MERGETREE_MERGE_AFTER_INSERT = 
"mergetree.merge_after_insert";
-static const std::unordered_set<String> 
BOOL_VALUE_SETTINGS{MERGETREE_MERGE_AFTER_INSERT, 
MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE};
+static const std::string DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = 
"spark.sql.decimalOperations.allowPrecisionLoss";
+
+static const std::unordered_set<String> BOOL_VALUE_SETTINGS{
+    MERGETREE_MERGE_AFTER_INSERT, MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE, 
DECIMAL_OPERATIONS_ALLOW_PREC_LOSS};
 static const std::unordered_set<String> LONG_VALUE_SETTINGS{
     "optimize.maxfilesize", "optimize.minFileSize", 
"mergetree.max_num_part_per_merge_task"};
 
diff --git a/cpp-ch/local-engine/Common/GlutenDecimalUtils.h 
b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h
new file mode 100644
index 000000000..32af66ec5
--- /dev/null
+++ b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h
@@ -0,0 +1,108 @@
+/*
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+
+namespace local_engine
+{
+
+class GlutenDecimalUtils
+{
+public:
+    static constexpr size_t MAX_PRECISION = 38;
+    static constexpr size_t MAX_SCALE = 38;
+    static constexpr auto system_Default = std::tuple(MAX_PRECISION, 18);
+    static constexpr auto user_Default = std::tuple(10, 0);
+    static constexpr size_t MINIMUM_ADJUSTED_SCALE = 6;
+
+    // The decimal types compatible with other numeric types
+    static constexpr auto BOOLEAN_DECIMAL = std::tuple(1, 0);
+    static constexpr auto BYTE_DECIMAL = std::tuple(3, 0);
+    static constexpr auto SHORT_DECIMAL = std::tuple(5, 0);
+    static constexpr auto INT_DECIMAL = std::tuple(10, 0);
+    static constexpr auto LONG_DECIMAL = std::tuple(20, 0);
+    static constexpr auto FLOAT_DECIMAL = std::tuple(14, 7);
+    static constexpr auto DOUBLE_DECIMAL = std::tuple(30, 15);
+    static constexpr auto BIGINT_DECIMAL = std::tuple(MAX_PRECISION, 0);
+
+    static std::tuple<size_t, size_t> adjustPrecisionScale(size_t precision, 
size_t scale)
+    {
+        if (precision <= MAX_PRECISION)
+        {
+            // Adjustment only needed when we exceed max precision
+            return std::tuple(precision, scale);
+        }
+        else if (scale < 0)
+        {
+            // Decimal can have negative scale (SPARK-24468). In this case, we 
cannot allow a precision
+            // loss since we would cause a loss of digits in the integer part.
+            // In this case, we are likely to meet an overflow.
+            return std::tuple(GlutenDecimalUtils::MAX_PRECISION, scale);
+        }
+        else
+        {
+            // Precision/scale exceed maximum precision. Result must be 
adjusted to MAX_PRECISION.
+            auto intDigits = precision - scale;
+            // If original scale is less than MINIMUM_ADJUSTED_SCALE, use 
original scale value; otherwise
+            // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
+            auto minScaleValue = std::min(scale, 
GlutenDecimalUtils::MINIMUM_ADJUSTED_SCALE);
+            // The resulting scale is the maximum between what is available 
without causing a loss of
+            // digits for the integer part of the decimal and the minimum 
guaranteed scale, which is
+            // computed above
+            auto adjustedScale = std::max(GlutenDecimalUtils::MAX_PRECISION - 
intDigits, minScaleValue);
+
+            return std::tuple(GlutenDecimalUtils::MAX_PRECISION, 
adjustedScale);
+        }
+    }
+
+    static std::tuple<size_t, size_t> dividePrecisionScale(size_t p1, size_t 
s1, size_t p2, size_t s2, bool allowPrecisionLoss)
+    {
+        if (allowPrecisionLoss)
+        {
+            // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
+            // Scale: max(6, s1 + p2 + 1)
+            const size_t intDig = p1 - s1 + s2;
+            const size_t scale = std::max(MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1);
+            const size_t precision = intDig + scale;
+            return adjustPrecisionScale(precision, scale);
+        }
+        else
+        {
+            auto intDig = std::min(MAX_SCALE, p1 - s1 + s2);
+            auto decDig = std::min(MAX_SCALE, std::max(static_cast<size_t>(6), 
s1 + p2 + 1));
+            auto diff = (intDig + decDig) - MAX_SCALE;
+            if (diff > 0)
+            {
+                decDig -= diff / 2 + 1;
+                intDig = MAX_SCALE - decDig;
+            }
+            return std::tuple(intDig + decDig, decDig);
+        }
+    }
+
+    static std::tuple<size_t, size_t> widerDecimalType(const size_t p1, const 
size_t s1, const size_t p2, const size_t s2)
+    {
+        // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
+        auto scale = std::max(s1, s2);
+        auto range = std::max(p1 - s1, p2 - s2);
+        return std::tuple(range + scale, scale);
+    }
+
+};
+
+}
diff --git a/cpp-ch/local-engine/Parser/RelParser.cpp 
b/cpp-ch/local-engine/Parser/RelParser.cpp
index 7fc807827..282339c4d 100644
--- a/cpp-ch/local-engine/Parser/RelParser.cpp
+++ b/cpp-ch/local-engine/Parser/RelParser.cpp
@@ -15,12 +15,16 @@
  * limitations under the License.
  */
 #include "RelParser.h"
+
 #include <string>
+#include <google/protobuf/wrappers.pb.h>
+
 #include <AggregateFunctions/AggregateFunctionFactory.h>
+#include <DataTypes/DataTypeAggregateFunction.h>
 #include <DataTypes/IDataType.h>
-#include <Common/Exception.h>
-#include <google/protobuf/wrappers.pb.h>
 #include <Poco/StringTokenizer.h>
+#include <Common/Exception.h>
+
 
 namespace DB
 {
@@ -38,7 +42,20 @@ AggregateFunctionPtr RelParser::getAggregateFunction(
 {
     auto & factory = AggregateFunctionFactory::instance();
     auto action = NullsAction::EMPTY;
-    return factory.get(name, action, arg_types, parameters, properties);
+
+    String function_name = name;
+    if (name == "avg" && isDecimal(removeNullable(arg_types[0])))
+        function_name = "sparkAvg";
+    else if (name == "avgPartialMerge")
+    {
+        if (auto agg_func = typeid_cast<const DataTypeAggregateFunction 
*>(arg_types[0].get());
+            !agg_func->getArgumentsDataTypes().empty() && 
isDecimal(removeNullable(agg_func->getArgumentsDataTypes()[0])))
+        {
+            function_name = "sparkAvgPartialMerge";
+        }
+    }
+
+    return factory.get(function_name, action, arg_types, parameters, 
properties);
 }
 
 std::optional<String> RelParser::parseSignatureFunctionName(UInt32 
function_ref)
diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala 
b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
index 148e8cdc0..4b4e29e7d 100644
--- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
@@ -718,7 +718,9 @@ object GlutenConfig {
 
       GLUTEN_OFFHEAP_SIZE_IN_BYTES_KEY,
       GLUTEN_TASK_OFFHEAP_SIZE_IN_BYTES_KEY,
-      GLUTEN_OFFHEAP_ENABLED
+      GLUTEN_OFFHEAP_ENABLED,
+      SESSION_LOCAL_TIMEZONE.key,
+      DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key
     )
     nativeConfMap.putAll(conf.filter(e => keys.contains(e._1)).asJava)
 
@@ -735,10 +737,6 @@ object GlutenConfig {
       .filter(_._1.startsWith(SPARK_ABFS_ACCOUNT_KEY))
       .foreach(entry => nativeConfMap.put(entry._1, entry._2))
 
-    conf
-      .filter(_._1.startsWith(SQLConf.SESSION_LOCAL_TIMEZONE.key))
-      .foreach(entry => nativeConfMap.put(entry._1, entry._2))
-
     // return
     nativeConfMap
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to