This is an automated email from the ASF dual-hosted git repository.
liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new a7f64b9803 [GLUTEN-8148][CH] Fix corr with NaN (#8150)
a7f64b9803 is described below
commit a7f64b980345574316d943ddccd8869926efeee6
Author: Shuai li <[email protected]>
AuthorDate: Thu Dec 5 16:35:22 2024 +0800
[GLUTEN-8148][CH] Fix corr with NaN (#8150)
What changes were proposed in this pull request?
(Fixes: #8148)
How was this patch tested?
Test by ut
---
.../GlutenClickhouseFunctionSuite.scala | 14 ++++++++++++++
.../local-engine/Parser/AggregateFunctionParser.cpp | 21 +++++++++++++++++++++
.../local-engine/Parser/AggregateFunctionParser.h | 3 +++
3 files changed, 38 insertions(+)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
index e35f6bf65b..2437ffd035 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
@@ -401,4 +401,18 @@ class GlutenClickhouseFunctionSuite extends
GlutenClickHouseTPCHAbstractSuite {
}
}
+ test("GLUTEN-8148: Fix corr with NaN") {
+ withTable("corr_nan") {
+ sql("create table if not exists corr_nan (x double, y double) using
parquet")
+ sql("insert into corr_nan values(0,1)")
+ compareResultsAgainstVanillaSpark(
+ """
+ |select corr(x,y), corr(y,x) from corr_nan
+ """.stripMargin,
+ true,
+ { _ => }
+ )
+ }
+ }
+
}
diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
index 42c4230e4a..eb05b26dcf 100644
--- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
+++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
@@ -195,6 +195,8 @@ const DB::ActionsDAG::Node *
AggregateFunctionParser::convertNodeTypeIfNeeded(
actions_dag.addOrReplaceInOutputs(*func_node);
}
+ func_node = convertNanToNullIfNeed(func_info, func_node, actions_dag);
+
if (output_type.has_decimal())
{
String checkDecimalOverflowSparkOrNull =
"checkDecimalOverflowSparkOrNull";
@@ -209,6 +211,25 @@ const DB::ActionsDAG::Node *
AggregateFunctionParser::convertNodeTypeIfNeeded(
return func_node;
}
+const DB::ActionsDAG::Node * AggregateFunctionParser::convertNanToNullIfNeed(
+ const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node *
func_node, DB::ActionsDAG & actions_dag) const
+{
+ if (getCHFunctionName(func_info) != "corr" ||
!func_node->result_type->isNullable())
+ return func_node;
+
+ /// result is nullable.
+ /// if result is NaN, convert it to NULL.
+ auto is_nan_func_node = toFunctionNode(actions_dag, "isNaN",
getUniqueName("isNaN"), {func_node});
+ auto nullable_col = func_node->result_type->createColumn();
+ nullable_col->insertDefault();
+ const auto * null_node
+ =
&actions_dag.addColumn(DB::ColumnWithTypeAndName(std::move(nullable_col),
func_node->result_type, getUniqueName("null")));
+ DB::ActionsDAG::NodeRawConstPtrs convert_nan_func_args =
{is_nan_func_node, null_node, func_node};
+ func_node = toFunctionNode(actions_dag, "if", func_node->result_name,
convert_nan_func_args);
+ actions_dag.addOrReplaceInOutputs(*func_node);
+ return func_node;
+}
+
AggregateFunctionParserFactory & AggregateFunctionParserFactory::instance()
{
static AggregateFunctionParserFactory factory;
diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
index 02b09fc256..a41b3e3ad9 100644
--- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
+++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
@@ -144,6 +144,9 @@ protected:
std::pair<DataTypePtr, Field> parseLiteral(const
substrait::Expression_Literal & literal) const;
+ const DB::ActionsDAG::Node * convertNanToNullIfNeed(
+ const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node *
func_node, DB::ActionsDAG & actions_dag) const;
+
ParserContextPtr parser_context;
std::unique_ptr<ExpressionParser> expression_parser;
Poco::Logger * logger =
&Poco::Logger::get("AggregateFunctionParserFactory");
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]