This is an automated email from the ASF dual-hosted git repository. taiyangli pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push: new 983e269df1 [GLUTEN-7796][CH] Fix diff while casting bool to string (#7804) 983e269df1 is described below commit 983e269df1b768266b4af346a1d219f3acd08cf7 Author: 李扬 <654010...@qq.com> AuthorDate: Thu Nov 7 09:07:49 2024 +0800 [GLUTEN-7796][CH] Fix diff while casting bool to string (#7804) * fix cast bool to string * fix failed uts * fix cast bool to string aghain * remove std::couts --- .../execution/GlutenFunctionValidateSuite.scala | 5 ++ cpp-ch/local-engine/Parser/ExpressionParser.cpp | 86 +++++++++++----------- cpp-ch/local-engine/Parser/FunctionParser.cpp | 69 +++++++++++------ cpp-ch/local-engine/Parser/FunctionParser.h | 1 + 4 files changed, 97 insertions(+), 64 deletions(-) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index 08342e5887..dbe8852290 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -855,4 +855,9 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS compareResultsAgainstVanillaSpark(sql, true, { _ => }) } } + + test("GLUTEN-7796 cast bool to string") { + val sql = "select cast(id % 2 = 1 as string) from range(10)" + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + } } diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.cpp b/cpp-ch/local-engine/Parser/ExpressionParser.cpp index ee64aff078..30ef92f176 100644 --- a/cpp-ch/local-engine/Parser/ExpressionParser.cpp +++ b/cpp-ch/local-engine/Parser/ExpressionParser.cpp @@ -93,7 +93,7 @@ std::pair<DB::DataTypePtr, DB::Field> LiteralParser::parse(const substrait::Expr break; } case substrait::Expression_Literal::kBoolean: { - type = std::make_shared<DB::DataTypeUInt8>(); + type = DB::DataTypeFactory::instance().get("Bool"); field = literal.boolean() ? UInt8(1) : UInt8(0); break; } @@ -288,77 +288,77 @@ const ActionsDAG::Node * ExpressionParser::parseExpression(ActionsDAG & actions_ case substrait::Expression::RexTypeCase::kCast: { if (!rel.cast().has_type() || !rel.cast().has_input()) - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Doesn't have type or input in cast node."); - DB::ActionsDAG::NodeRawConstPtrs args; + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type or input in cast node."); + ActionsDAG::NodeRawConstPtrs args; - String cast_function = "CAST"; const auto & input = rel.cast().input(); args.emplace_back(parseExpression(actions_dag, input)); const auto & substrait_type = rel.cast().type(); const auto & input_type = args[0]->result_type; - DB::DataTypePtr non_nullable_input_type = DB::removeNullable(input_type); - DB::DataTypePtr output_type = TypeParser::parseType(substrait_type); - DB::DataTypePtr non_nullable_output_type = DB::removeNullable(output_type); + DataTypePtr denull_input_type = removeNullable(input_type); + DataTypePtr output_type = TypeParser::parseType(substrait_type); + DataTypePtr denull_output_type = removeNullable(output_type); - const DB::ActionsDAG::Node * function_node = nullptr; + const ActionsDAG::Node * result_node = nullptr; if (substrait_type.has_binary()) { /// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x) - function_node = toFunctionNode(actions_dag, "reinterpretAsStringSpark", args); + result_node = toFunctionNode(actions_dag, "reinterpretAsStringSpark", args); } - else if (DB::isString(non_nullable_input_type) && DB::isDate32(non_nullable_output_type)) - function_node = toFunctionNode(actions_dag, "sparkToDate", args); - else if (DB::isString(non_nullable_input_type) && DB::isDateTime64(non_nullable_output_type)) - function_node = toFunctionNode(actions_dag, "sparkToDateTime", args); - else if (DB::isDecimal(non_nullable_input_type) && DB::isString(non_nullable_output_type)) + else if (isString(denull_input_type) && isDate32(denull_output_type)) + result_node = toFunctionNode(actions_dag, "sparkToDate", args); + else if (isString(denull_input_type) && isDateTime64(denull_output_type)) + result_node = toFunctionNode(actions_dag, "sparkToDateTime", args); + else if (isDecimal(denull_input_type) && isString(denull_output_type)) { /// Spark cast(x as STRING) if x is Decimal -> CH toDecimalString(x, scale) - UInt8 scale = DB::getDecimalScale(*non_nullable_input_type); - args.emplace_back(addConstColumn(actions_dag, std::make_shared<DB::DataTypeUInt8>(), DB::Field(scale))); - function_node = toFunctionNode(actions_dag, "toDecimalString", args); + UInt8 scale = getDecimalScale(*denull_input_type); + args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeUInt8>(), Field(scale))); + result_node = toFunctionNode(actions_dag, "toDecimalString", args); } - else if (DB::isFloat(non_nullable_input_type) && DB::isInt(non_nullable_output_type)) + else if (isFloat(denull_input_type) && isInt(denull_output_type)) { - String function_name = "sparkCastFloatTo" + non_nullable_output_type->getName(); - function_node = toFunctionNode(actions_dag, function_name, args); + String function_name = "sparkCastFloatTo" + denull_output_type->getName(); + result_node = toFunctionNode(actions_dag, function_name, args); } - else if ((isDecimal(non_nullable_input_type) && substrait_type.has_decimal())) + else if ((isDecimal(denull_input_type) && substrait_type.has_decimal())) { args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), substrait_type.decimal().precision())); args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), substrait_type.decimal().scale())); - - function_node = toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", args); + result_node = toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", args); } - else if (isMap(non_nullable_input_type) && isString(non_nullable_output_type)) + else if (isMap(denull_input_type) && isString(denull_output_type)) { // ISSUE-7389: spark cast(map to string) has different behavior with CH cast(map to string) - auto map_input_type = std::static_pointer_cast<const DataTypeMap>(non_nullable_input_type); + auto map_input_type = std::static_pointer_cast<const DataTypeMap>(denull_input_type); args.emplace_back(addConstColumn(actions_dag, map_input_type->getKeyType(), map_input_type->getKeyType()->getDefault())); args.emplace_back(addConstColumn(actions_dag, map_input_type->getValueType(), map_input_type->getValueType()->getDefault())); - function_node = toFunctionNode(actions_dag, "sparkCastMapToString", args); + result_node = toFunctionNode(actions_dag, "sparkCastMapToString", args); + } + else if (isString(denull_input_type) && substrait_type.has_bool_()) + { + /// cast(string to boolean) + args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeString>(), output_type->getName())); + result_node = toFunctionNode(actions_dag, "accurateCastOrNull", args); + } + else if (isString(denull_input_type) && isInt(denull_output_type)) + { + /// Spark cast(x as INT) if x is String -> CH cast(trim(x) as INT) + /// Refer to https://github.com/apache/incubator-gluten/issues/4956 + args[0] = toFunctionNode(actions_dag, "trim", {args[0]}); + args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeString>(), output_type->getName())); + result_node = toFunctionNode(actions_dag, "CAST", args); } else { - if (DB::isString(non_nullable_input_type) && DB::isInt(non_nullable_output_type)) - { - /// Spark cast(x as INT) if x is String -> CH cast(trim(x) as INT) - /// Refer to https://github.com/apache/incubator-gluten/issues/4956 - args[0] = toFunctionNode(actions_dag, "trim", {args[0]}); - } - else if (DB::isString(non_nullable_input_type) && substrait_type.has_bool_()) - { - /// cast(string to boolean) - cast_function = "accurateCastOrNull"; - } - - /// Common process - args.emplace_back(addConstColumn(actions_dag, std::make_shared<DB::DataTypeString>(), output_type->getName())); - function_node = toFunctionNode(actions_dag, cast_function, args); + /// Common process: CAST(input, type) + args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeString>(), output_type->getName())); + result_node = toFunctionNode(actions_dag, "CAST", args); } - actions_dag.addOrReplaceInOutputs(*function_node); - return function_node; + actions_dag.addOrReplaceInOutputs(*result_node); + return result_node; } case substrait::Expression::RexTypeCase::kIfThen: { diff --git a/cpp-ch/local-engine/Parser/FunctionParser.cpp b/cpp-ch/local-engine/Parser/FunctionParser.cpp index c8158d3e3e..7e794dabec 100644 --- a/cpp-ch/local-engine/Parser/FunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/FunctionParser.cpp @@ -115,33 +115,60 @@ const ActionsDAG::Node * FunctionParser::convertNodeTypeIfNeeded( const substrait::Expression_ScalarFunction & substrait_func, const ActionsDAG::Node * func_node, ActionsDAG & actions_dag) const { const auto & output_type = substrait_func.output_type(); - if (!TypeParser::isTypeMatched(output_type, func_node->result_type)) + const ActionsDAG::Node * result_node = nullptr; + + auto convert_type_if_needed = [&]() { - auto result_type = TypeParser::parseType(substrait_func.output_type()); - if (DB::isDecimalOrNullableDecimal(result_type)) + if (!TypeParser::isTypeMatched(output_type, func_node->result_type)) { - return ActionsDAGUtil::convertNodeType( - actions_dag, - func_node, - // as stated in isTypeMatched, currently we don't change nullability of the result type - func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type) - : local_engine::removeNullable(result_type), - func_node->result_name, - CastType::accurateOrNull); + auto result_type = TypeParser::parseType(substrait_func.output_type()); + if (DB::isDecimalOrNullableDecimal(result_type)) + { + return ActionsDAGUtil::convertNodeType( + actions_dag, + func_node, + // as stated in isTypeMatched, currently we don't change nullability of the result type + func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type) + : local_engine::removeNullable(result_type), + func_node->result_name, + CastType::accurateOrNull); + } + else + { + return ActionsDAGUtil::convertNodeType( + actions_dag, + func_node, + // as stated in isTypeMatched, currently we don't change nullability of the result type + func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, TypeParser::parseType(output_type)) + : DB::removeNullable(TypeParser::parseType(output_type)), + func_node->result_name); + } } else + return func_node; + }; + result_node = convert_type_if_needed(); + + /// Notice that in CH Bool and UInt8 have different serialization and deserialization methods, which will cause issue when executing cast(bool as string) in spark in spark. + auto convert_uint8_to_bool_if_needed = [&]() -> const auto * + { + auto * mutable_result_node = const_cast<ActionsDAG::Node *>(result_node); + auto denull_result_type = DB::removeNullable(result_node->result_type); + if (isUInt8(denull_result_type) && output_type.has_bool_()) { - return ActionsDAGUtil::convertNodeType( - actions_dag, - func_node, - // as stated in isTypeMatched, currently we don't change nullability of the result type - func_node->result_type->isNullable() ? local_engine::wrapNullableType(true, TypeParser::parseType(output_type)) - : DB::removeNullable(TypeParser::parseType(output_type)), - func_node->result_name); + auto bool_type = DB::DataTypeFactory::instance().get("Bool"); + if (result_node->result_type->isNullable()) + bool_type = DB::makeNullable(bool_type); + + mutable_result_node->result_type = std::move(bool_type); + return mutable_result_node; } - } - else - return func_node; + else + return result_node; + }; + result_node = convert_uint8_to_bool_if_needed(); + + return result_node; } void FunctionParserFactory::registerFunctionParser(const String & name, Value value) diff --git a/cpp-ch/local-engine/Parser/FunctionParser.h b/cpp-ch/local-engine/Parser/FunctionParser.h index d9ca7b5128..23216f2203 100644 --- a/cpp-ch/local-engine/Parser/FunctionParser.h +++ b/cpp-ch/local-engine/Parser/FunctionParser.h @@ -60,6 +60,7 @@ protected: { return parseFunctionArguments(substrait_func, actions_dag); } + virtual DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const; --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org For additional commands, e-mail: commits-h...@gluten.apache.org