This is an automated email from the ASF dual-hosted git repository.

liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 10a663c2b [GLUTEN-6156][CH]Fix least diff (#6155)
10a663c2b is described below

commit 10a663c2b86c73490cdaee1d94177cb485c9fe31
Author: KevinyhZou <[email protected]>
AuthorDate: Wed Jun 26 17:02:00 2024 +0800

    [GLUTEN-6156][CH]Fix least diff (#6155)
    
    What changes were proposed in this pull request?
    (Please fill in changes proposed in this fix)
    
    (Fixes: #6156)
    
    How was this patch tested?
    test by ut
---
 .../GlutenClickHouseTPCHSaltNullParquetSuite.scala |  4 +-
 ...unctionGreatest.cpp => FunctionGreatestLeast.h} | 30 +++++++-------
 .../Functions/SparkFunctionGreatest.cpp            | 47 +++-------------------
 .../local-engine/Functions/SparkFunctionLeast.cpp  | 38 +++++++++++++++++
 cpp-ch/local-engine/Parser/SerializedPlanParser.h  |  2 +-
 5 files changed, 62 insertions(+), 59 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 118f84186..188995f11 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -2575,12 +2575,12 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends 
GlutenClickHouseTPCHAbstr
     spark.sql("drop table test_tbl_5096")
   }
 
-  test("GLUTEN-5896: Bug fix greatest diff") {
+  test("GLUTEN-5896: Bug fix greatest/least diff") {
     val tbl_create_sql =
       "create table test_tbl_5896(id bigint, x1 int, x2 int, x3 int) using 
parquet"
     val tbl_insert_sql =
       "insert into test_tbl_5896 values(1, 12, NULL, 13), (2, NULL, NULL, 
NULL), (3, 11, NULL, NULL), (4, 10, 9, 8)"
-    val select_sql = "select id, greatest(x1, x2, x3) from test_tbl_5896"
+    val select_sql = "select id, greatest(x1, x2, x3), least(x1, x2, x3) from 
test_tbl_5896"
     spark.sql(tbl_create_sql)
     spark.sql(tbl_insert_sql)
     compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp 
b/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h
similarity index 75%
copy from cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp
copy to cpp-ch/local-engine/Functions/FunctionGreatestLeast.h
index 9577d65ec..6930c1d75 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp
+++ b/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h
@@ -25,23 +25,20 @@ namespace ErrorCodes
     extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
 }
 }
-
 namespace local_engine
 {
-class SparkFunctionGreatest : public 
DB::FunctionLeastGreatestGeneric<DB::LeastGreatest::Greatest>
+template <DB::LeastGreatest kind>
+class FunctionGreatestestLeast : public DB::FunctionLeastGreatestGeneric<kind>
 {
 public:
-    static constexpr auto name = "sparkGreatest";
-    static DB::FunctionPtr create(DB::ContextPtr) { return 
std::make_shared<SparkFunctionGreatest>(); }
-    SparkFunctionGreatest() = default;
-    ~SparkFunctionGreatest() override = default;
     bool useDefaultImplementationForNulls() const override { return false; }
+    virtual String getName() const = 0;
 
 private:
     DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const 
override
     {
         if (types.empty())
-            throw 
DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} 
cannot be called without arguments", name);
+            throw 
DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} 
cannot be called without arguments", getName());
         return makeNullable(getLeastSupertype(types));
     }
 
@@ -58,9 +55,18 @@ private:
             size_t best_arg = 0;
             for (size_t arg = 1; arg < num_arguments; ++arg)
             {
-                auto cmp_result = converted_columns[arg]->compareAt(row_num, 
row_num, *converted_columns[best_arg], -1);
-                if (cmp_result > 0)
-                    best_arg = arg;
+                if constexpr (kind == DB::LeastGreatest::Greatest)
+                {
+                    auto cmp_result = 
converted_columns[arg]->compareAt(row_num, row_num, 
*converted_columns[best_arg], -1);
+                    if (cmp_result > 0)
+                        best_arg = arg;
+                }
+                else
+                {
+                    auto cmp_result = 
converted_columns[arg]->compareAt(row_num, row_num, 
*converted_columns[best_arg], 1);
+                    if (cmp_result <  0)
+                        best_arg = arg;
+                }
             }
             result_column->insertFrom(*converted_columns[best_arg], row_num);
         }
@@ -68,8 +74,4 @@ private:
     }
 };
 
-REGISTER_FUNCTION(SparkGreatest)
-{
-    factory.registerFunction<SparkFunctionGreatest>();
-}
 }
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp
index 9577d65ec..920fe1b9c 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp
@@ -14,58 +14,21 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#include <Functions/LeastGreatestGeneric.h>
-#include <DataTypes/getLeastSupertype.h>
-#include <DataTypes/DataTypeNullable.h>
-
-namespace DB
-{
-namespace ErrorCodes
-{
-    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
-}
-}
+#include <Functions/FunctionGreatestLeast.h>
 
 namespace local_engine
 {
-class SparkFunctionGreatest : public 
DB::FunctionLeastGreatestGeneric<DB::LeastGreatest::Greatest>
+class SparkFunctionGreatest : public 
FunctionGreatestestLeast<DB::LeastGreatest::Greatest>
 {
 public:
     static constexpr auto name = "sparkGreatest";
     static DB::FunctionPtr create(DB::ContextPtr) { return 
std::make_shared<SparkFunctionGreatest>(); }
     SparkFunctionGreatest() = default;
     ~SparkFunctionGreatest() override = default;
-    bool useDefaultImplementationForNulls() const override { return false; }
-
-private:
-    DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const 
override
-    {
-        if (types.empty())
-            throw 
DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} 
cannot be called without arguments", name);
-        return makeNullable(getLeastSupertype(types));
-    }
-
-    DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, 
const DB::DataTypePtr & result_type, size_t input_rows_count) const override
+    String getName() const override
     {
-        size_t num_arguments = arguments.size();
-        DB::Columns converted_columns(num_arguments);
-        for (size_t arg = 0; arg < num_arguments; ++arg)
-            converted_columns[arg] = castColumn(arguments[arg], 
result_type)->convertToFullColumnIfConst();
-        auto result_column = result_type->createColumn();
-        result_column->reserve(input_rows_count);
-        for (size_t row_num = 0; row_num < input_rows_count; ++row_num)
-        {
-            size_t best_arg = 0;
-            for (size_t arg = 1; arg < num_arguments; ++arg)
-            {
-                auto cmp_result = converted_columns[arg]->compareAt(row_num, 
row_num, *converted_columns[best_arg], -1);
-                if (cmp_result > 0)
-                    best_arg = arg;
-            }
-            result_column->insertFrom(*converted_columns[best_arg], row_num);
-        }
-        return result_column;
-    }
+        return name;
+    } 
 };
 
 REGISTER_FUNCTION(SparkGreatest)
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp
new file mode 100644
index 000000000..70aafdf07
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <Functions/FunctionGreatestLeast.h>
+
+namespace local_engine
+{
+class SparkFunctionLeast : public 
FunctionGreatestestLeast<DB::LeastGreatest::Least>
+{
+public:
+    static constexpr auto name = "sparkLeast";
+    static DB::FunctionPtr create(DB::ContextPtr) { return 
std::make_shared<SparkFunctionLeast>(); }
+    SparkFunctionLeast() = default;
+    ~SparkFunctionLeast() override = default;
+    String getName() const override
+    {
+        return name;
+    } 
+};
+
+REGISTER_FUNCTION(SparkLeast)
+{
+    factory.registerFunction<SparkFunctionLeast>();
+}
+}
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h 
b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
index 6ce92b558..184065836 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
@@ -105,7 +105,7 @@ static const std::map<std::string, std::string> 
SCALAR_FUNCTIONS
        {"sign", "sign"},
        {"radians", "radians"},
        {"greatest", "sparkGreatest"},
-       {"least", "least"},
+       {"least", "sparkLeast"},
        {"shiftleft", "bitShiftLeft"},
        {"shiftright", "bitShiftRight"},
        {"check_overflow", "checkDecimalOverflowSpark"},


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to