This is an automated email from the ASF dual-hosted git repository.

changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 26cde39963 [GLUTEN-8921][GLUTEN-8922][CH] Fix 
checkDecimalOverflowSparkOrNull and lead function (#8929)
26cde39963 is described below

commit 26cde39963af7179afffa038c201ca6d40f0042c
Author: Wenzheng Liu <[email protected]>
AuthorDate: Sat Mar 8 10:48:32 2025 +0800

    [GLUTEN-8921][GLUTEN-8922][CH] Fix checkDecimalOverflowSparkOrNull and lead 
function (#8929)
---
 .../GlutenClickhouseFunctionSuite.scala            | 34 ++++++++++++++++++++++
 .../SparkFunctionCheckDecimalOverflow.cpp          | 21 +++++++++----
 .../aggregate_function_parser/LeadLagParser.cpp    |  8 ++---
 3 files changed, 53 insertions(+), 10 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
index faccfa105c..bf84c46c71 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
@@ -442,4 +442,38 @@ class GlutenClickhouseFunctionSuite extends 
GlutenClickHouseTPCHAbstractSuite {
     }
   }
 
+  test("GLUTEN-8921: Type mismatch at checkDecimalOverflowSparkOrNull") {
+    compareResultsAgainstVanillaSpark(
+      """
+        |select l_shipdate, avg(l_quantity), count(0) over() COU,
+        |SUM(-1.1) over() SU, AVG(-2) over() AV,
+        |max(-1.1) over() MA, min(-3) over() MI
+        |from lineitem
+        |where l_shipdate <= date'1998-09-02'
+        |group by l_shipdate
+        |order by l_shipdate
+      """.stripMargin,
+      true,
+      { _ => }
+    )
+  }
+
+  test("GLUTEN-8922: Incorrect result in lead function with constant col") {
+    compareResultsAgainstVanillaSpark(
+      """
+        |select l_shipdate,
+        |FIRST_VALUE(-2) over() FI,
+        |LAST_VALUE(-2) over() LA,
+        |lag(-2) over(order by l_shipdate) lag0,
+        |lead(-2) over(order by l_shipdate) lead0
+        |from lineitem
+        |where l_shipdate <= date'1998-09-02'
+        |group by l_shipdate
+        |order by l_shipdate
+      """.stripMargin,
+      true,
+      { _ => }
+    )
+  }
+
 }
diff --git 
a/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
index 3b870b22d7..a709f332b2 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
@@ -100,14 +100,20 @@ public:
         UInt32 precision = extractArgument(arguments[1]);
         UInt32 scale = extractArgument(arguments[2]);
         auto return_type = createDecimal<DataTypeDecimal>(precision, scale);
-        if constexpr (exception_mode == CheckExceptionMode::Null)
-        {
-            if (!arguments[0].type->isNullable())
-                return std::make_shared<DataTypeNullable>(return_type);
-        }
+        if (isReturnTypeNullable(arguments[0]))
+            return std::make_shared<DataTypeNullable>(return_type);
         return return_type;
     }
 
+    bool isReturnTypeNullable(const ColumnWithTypeAndName & arg) const
+    {
+        if constexpr (exception_mode == CheckExceptionMode::Null)
+            return true;
+        if (arg.type->isNullable())
+            return true;
+        return false;
+    }
+
     ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const 
DataTypePtr &, size_t input_rows_count) const override
     {
         UInt32 to_precision = extractArgument(arguments[1]);
@@ -133,7 +139,10 @@ public:
                     auto from_scale = getDecimalScale(*src_col.type);
                     if (from_precision == to_precision && from_scale == 
to_scale)
                     {
-                        dst_col = src_col.column;
+                        if (isReturnTypeNullable(arguments[0]))
+                            dst_col = makeNullable(src_col.column);
+                        else
+                            dst_col = src_col.column;
                         return true;
                     }
                 }
diff --git 
a/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp 
b/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp
index 5984f5a27a..5dc6a734b9 100644
--- a/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp
+++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp
@@ -33,7 +33,7 @@ LeadParser::parseFunctionArguments(const CommonFunctionInfo & 
func_info, DB::Act
     /// The 3rd arg is default value
     /// when it is set to null, the 1st arg must be nullable
     const auto & arg2 = func_info.arguments[2].value();
-    const auto * arg0_col = 
actions_dag.getInputs()[arg0.selection().direct_reference().struct_field().field()];
+    const auto * arg0_col = parseExpression(actions_dag, arg0);
     auto arg0_col_name = arg0_col->result_name;
     auto arg0_col_type= arg0_col->result_type;
     const DB::ActionsDAG::Node * node = nullptr;
@@ -41,7 +41,7 @@ LeadParser::parseFunctionArguments(const CommonFunctionInfo & 
func_info, DB::Act
     {
         node = ActionsDAGUtil::convertNodeType(
             actions_dag,
-            &actions_dag.findInOutputs(arg0_col_name),
+            arg0_col,
             DB::makeNullable(arg0_col_type),
             arg0_col_name);
         actions_dag.addOrReplaceInOutputs(*node);
@@ -76,7 +76,7 @@ LagParser::parseFunctionArguments(const CommonFunctionInfo & 
func_info, DB::Acti
     /// The 3rd arg is default value
     /// when it is set to null, the 1st arg must be nullable
     const auto & arg2 = func_info.arguments[2].value();
-    const auto * arg0_col = 
actions_dag.getInputs()[arg0.selection().direct_reference().struct_field().field()];
+    const auto * arg0_col = parseExpression(actions_dag, arg0);
     auto arg0_col_name = arg0_col->result_name;
     auto arg0_col_type = arg0_col->result_type;
     const DB::ActionsDAG::Node * node = nullptr;
@@ -84,7 +84,7 @@ LagParser::parseFunctionArguments(const CommonFunctionInfo & 
func_info, DB::Acti
     {
         node = ActionsDAGUtil::convertNodeType(
             actions_dag,
-            &actions_dag.findInOutputs(arg0_col_name),
+            arg0_col,
             makeNullable(arg0_col_type),
             arg0_col_name);
         actions_dag.addOrReplaceInOutputs(*node);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to