This is an automated email from the ASF dual-hosted git repository.

liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new a7f64b9803 [GLUTEN-8148][CH] Fix corr with NaN (#8150)
a7f64b9803 is described below

commit a7f64b980345574316d943ddccd8869926efeee6
Author: Shuai li <[email protected]>
AuthorDate: Thu Dec 5 16:35:22 2024 +0800

    [GLUTEN-8148][CH] Fix corr with NaN (#8150)
    
    What changes were proposed in this pull request?
    (Fixes: #8148)
    
    How was this patch tested?
    Test by ut
---
 .../GlutenClickhouseFunctionSuite.scala             | 14 ++++++++++++++
 .../local-engine/Parser/AggregateFunctionParser.cpp | 21 +++++++++++++++++++++
 .../local-engine/Parser/AggregateFunctionParser.h   |  3 +++
 3 files changed, 38 insertions(+)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
index e35f6bf65b..2437ffd035 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
@@ -401,4 +401,18 @@ class GlutenClickhouseFunctionSuite extends 
GlutenClickHouseTPCHAbstractSuite {
     }
   }
 
+  test("GLUTEN-8148: Fix corr with NaN") {
+    withTable("corr_nan") {
+      sql("create table if not exists corr_nan (x double, y double) using 
parquet")
+      sql("insert into corr_nan values(0,1)")
+      compareResultsAgainstVanillaSpark(
+        """
+          |select corr(x,y), corr(y,x) from corr_nan
+        """.stripMargin,
+        true,
+        { _ => }
+      )
+    }
+  }
+
 }
diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp 
b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
index 42c4230e4a..eb05b26dcf 100644
--- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
+++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
@@ -195,6 +195,8 @@ const DB::ActionsDAG::Node * 
AggregateFunctionParser::convertNodeTypeIfNeeded(
         actions_dag.addOrReplaceInOutputs(*func_node);
     }
 
+    func_node = convertNanToNullIfNeed(func_info, func_node, actions_dag);
+
     if (output_type.has_decimal())
     {
         String checkDecimalOverflowSparkOrNull = 
"checkDecimalOverflowSparkOrNull";
@@ -209,6 +211,25 @@ const DB::ActionsDAG::Node * 
AggregateFunctionParser::convertNodeTypeIfNeeded(
     return func_node;
 }
 
+const DB::ActionsDAG::Node * AggregateFunctionParser::convertNanToNullIfNeed(
+    const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * 
func_node, DB::ActionsDAG & actions_dag) const
+{
+    if (getCHFunctionName(func_info) != "corr" || 
!func_node->result_type->isNullable())
+        return func_node;
+
+    /// result is nullable.
+    /// if result is NaN, convert it to NULL.
+    auto is_nan_func_node = toFunctionNode(actions_dag, "isNaN", 
getUniqueName("isNaN"), {func_node});
+    auto nullable_col = func_node->result_type->createColumn();
+    nullable_col->insertDefault();
+    const auto * null_node
+        = 
&actions_dag.addColumn(DB::ColumnWithTypeAndName(std::move(nullable_col), 
func_node->result_type, getUniqueName("null")));
+    DB::ActionsDAG::NodeRawConstPtrs convert_nan_func_args = 
{is_nan_func_node, null_node, func_node};
+    func_node = toFunctionNode(actions_dag, "if", func_node->result_name, 
convert_nan_func_args);
+    actions_dag.addOrReplaceInOutputs(*func_node);
+    return func_node;
+}
+
 AggregateFunctionParserFactory & AggregateFunctionParserFactory::instance()
 {
     static AggregateFunctionParserFactory factory;
diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h 
b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
index 02b09fc256..a41b3e3ad9 100644
--- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
+++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
@@ -144,6 +144,9 @@ protected:
 
     std::pair<DataTypePtr, Field> parseLiteral(const 
substrait::Expression_Literal & literal) const;
 
+    const DB::ActionsDAG::Node * convertNanToNullIfNeed(
+        const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * 
func_node, DB::ActionsDAG & actions_dag) const;
+
     ParserContextPtr parser_context;
     std::unique_ptr<ExpressionParser> expression_parser;
     Poco::Logger * logger = 
&Poco::Logger::get("AggregateFunctionParserFactory");


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to