This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 26cde39963 [GLUTEN-8921][GLUTEN-8922][CH] Fix
checkDecimalOverflowSparkOrNull and lead function (#8929)
26cde39963 is described below
commit 26cde39963af7179afffa038c201ca6d40f0042c
Author: Wenzheng Liu <[email protected]>
AuthorDate: Sat Mar 8 10:48:32 2025 +0800
[GLUTEN-8921][GLUTEN-8922][CH] Fix checkDecimalOverflowSparkOrNull and lead
function (#8929)
---
.../GlutenClickhouseFunctionSuite.scala | 34 ++++++++++++++++++++++
.../SparkFunctionCheckDecimalOverflow.cpp | 21 +++++++++----
.../aggregate_function_parser/LeadLagParser.cpp | 8 ++---
3 files changed, 53 insertions(+), 10 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
index faccfa105c..bf84c46c71 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
@@ -442,4 +442,38 @@ class GlutenClickhouseFunctionSuite extends
GlutenClickHouseTPCHAbstractSuite {
}
}
+ test("GLUTEN-8921: Type mismatch at checkDecimalOverflowSparkOrNull") {
+ compareResultsAgainstVanillaSpark(
+ """
+ |select l_shipdate, avg(l_quantity), count(0) over() COU,
+ |SUM(-1.1) over() SU, AVG(-2) over() AV,
+ |max(-1.1) over() MA, min(-3) over() MI
+ |from lineitem
+ |where l_shipdate <= date'1998-09-02'
+ |group by l_shipdate
+ |order by l_shipdate
+ """.stripMargin,
+ true,
+ { _ => }
+ )
+ }
+
+ test("GLUTEN-8922: Incorrect result in lead function with constant col") {
+ compareResultsAgainstVanillaSpark(
+ """
+ |select l_shipdate,
+ |FIRST_VALUE(-2) over() FI,
+ |LAST_VALUE(-2) over() LA,
+ |lag(-2) over(order by l_shipdate) lag0,
+ |lead(-2) over(order by l_shipdate) lead0
+ |from lineitem
+ |where l_shipdate <= date'1998-09-02'
+ |group by l_shipdate
+ |order by l_shipdate
+ """.stripMargin,
+ true,
+ { _ => }
+ )
+ }
+
}
diff --git
a/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
index 3b870b22d7..a709f332b2 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
@@ -100,14 +100,20 @@ public:
UInt32 precision = extractArgument(arguments[1]);
UInt32 scale = extractArgument(arguments[2]);
auto return_type = createDecimal<DataTypeDecimal>(precision, scale);
- if constexpr (exception_mode == CheckExceptionMode::Null)
- {
- if (!arguments[0].type->isNullable())
- return std::make_shared<DataTypeNullable>(return_type);
- }
+ if (isReturnTypeNullable(arguments[0]))
+ return std::make_shared<DataTypeNullable>(return_type);
return return_type;
}
+ bool isReturnTypeNullable(const ColumnWithTypeAndName & arg) const
+ {
+ if constexpr (exception_mode == CheckExceptionMode::Null)
+ return true;
+ if (arg.type->isNullable())
+ return true;
+ return false;
+ }
+
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const
DataTypePtr &, size_t input_rows_count) const override
{
UInt32 to_precision = extractArgument(arguments[1]);
@@ -133,7 +139,10 @@ public:
auto from_scale = getDecimalScale(*src_col.type);
if (from_precision == to_precision && from_scale ==
to_scale)
{
- dst_col = src_col.column;
+ if (isReturnTypeNullable(arguments[0]))
+ dst_col = makeNullable(src_col.column);
+ else
+ dst_col = src_col.column;
return true;
}
}
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp
b/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp
index 5984f5a27a..5dc6a734b9 100644
--- a/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp
+++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp
@@ -33,7 +33,7 @@ LeadParser::parseFunctionArguments(const CommonFunctionInfo &
func_info, DB::Act
/// The 3rd arg is default value
/// when it is set to null, the 1st arg must be nullable
const auto & arg2 = func_info.arguments[2].value();
- const auto * arg0_col =
actions_dag.getInputs()[arg0.selection().direct_reference().struct_field().field()];
+ const auto * arg0_col = parseExpression(actions_dag, arg0);
auto arg0_col_name = arg0_col->result_name;
auto arg0_col_type= arg0_col->result_type;
const DB::ActionsDAG::Node * node = nullptr;
@@ -41,7 +41,7 @@ LeadParser::parseFunctionArguments(const CommonFunctionInfo &
func_info, DB::Act
{
node = ActionsDAGUtil::convertNodeType(
actions_dag,
- &actions_dag.findInOutputs(arg0_col_name),
+ arg0_col,
DB::makeNullable(arg0_col_type),
arg0_col_name);
actions_dag.addOrReplaceInOutputs(*node);
@@ -76,7 +76,7 @@ LagParser::parseFunctionArguments(const CommonFunctionInfo &
func_info, DB::Acti
/// The 3rd arg is default value
/// when it is set to null, the 1st arg must be nullable
const auto & arg2 = func_info.arguments[2].value();
- const auto * arg0_col =
actions_dag.getInputs()[arg0.selection().direct_reference().struct_field().field()];
+ const auto * arg0_col = parseExpression(actions_dag, arg0);
auto arg0_col_name = arg0_col->result_name;
auto arg0_col_type = arg0_col->result_type;
const DB::ActionsDAG::Node * node = nullptr;
@@ -84,7 +84,7 @@ LagParser::parseFunctionArguments(const CommonFunctionInfo &
func_info, DB::Acti
{
node = ActionsDAGUtil::convertNodeType(
actions_dag,
- &actions_dag.findInOutputs(arg0_col_name),
+ arg0_col,
makeNullable(arg0_col_type),
arg0_col_name);
actions_dag.addOrReplaceInOutputs(*node);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]