This is an automated email from the ASF dual-hosted git repository.
liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 10a663c2b [GLUTEN-6156][CH]Fix least diff (#6155)
10a663c2b is described below
commit 10a663c2b86c73490cdaee1d94177cb485c9fe31
Author: KevinyhZou <[email protected]>
AuthorDate: Wed Jun 26 17:02:00 2024 +0800
[GLUTEN-6156][CH]Fix least diff (#6155)
What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)
(Fixes: #6156)
How was this patch tested?
test by ut
---
.../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 4 +-
...unctionGreatest.cpp => FunctionGreatestLeast.h} | 30 +++++++-------
.../Functions/SparkFunctionGreatest.cpp | 47 +++-------------------
.../local-engine/Functions/SparkFunctionLeast.cpp | 38 +++++++++++++++++
cpp-ch/local-engine/Parser/SerializedPlanParser.h | 2 +-
5 files changed, 62 insertions(+), 59 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 118f84186..188995f11 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -2575,12 +2575,12 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
spark.sql("drop table test_tbl_5096")
}
- test("GLUTEN-5896: Bug fix greatest diff") {
+ test("GLUTEN-5896: Bug fix greatest/least diff") {
val tbl_create_sql =
"create table test_tbl_5896(id bigint, x1 int, x2 int, x3 int) using
parquet"
val tbl_insert_sql =
"insert into test_tbl_5896 values(1, 12, NULL, 13), (2, NULL, NULL,
NULL), (3, 11, NULL, NULL), (4, 10, 9, 8)"
- val select_sql = "select id, greatest(x1, x2, x3) from test_tbl_5896"
+ val select_sql = "select id, greatest(x1, x2, x3), least(x1, x2, x3) from
test_tbl_5896"
spark.sql(tbl_create_sql)
spark.sql(tbl_insert_sql)
compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp
b/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h
similarity index 75%
copy from cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp
copy to cpp-ch/local-engine/Functions/FunctionGreatestLeast.h
index 9577d65ec..6930c1d75 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp
+++ b/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h
@@ -25,23 +25,20 @@ namespace ErrorCodes
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}
-
namespace local_engine
{
-class SparkFunctionGreatest : public
DB::FunctionLeastGreatestGeneric<DB::LeastGreatest::Greatest>
+template <DB::LeastGreatest kind>
+class FunctionGreatestestLeast : public DB::FunctionLeastGreatestGeneric<kind>
{
public:
- static constexpr auto name = "sparkGreatest";
- static DB::FunctionPtr create(DB::ContextPtr) { return
std::make_shared<SparkFunctionGreatest>(); }
- SparkFunctionGreatest() = default;
- ~SparkFunctionGreatest() override = default;
bool useDefaultImplementationForNulls() const override { return false; }
+ virtual String getName() const = 0;
private:
DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const
override
{
if (types.empty())
- throw
DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}
cannot be called without arguments", name);
+ throw
DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}
cannot be called without arguments", getName());
return makeNullable(getLeastSupertype(types));
}
@@ -58,9 +55,18 @@ private:
size_t best_arg = 0;
for (size_t arg = 1; arg < num_arguments; ++arg)
{
- auto cmp_result = converted_columns[arg]->compareAt(row_num,
row_num, *converted_columns[best_arg], -1);
- if (cmp_result > 0)
- best_arg = arg;
+ if constexpr (kind == DB::LeastGreatest::Greatest)
+ {
+ auto cmp_result =
converted_columns[arg]->compareAt(row_num, row_num,
*converted_columns[best_arg], -1);
+ if (cmp_result > 0)
+ best_arg = arg;
+ }
+ else
+ {
+ auto cmp_result =
converted_columns[arg]->compareAt(row_num, row_num,
*converted_columns[best_arg], 1);
+ if (cmp_result < 0)
+ best_arg = arg;
+ }
}
result_column->insertFrom(*converted_columns[best_arg], row_num);
}
@@ -68,8 +74,4 @@ private:
}
};
-REGISTER_FUNCTION(SparkGreatest)
-{
- factory.registerFunction<SparkFunctionGreatest>();
-}
}
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp
index 9577d65ec..920fe1b9c 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp
@@ -14,58 +14,21 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include <Functions/LeastGreatestGeneric.h>
-#include <DataTypes/getLeastSupertype.h>
-#include <DataTypes/DataTypeNullable.h>
-
-namespace DB
-{
-namespace ErrorCodes
-{
- extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
-}
-}
+#include <Functions/FunctionGreatestLeast.h>
namespace local_engine
{
-class SparkFunctionGreatest : public
DB::FunctionLeastGreatestGeneric<DB::LeastGreatest::Greatest>
+class SparkFunctionGreatest : public
FunctionGreatestestLeast<DB::LeastGreatest::Greatest>
{
public:
static constexpr auto name = "sparkGreatest";
static DB::FunctionPtr create(DB::ContextPtr) { return
std::make_shared<SparkFunctionGreatest>(); }
SparkFunctionGreatest() = default;
~SparkFunctionGreatest() override = default;
- bool useDefaultImplementationForNulls() const override { return false; }
-
-private:
- DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const
override
- {
- if (types.empty())
- throw
DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}
cannot be called without arguments", name);
- return makeNullable(getLeastSupertype(types));
- }
-
- DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments,
const DB::DataTypePtr & result_type, size_t input_rows_count) const override
+ String getName() const override
{
- size_t num_arguments = arguments.size();
- DB::Columns converted_columns(num_arguments);
- for (size_t arg = 0; arg < num_arguments; ++arg)
- converted_columns[arg] = castColumn(arguments[arg],
result_type)->convertToFullColumnIfConst();
- auto result_column = result_type->createColumn();
- result_column->reserve(input_rows_count);
- for (size_t row_num = 0; row_num < input_rows_count; ++row_num)
- {
- size_t best_arg = 0;
- for (size_t arg = 1; arg < num_arguments; ++arg)
- {
- auto cmp_result = converted_columns[arg]->compareAt(row_num,
row_num, *converted_columns[best_arg], -1);
- if (cmp_result > 0)
- best_arg = arg;
- }
- result_column->insertFrom(*converted_columns[best_arg], row_num);
- }
- return result_column;
- }
+ return name;
+ }
};
REGISTER_FUNCTION(SparkGreatest)
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp
new file mode 100644
index 000000000..70aafdf07
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <Functions/FunctionGreatestLeast.h>
+
+namespace local_engine
+{
+class SparkFunctionLeast : public
FunctionGreatestestLeast<DB::LeastGreatest::Least>
+{
+public:
+ static constexpr auto name = "sparkLeast";
+ static DB::FunctionPtr create(DB::ContextPtr) { return
std::make_shared<SparkFunctionLeast>(); }
+ SparkFunctionLeast() = default;
+ ~SparkFunctionLeast() override = default;
+ String getName() const override
+ {
+ return name;
+ }
+};
+
+REGISTER_FUNCTION(SparkLeast)
+{
+ factory.registerFunction<SparkFunctionLeast>();
+}
+}
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
index 6ce92b558..184065836 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
@@ -105,7 +105,7 @@ static const std::map<std::string, std::string>
SCALAR_FUNCTIONS
{"sign", "sign"},
{"radians", "radians"},
{"greatest", "sparkGreatest"},
- {"least", "least"},
+ {"least", "sparkLeast"},
{"shiftleft", "bitShiftLeft"},
{"shiftright", "bitShiftRight"},
{"check_overflow", "checkDecimalOverflowSpark"},
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]