This is an automated email from the ASF dual-hosted git repository.

taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 983e269df1 [GLUTEN-7796][CH] Fix diff while casting bool to string 
(#7804)
983e269df1 is described below

commit 983e269df1b768266b4af346a1d219f3acd08cf7
Author: 李扬 <654010...@qq.com>
AuthorDate: Thu Nov 7 09:07:49 2024 +0800

    [GLUTEN-7796][CH] Fix diff while casting bool to string (#7804)
    
    * fix cast bool to string
    
    * fix failed uts
    
    * fix cast bool to string aghain
    
    * remove std::couts
---
 .../execution/GlutenFunctionValidateSuite.scala    |  5 ++
 cpp-ch/local-engine/Parser/ExpressionParser.cpp    | 86 +++++++++++-----------
 cpp-ch/local-engine/Parser/FunctionParser.cpp      | 69 +++++++++++------
 cpp-ch/local-engine/Parser/FunctionParser.h        |  1 +
 4 files changed, 97 insertions(+), 64 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index 08342e5887..dbe8852290 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -855,4 +855,9 @@ class GlutenFunctionValidateSuite extends 
GlutenClickHouseWholeStageTransformerS
       compareResultsAgainstVanillaSpark(sql, true, { _ => })
     }
   }
+
+  test("GLUTEN-7796 cast bool to string") {
+    val sql = "select cast(id % 2 = 1 as string) from range(10)"
+    compareResultsAgainstVanillaSpark(sql, true, { _ => })
+  }
 }
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.cpp 
b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
index ee64aff078..30ef92f176 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
@@ -93,7 +93,7 @@ std::pair<DB::DataTypePtr, DB::Field> 
LiteralParser::parse(const substrait::Expr
             break;
         }
         case substrait::Expression_Literal::kBoolean: {
-            type = std::make_shared<DB::DataTypeUInt8>();
+            type = DB::DataTypeFactory::instance().get("Bool");
             field = literal.boolean() ? UInt8(1) : UInt8(0);
             break;
         }
@@ -288,77 +288,77 @@ const ActionsDAG::Node * 
ExpressionParser::parseExpression(ActionsDAG & actions_
 
         case substrait::Expression::RexTypeCase::kCast: {
             if (!rel.cast().has_type() || !rel.cast().has_input())
-                throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Doesn't 
have type or input in cast node.");
-            DB::ActionsDAG::NodeRawConstPtrs args;
+                throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type 
or input in cast node.");
+            ActionsDAG::NodeRawConstPtrs args;
 
-            String cast_function = "CAST";
             const auto & input = rel.cast().input();
             args.emplace_back(parseExpression(actions_dag, input));
 
             const auto & substrait_type = rel.cast().type();
             const auto & input_type = args[0]->result_type;
-            DB::DataTypePtr non_nullable_input_type = 
DB::removeNullable(input_type);
-            DB::DataTypePtr output_type = 
TypeParser::parseType(substrait_type);
-            DB::DataTypePtr non_nullable_output_type = 
DB::removeNullable(output_type);
+            DataTypePtr denull_input_type = removeNullable(input_type);
+            DataTypePtr output_type = TypeParser::parseType(substrait_type);
+            DataTypePtr denull_output_type = removeNullable(output_type);
 
-            const DB::ActionsDAG::Node * function_node = nullptr;
+            const ActionsDAG::Node * result_node = nullptr;
             if (substrait_type.has_binary())
             {
                 /// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
-                function_node = toFunctionNode(actions_dag, 
"reinterpretAsStringSpark", args);
+                result_node = toFunctionNode(actions_dag, 
"reinterpretAsStringSpark", args);
             }
-            else if (DB::isString(non_nullable_input_type) && 
DB::isDate32(non_nullable_output_type))
-                function_node = toFunctionNode(actions_dag, "sparkToDate", 
args);
-            else if (DB::isString(non_nullable_input_type) && 
DB::isDateTime64(non_nullable_output_type))
-                function_node = toFunctionNode(actions_dag, "sparkToDateTime", 
args);
-            else if (DB::isDecimal(non_nullable_input_type) && 
DB::isString(non_nullable_output_type))
+            else if (isString(denull_input_type) && 
isDate32(denull_output_type))
+                result_node = toFunctionNode(actions_dag, "sparkToDate", args);
+            else if (isString(denull_input_type) && 
isDateTime64(denull_output_type))
+                result_node = toFunctionNode(actions_dag, "sparkToDateTime", 
args);
+            else if (isDecimal(denull_input_type) && 
isString(denull_output_type))
             {
                 /// Spark cast(x as STRING) if x is Decimal -> CH 
toDecimalString(x, scale)
-                UInt8 scale = DB::getDecimalScale(*non_nullable_input_type);
-                args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DB::DataTypeUInt8>(), DB::Field(scale)));
-                function_node = toFunctionNode(actions_dag, "toDecimalString", 
args);
+                UInt8 scale = getDecimalScale(*denull_input_type);
+                args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeUInt8>(), Field(scale)));
+                result_node = toFunctionNode(actions_dag, "toDecimalString", 
args);
             }
-            else if (DB::isFloat(non_nullable_input_type) && 
DB::isInt(non_nullable_output_type))
+            else if (isFloat(denull_input_type) && isInt(denull_output_type))
             {
-                String function_name = "sparkCastFloatTo" + 
non_nullable_output_type->getName();
-                function_node = toFunctionNode(actions_dag, function_name, 
args);
+                String function_name = "sparkCastFloatTo" + 
denull_output_type->getName();
+                result_node = toFunctionNode(actions_dag, function_name, args);
             }
-            else if ((isDecimal(non_nullable_input_type) && 
substrait_type.has_decimal()))
+            else if ((isDecimal(denull_input_type) && 
substrait_type.has_decimal()))
             {
                 args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeInt32>(), substrait_type.decimal().precision()));
                 args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeInt32>(), substrait_type.decimal().scale()));
-
-                function_node = toFunctionNode(actions_dag, 
"checkDecimalOverflowSparkOrNull", args);
+                result_node = toFunctionNode(actions_dag, 
"checkDecimalOverflowSparkOrNull", args);
             }
-            else if (isMap(non_nullable_input_type) && 
isString(non_nullable_output_type))
+            else if (isMap(denull_input_type) && isString(denull_output_type))
             {
                 // ISSUE-7389: spark cast(map to string) has different 
behavior with CH cast(map to string)
-                auto map_input_type = std::static_pointer_cast<const 
DataTypeMap>(non_nullable_input_type);
+                auto map_input_type = std::static_pointer_cast<const 
DataTypeMap>(denull_input_type);
                 args.emplace_back(addConstColumn(actions_dag, 
map_input_type->getKeyType(), map_input_type->getKeyType()->getDefault()));
                 args.emplace_back(addConstColumn(actions_dag, 
map_input_type->getValueType(), map_input_type->getValueType()->getDefault()));
-                function_node = toFunctionNode(actions_dag, 
"sparkCastMapToString", args);
+                result_node = toFunctionNode(actions_dag, 
"sparkCastMapToString", args);
+            }
+            else if (isString(denull_input_type) && substrait_type.has_bool_())
+            {
+                /// cast(string to boolean)
+                args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeString>(), output_type->getName()));
+                result_node = toFunctionNode(actions_dag, 
"accurateCastOrNull", args);
+            }
+            else if (isString(denull_input_type) && isInt(denull_output_type))
+            {
+                /// Spark cast(x as INT) if x is String -> CH cast(trim(x) as 
INT)
+                /// Refer to 
https://github.com/apache/incubator-gluten/issues/4956
+                args[0] = toFunctionNode(actions_dag, "trim", {args[0]});
+                args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeString>(), output_type->getName()));
+                result_node = toFunctionNode(actions_dag, "CAST", args);
             }
             else
             {
-                if (DB::isString(non_nullable_input_type) && 
DB::isInt(non_nullable_output_type))
-                {
-                    /// Spark cast(x as INT) if x is String -> CH cast(trim(x) 
as INT)
-                    /// Refer to 
https://github.com/apache/incubator-gluten/issues/4956
-                    args[0] = toFunctionNode(actions_dag, "trim", {args[0]});
-                }
-                else if (DB::isString(non_nullable_input_type) && 
substrait_type.has_bool_())
-                {
-                    /// cast(string to boolean)
-                    cast_function = "accurateCastOrNull";
-                }
-
-                /// Common process
-                args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DB::DataTypeString>(), output_type->getName()));
-                function_node = toFunctionNode(actions_dag, cast_function, 
args);
+                /// Common process: CAST(input, type)
+                args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeString>(), output_type->getName()));
+                result_node = toFunctionNode(actions_dag, "CAST", args);
             }
 
-            actions_dag.addOrReplaceInOutputs(*function_node);
-            return function_node;
+            actions_dag.addOrReplaceInOutputs(*result_node);
+            return result_node;
         }
 
         case substrait::Expression::RexTypeCase::kIfThen: {
diff --git a/cpp-ch/local-engine/Parser/FunctionParser.cpp 
b/cpp-ch/local-engine/Parser/FunctionParser.cpp
index c8158d3e3e..7e794dabec 100644
--- a/cpp-ch/local-engine/Parser/FunctionParser.cpp
+++ b/cpp-ch/local-engine/Parser/FunctionParser.cpp
@@ -115,33 +115,60 @@ const ActionsDAG::Node * 
FunctionParser::convertNodeTypeIfNeeded(
     const substrait::Expression_ScalarFunction & substrait_func, const 
ActionsDAG::Node * func_node, ActionsDAG & actions_dag) const
 {
     const auto & output_type = substrait_func.output_type();
-    if (!TypeParser::isTypeMatched(output_type, func_node->result_type))
+    const ActionsDAG::Node * result_node = nullptr;
+
+    auto convert_type_if_needed = [&]()
     {
-        auto result_type = TypeParser::parseType(substrait_func.output_type());
-        if (DB::isDecimalOrNullableDecimal(result_type))
+        if (!TypeParser::isTypeMatched(output_type, func_node->result_type))
         {
-            return ActionsDAGUtil::convertNodeType(
-                actions_dag,
-                func_node,
-                // as stated in isTypeMatched, currently we don't change 
nullability of the result type
-                func_node->result_type->isNullable() ? 
local_engine::wrapNullableType(true, result_type)
-                                                     : 
local_engine::removeNullable(result_type),
-                func_node->result_name,
-                CastType::accurateOrNull);
+            auto result_type = 
TypeParser::parseType(substrait_func.output_type());
+            if (DB::isDecimalOrNullableDecimal(result_type))
+            {
+                return ActionsDAGUtil::convertNodeType(
+                    actions_dag,
+                    func_node,
+                    // as stated in isTypeMatched, currently we don't change 
nullability of the result type
+                    func_node->result_type->isNullable() ? 
local_engine::wrapNullableType(true, result_type)
+                                                         : 
local_engine::removeNullable(result_type),
+                    func_node->result_name,
+                    CastType::accurateOrNull);
+            }
+            else
+            {
+                return ActionsDAGUtil::convertNodeType(
+                    actions_dag,
+                    func_node,
+                    // as stated in isTypeMatched, currently we don't change 
nullability of the result type
+                    func_node->result_type->isNullable() ? 
local_engine::wrapNullableType(true, TypeParser::parseType(output_type))
+                                                         : 
DB::removeNullable(TypeParser::parseType(output_type)),
+                    func_node->result_name);
+            }
         }
         else
+            return func_node;
+    };
+    result_node = convert_type_if_needed();
+
+    /// Notice that in CH Bool and UInt8 have different serialization and 
deserialization methods, which will cause issue when executing cast(bool as 
string) in spark in spark.
+    auto convert_uint8_to_bool_if_needed = [&]() -> const auto *
+    {
+        auto * mutable_result_node = const_cast<ActionsDAG::Node 
*>(result_node);
+        auto denull_result_type = DB::removeNullable(result_node->result_type);
+        if (isUInt8(denull_result_type) && output_type.has_bool_())
         {
-            return ActionsDAGUtil::convertNodeType(
-                actions_dag,
-                func_node,
-                // as stated in isTypeMatched, currently we don't change 
nullability of the result type
-                func_node->result_type->isNullable() ? 
local_engine::wrapNullableType(true, TypeParser::parseType(output_type))
-                                                     : 
DB::removeNullable(TypeParser::parseType(output_type)),
-                func_node->result_name);
+            auto bool_type = DB::DataTypeFactory::instance().get("Bool");
+            if (result_node->result_type->isNullable())
+                bool_type = DB::makeNullable(bool_type);
+
+            mutable_result_node->result_type = std::move(bool_type);
+            return mutable_result_node;
         }
-    }
-    else
-        return func_node;
+        else
+            return result_node;
+    };
+    result_node = convert_uint8_to_bool_if_needed();
+
+    return result_node;
 }
 
 void FunctionParserFactory::registerFunctionParser(const String & name, Value 
value)
diff --git a/cpp-ch/local-engine/Parser/FunctionParser.h 
b/cpp-ch/local-engine/Parser/FunctionParser.h
index d9ca7b5128..23216f2203 100644
--- a/cpp-ch/local-engine/Parser/FunctionParser.h
+++ b/cpp-ch/local-engine/Parser/FunctionParser.h
@@ -60,6 +60,7 @@ protected:
     {
         return parseFunctionArguments(substrait_func, actions_dag);
     }
+
     virtual DB::ActionsDAG::NodeRawConstPtrs
     parseFunctionArguments(const substrait::Expression_ScalarFunction & 
substrait_func, DB::ActionsDAG & actions_dag) const;
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org
For additional commands, e-mail: commits-h...@gluten.apache.org

Reply via email to