This is an automated email from the ASF dual-hosted git repository.
taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 983e269df1 [GLUTEN-7796][CH] Fix diff while casting bool to string
(#7804)
983e269df1 is described below
commit 983e269df1b768266b4af346a1d219f3acd08cf7
Author: 李扬 <[email protected]>
AuthorDate: Thu Nov 7 09:07:49 2024 +0800
[GLUTEN-7796][CH] Fix diff while casting bool to string (#7804)
* fix cast bool to string
* fix failed uts
* fix cast bool to string aghain
* remove std::couts
---
.../execution/GlutenFunctionValidateSuite.scala | 5 ++
cpp-ch/local-engine/Parser/ExpressionParser.cpp | 86 +++++++++++-----------
cpp-ch/local-engine/Parser/FunctionParser.cpp | 69 +++++++++++------
cpp-ch/local-engine/Parser/FunctionParser.h | 1 +
4 files changed, 97 insertions(+), 64 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index 08342e5887..dbe8852290 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -855,4 +855,9 @@ class GlutenFunctionValidateSuite extends
GlutenClickHouseWholeStageTransformerS
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}
}
+
+ test("GLUTEN-7796 cast bool to string") {
+ val sql = "select cast(id % 2 = 1 as string) from range(10)"
+ compareResultsAgainstVanillaSpark(sql, true, { _ => })
+ }
}
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
index ee64aff078..30ef92f176 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
@@ -93,7 +93,7 @@ std::pair<DB::DataTypePtr, DB::Field>
LiteralParser::parse(const substrait::Expr
break;
}
case substrait::Expression_Literal::kBoolean: {
- type = std::make_shared<DB::DataTypeUInt8>();
+ type = DB::DataTypeFactory::instance().get("Bool");
field = literal.boolean() ? UInt8(1) : UInt8(0);
break;
}
@@ -288,77 +288,77 @@ const ActionsDAG::Node *
ExpressionParser::parseExpression(ActionsDAG & actions_
case substrait::Expression::RexTypeCase::kCast: {
if (!rel.cast().has_type() || !rel.cast().has_input())
- throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Doesn't
have type or input in cast node.");
- DB::ActionsDAG::NodeRawConstPtrs args;
+ throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type
or input in cast node.");
+ ActionsDAG::NodeRawConstPtrs args;
- String cast_function = "CAST";
const auto & input = rel.cast().input();
args.emplace_back(parseExpression(actions_dag, input));
const auto & substrait_type = rel.cast().type();
const auto & input_type = args[0]->result_type;
- DB::DataTypePtr non_nullable_input_type =
DB::removeNullable(input_type);
- DB::DataTypePtr output_type =
TypeParser::parseType(substrait_type);
- DB::DataTypePtr non_nullable_output_type =
DB::removeNullable(output_type);
+ DataTypePtr denull_input_type = removeNullable(input_type);
+ DataTypePtr output_type = TypeParser::parseType(substrait_type);
+ DataTypePtr denull_output_type = removeNullable(output_type);
- const DB::ActionsDAG::Node * function_node = nullptr;
+ const ActionsDAG::Node * result_node = nullptr;
if (substrait_type.has_binary())
{
/// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
- function_node = toFunctionNode(actions_dag,
"reinterpretAsStringSpark", args);
+ result_node = toFunctionNode(actions_dag,
"reinterpretAsStringSpark", args);
}
- else if (DB::isString(non_nullable_input_type) &&
DB::isDate32(non_nullable_output_type))
- function_node = toFunctionNode(actions_dag, "sparkToDate",
args);
- else if (DB::isString(non_nullable_input_type) &&
DB::isDateTime64(non_nullable_output_type))
- function_node = toFunctionNode(actions_dag, "sparkToDateTime",
args);
- else if (DB::isDecimal(non_nullable_input_type) &&
DB::isString(non_nullable_output_type))
+ else if (isString(denull_input_type) &&
isDate32(denull_output_type))
+ result_node = toFunctionNode(actions_dag, "sparkToDate", args);
+ else if (isString(denull_input_type) &&
isDateTime64(denull_output_type))
+ result_node = toFunctionNode(actions_dag, "sparkToDateTime",
args);
+ else if (isDecimal(denull_input_type) &&
isString(denull_output_type))
{
/// Spark cast(x as STRING) if x is Decimal -> CH
toDecimalString(x, scale)
- UInt8 scale = DB::getDecimalScale(*non_nullable_input_type);
- args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DB::DataTypeUInt8>(), DB::Field(scale)));
- function_node = toFunctionNode(actions_dag, "toDecimalString",
args);
+ UInt8 scale = getDecimalScale(*denull_input_type);
+ args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeUInt8>(), Field(scale)));
+ result_node = toFunctionNode(actions_dag, "toDecimalString",
args);
}
- else if (DB::isFloat(non_nullable_input_type) &&
DB::isInt(non_nullable_output_type))
+ else if (isFloat(denull_input_type) && isInt(denull_output_type))
{
- String function_name = "sparkCastFloatTo" +
non_nullable_output_type->getName();
- function_node = toFunctionNode(actions_dag, function_name,
args);
+ String function_name = "sparkCastFloatTo" +
denull_output_type->getName();
+ result_node = toFunctionNode(actions_dag, function_name, args);
}
- else if ((isDecimal(non_nullable_input_type) &&
substrait_type.has_decimal()))
+ else if ((isDecimal(denull_input_type) &&
substrait_type.has_decimal()))
{
args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeInt32>(), substrait_type.decimal().precision()));
args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeInt32>(), substrait_type.decimal().scale()));
-
- function_node = toFunctionNode(actions_dag,
"checkDecimalOverflowSparkOrNull", args);
+ result_node = toFunctionNode(actions_dag,
"checkDecimalOverflowSparkOrNull", args);
}
- else if (isMap(non_nullable_input_type) &&
isString(non_nullable_output_type))
+ else if (isMap(denull_input_type) && isString(denull_output_type))
{
// ISSUE-7389: spark cast(map to string) has different
behavior with CH cast(map to string)
- auto map_input_type = std::static_pointer_cast<const
DataTypeMap>(non_nullable_input_type);
+ auto map_input_type = std::static_pointer_cast<const
DataTypeMap>(denull_input_type);
args.emplace_back(addConstColumn(actions_dag,
map_input_type->getKeyType(), map_input_type->getKeyType()->getDefault()));
args.emplace_back(addConstColumn(actions_dag,
map_input_type->getValueType(), map_input_type->getValueType()->getDefault()));
- function_node = toFunctionNode(actions_dag,
"sparkCastMapToString", args);
+ result_node = toFunctionNode(actions_dag,
"sparkCastMapToString", args);
+ }
+ else if (isString(denull_input_type) && substrait_type.has_bool_())
+ {
+ /// cast(string to boolean)
+ args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeString>(), output_type->getName()));
+ result_node = toFunctionNode(actions_dag,
"accurateCastOrNull", args);
+ }
+ else if (isString(denull_input_type) && isInt(denull_output_type))
+ {
+ /// Spark cast(x as INT) if x is String -> CH cast(trim(x) as
INT)
+ /// Refer to
https://github.com/apache/incubator-gluten/issues/4956
+ args[0] = toFunctionNode(actions_dag, "trim", {args[0]});
+ args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeString>(), output_type->getName()));
+ result_node = toFunctionNode(actions_dag, "CAST", args);
}
else
{
- if (DB::isString(non_nullable_input_type) &&
DB::isInt(non_nullable_output_type))
- {
- /// Spark cast(x as INT) if x is String -> CH cast(trim(x)
as INT)
- /// Refer to
https://github.com/apache/incubator-gluten/issues/4956
- args[0] = toFunctionNode(actions_dag, "trim", {args[0]});
- }
- else if (DB::isString(non_nullable_input_type) &&
substrait_type.has_bool_())
- {
- /// cast(string to boolean)
- cast_function = "accurateCastOrNull";
- }
-
- /// Common process
- args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DB::DataTypeString>(), output_type->getName()));
- function_node = toFunctionNode(actions_dag, cast_function,
args);
+ /// Common process: CAST(input, type)
+ args.emplace_back(addConstColumn(actions_dag,
std::make_shared<DataTypeString>(), output_type->getName()));
+ result_node = toFunctionNode(actions_dag, "CAST", args);
}
- actions_dag.addOrReplaceInOutputs(*function_node);
- return function_node;
+ actions_dag.addOrReplaceInOutputs(*result_node);
+ return result_node;
}
case substrait::Expression::RexTypeCase::kIfThen: {
diff --git a/cpp-ch/local-engine/Parser/FunctionParser.cpp
b/cpp-ch/local-engine/Parser/FunctionParser.cpp
index c8158d3e3e..7e794dabec 100644
--- a/cpp-ch/local-engine/Parser/FunctionParser.cpp
+++ b/cpp-ch/local-engine/Parser/FunctionParser.cpp
@@ -115,33 +115,60 @@ const ActionsDAG::Node *
FunctionParser::convertNodeTypeIfNeeded(
const substrait::Expression_ScalarFunction & substrait_func, const
ActionsDAG::Node * func_node, ActionsDAG & actions_dag) const
{
const auto & output_type = substrait_func.output_type();
- if (!TypeParser::isTypeMatched(output_type, func_node->result_type))
+ const ActionsDAG::Node * result_node = nullptr;
+
+ auto convert_type_if_needed = [&]()
{
- auto result_type = TypeParser::parseType(substrait_func.output_type());
- if (DB::isDecimalOrNullableDecimal(result_type))
+ if (!TypeParser::isTypeMatched(output_type, func_node->result_type))
{
- return ActionsDAGUtil::convertNodeType(
- actions_dag,
- func_node,
- // as stated in isTypeMatched, currently we don't change
nullability of the result type
- func_node->result_type->isNullable() ?
local_engine::wrapNullableType(true, result_type)
- :
local_engine::removeNullable(result_type),
- func_node->result_name,
- CastType::accurateOrNull);
+ auto result_type =
TypeParser::parseType(substrait_func.output_type());
+ if (DB::isDecimalOrNullableDecimal(result_type))
+ {
+ return ActionsDAGUtil::convertNodeType(
+ actions_dag,
+ func_node,
+ // as stated in isTypeMatched, currently we don't change
nullability of the result type
+ func_node->result_type->isNullable() ?
local_engine::wrapNullableType(true, result_type)
+ :
local_engine::removeNullable(result_type),
+ func_node->result_name,
+ CastType::accurateOrNull);
+ }
+ else
+ {
+ return ActionsDAGUtil::convertNodeType(
+ actions_dag,
+ func_node,
+ // as stated in isTypeMatched, currently we don't change
nullability of the result type
+ func_node->result_type->isNullable() ?
local_engine::wrapNullableType(true, TypeParser::parseType(output_type))
+ :
DB::removeNullable(TypeParser::parseType(output_type)),
+ func_node->result_name);
+ }
}
else
+ return func_node;
+ };
+ result_node = convert_type_if_needed();
+
+ /// Notice that in CH Bool and UInt8 have different serialization and
deserialization methods, which will cause issue when executing cast(bool as
string) in spark in spark.
+ auto convert_uint8_to_bool_if_needed = [&]() -> const auto *
+ {
+ auto * mutable_result_node = const_cast<ActionsDAG::Node
*>(result_node);
+ auto denull_result_type = DB::removeNullable(result_node->result_type);
+ if (isUInt8(denull_result_type) && output_type.has_bool_())
{
- return ActionsDAGUtil::convertNodeType(
- actions_dag,
- func_node,
- // as stated in isTypeMatched, currently we don't change
nullability of the result type
- func_node->result_type->isNullable() ?
local_engine::wrapNullableType(true, TypeParser::parseType(output_type))
- :
DB::removeNullable(TypeParser::parseType(output_type)),
- func_node->result_name);
+ auto bool_type = DB::DataTypeFactory::instance().get("Bool");
+ if (result_node->result_type->isNullable())
+ bool_type = DB::makeNullable(bool_type);
+
+ mutable_result_node->result_type = std::move(bool_type);
+ return mutable_result_node;
}
- }
- else
- return func_node;
+ else
+ return result_node;
+ };
+ result_node = convert_uint8_to_bool_if_needed();
+
+ return result_node;
}
void FunctionParserFactory::registerFunctionParser(const String & name, Value
value)
diff --git a/cpp-ch/local-engine/Parser/FunctionParser.h
b/cpp-ch/local-engine/Parser/FunctionParser.h
index d9ca7b5128..23216f2203 100644
--- a/cpp-ch/local-engine/Parser/FunctionParser.h
+++ b/cpp-ch/local-engine/Parser/FunctionParser.h
@@ -60,6 +60,7 @@ protected:
{
return parseFunctionArguments(substrait_func, actions_dag);
}
+
virtual DB::ActionsDAG::NodeRawConstPtrs
parseFunctionArguments(const substrait::Expression_ScalarFunction &
substrait_func, DB::ActionsDAG & actions_dag) const;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]