This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new d90bc4c00 [GLUTEN-4956][CH] Fix parsing string with blank
prefix/suffix to number (#5022)
d90bc4c00 is described below
commit d90bc4c00baff9526c89db993a9b55ff46f1e5be
Author: 李扬 <[email protected]>
AuthorDate: Thu Mar 21 15:03:16 2024 +0800
[GLUTEN-4956][CH] Fix parsing string with blank prefix/suffix to number
(#5022)
* fix parse string to int
* add uts
* fix failed uts
---
.../execution/GlutenFunctionValidateSuite.scala | 6 +
.../local-engine/Parser/SerializedPlanParser.cpp | 206 +++++++++++----------
cpp-ch/local-engine/Parser/SerializedPlanParser.h | 2 +-
3 files changed, 113 insertions(+), 101 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
index 8eb982693..818fe4e5f 100644
---
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
+++
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
@@ -633,4 +633,10 @@ class GlutenFunctionValidateSuite extends
GlutenClickHouseWholeStageTransformerS
""".stripMargin
runQueryAndCompare(sql)(checkOperatorMatch[ProjectExecTransformer])
}
+
+ test("test parse string with blank to integer") {
+ // issue https://github.com/apache/incubator-gluten/issues/4956
+ val sql = "select cast(concat(' ', cast(id as string)) as bigint) from
range(10)"
+ runQueryAndCompare(sql)(checkOperatorMatch[ProjectExecTransformer])
+ }
}
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index 8723489fc..f0715f500 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -94,14 +94,14 @@ namespace DB
{
namespace ErrorCodes
{
-extern const int LOGICAL_ERROR;
-extern const int UNKNOWN_TYPE;
-extern const int BAD_ARGUMENTS;
-extern const int NO_SUCH_DATA_PART;
-extern const int UNKNOWN_FUNCTION;
-extern const int CANNOT_PARSE_PROTOBUF_SCHEMA;
-extern const int ILLEGAL_TYPE_OF_ARGUMENT;
-extern const int INVALID_JOIN_ON_EXPRESSION;
+ extern const int LOGICAL_ERROR;
+ extern const int UNKNOWN_TYPE;
+ extern const int BAD_ARGUMENTS;
+ extern const int NO_SUCH_DATA_PART;
+ extern const int UNKNOWN_FUNCTION;
+ extern const int CANNOT_PARSE_PROTOBUF_SCHEMA;
+ extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+ extern const int INVALID_JOIN_ON_EXPRESSION;
}
}
@@ -136,7 +136,7 @@ void logDebugMessage(const google::protobuf::Message &
message, const char * typ
}
}
-const ActionsDAG::Node * SerializedPlanParser::addColumn(DB::ActionsDAGPtr
actions_dag, const DataTypePtr & type, const Field & field)
+const ActionsDAG::Node * SerializedPlanParser::addColumn(ActionsDAGPtr
actions_dag, const DataTypePtr & type, const Field & field)
{
return &actions_dag->addColumn(
ColumnWithTypeAndName(type->createColumnConst(1, field), type,
getUniqueName(toString(field).substr(0, 10))));
@@ -157,10 +157,10 @@ void SerializedPlanParser::parseExtensions(
}
}
-std::shared_ptr<DB::ActionsDAG> SerializedPlanParser::expressionsToActionsDAG(
+std::shared_ptr<ActionsDAG> SerializedPlanParser::expressionsToActionsDAG(
const std::vector<substrait::Expression> & expressions,
- const DB::Block & header,
- const DB::Block & read_schema)
+ const Block & header,
+ const Block & read_schema)
{
auto actions_dag =
std::make_shared<ActionsDAG>(blockToNameAndTypeList(header));
NamesWithAliases required_columns;
@@ -329,7 +329,7 @@ IQueryPlanStep *
SerializedPlanParser::addRemoveNullableStep(QueryPlan & plan, c
return nullptr;
auto remove_nullable_actions_dag =
std::make_shared<ActionsDAG>(blockToNameAndTypeList(plan.getCurrentDataStream().header));
- removeNullable(columns, remove_nullable_actions_dag);
+ removeNullableForRequiredColumns(columns, remove_nullable_actions_dag);
auto expression_step =
std::make_unique<ExpressionStep>(plan.getCurrentDataStream(),
remove_nullable_actions_dag);
expression_step->setStepDescription("Remove nullable properties");
auto * step_ptr = expression_step.get();
@@ -418,7 +418,7 @@ QueryPlanPtr
SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan)
auto cols =
query_plan->getCurrentDataStream().header.getNamesAndTypesList();
if (cols.getNames().size() !=
static_cast<size_t>(root_rel.root().names_size()))
{
- throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Missmatch
result columns size.");
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "Missmatch result
columns size.");
}
for (int i = 0; i < static_cast<int>(cols.getNames().size()); i++)
{
@@ -438,17 +438,17 @@ QueryPlanPtr
SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan)
const auto & original_cols =
original_header.getColumnsWithTypeAndName();
if (static_cast<size_t>(output_schema.types_size()) !=
original_cols.size())
{
- throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Mismatch
output schema");
+ throw Exception(ErrorCodes::LOGICAL_ERROR, "Mismatch output
schema");
}
bool need_final_project = false;
- DB::ColumnsWithTypeAndName final_cols;
+ ColumnsWithTypeAndName final_cols;
for (int i = 0; i < output_schema.types_size(); ++i)
{
const auto & col = original_cols[i];
auto type = TypeParser::parseType(output_schema.types(i));
// At present, we only check nullable mismatch.
// intermediate aggregate data is special, no check here.
- if (type->isNullable() != col.type->isNullable() &&
!typeid_cast<const DB::DataTypeAggregateFunction *>(col.type.get()))
+ if (type->isNullable() != col.type->isNullable() &&
!typeid_cast<const DataTypeAggregateFunction *>(col.type.get()))
{
if (type->isNullable())
{
@@ -458,7 +458,7 @@ QueryPlanPtr
SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan)
}
else
{
- final_cols.emplace_back(type->createColumn(),
DB::removeNullable(col.type), col.name);
+ final_cols.emplace_back(type->createColumn(),
removeNullable(col.type), col.name);
need_final_project = true;
}
}
@@ -708,7 +708,7 @@ SerializedPlanParser::getFunctionName(const std::string &
function_signature, co
}
void SerializedPlanParser::parseArrayJoinArguments(
- DB::ActionsDAGPtr & actions_dag,
+ ActionsDAGPtr & actions_dag,
const std::string & function_name,
const substrait::Expression_ScalarFunction & scalar_function,
bool position,
@@ -732,7 +732,7 @@ void SerializedPlanParser::parseArrayJoinArguments(
parseFunctionArguments(actions_dag, parsed_args, function_name_copy,
scalar_function);
auto arg = parsed_args[0];
- auto arg_type = DB::removeNullable(arg->result_type);
+ auto arg_type = removeNullable(arg->result_type);
if (isMap(arg_type))
is_map = true;
else if (isArray(arg_type))
@@ -767,7 +767,7 @@ void SerializedPlanParser::parseArrayJoinArguments(
ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG(
const substrait::Expression & rel,
std::vector<String> & result_names,
- DB::ActionsDAGPtr actions_dag,
+ ActionsDAGPtr actions_dag,
bool keep_result,
bool position)
{
@@ -887,7 +887,7 @@ ActionsDAG::NodeRawConstPtrs
SerializedPlanParser::parseArrayJoinWithDAG(
const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG(
const substrait::Expression & rel,
std::string & result_name,
- DB::ActionsDAGPtr actions_dag,
+ ActionsDAGPtr actions_dag,
bool keep_result)
{
if (!rel.has_scalar_function())
@@ -924,7 +924,7 @@ const ActionsDAG::Node *
SerializedPlanParser::parseFunctionWithDAG(
/// to avoid exception
if (ch_func_name == "formatDateTimeInJodaSyntax")
{
- if (args.size() > 1 &&
isInteger(DB::removeNullable(args[0]->result_type)))
+ if (args.size() > 1 && isInteger(removeNullable(args[0]->result_type)))
ch_func_name = "fromUnixTimestampInJodaSyntax";
}
@@ -1007,7 +1007,7 @@ const ActionsDAG::Node *
SerializedPlanParser::parseFunctionWithDAG(
? local_engine::wrapNullableType(true, result_type)->getName()
: local_engine::removeNullable(result_type)->getName(),
function_node->result_name,
- DB::CastType::accurateOrNull);
+ CastType::accurateOrNull);
}
else
{
@@ -1044,12 +1044,14 @@ bool
SerializedPlanParser::convertBinaryArithmeticFunDecimalArgs(
{
/// for divide/plus/minus, we need to convert first arg to result
precision and scale
/// for multiply, we need to convert first arg to result precision,
but keep scale
- if (isDecimalOrNullableDecimal(args[0]->result_type) &&
isDecimalOrNullableDecimal(args[1]->result_type))
+ auto arg1_type = removeNullable(args[0]->result_type);
+ auto arg2_type = removeNullable(args[1]->result_type);
+ if (isDecimal(arg1_type) && isDecimal(arg2_type))
{
- UInt32 p1 =
getDecimalPrecision(*DB::removeNullable(args[0]->result_type));
- UInt32 s1 =
getDecimalScale(*DB::removeNullable(args[0]->result_type));
- UInt32 p2 =
getDecimalPrecision(*DB::removeNullable(args[1]->result_type));
- UInt32 s2 =
getDecimalScale(*DB::removeNullable(args[1]->result_type));
+ UInt32 p1 = getDecimalPrecision(*arg1_type);
+ UInt32 s1 = getDecimalScale(*arg1_type);
+ UInt32 p2 = getDecimalPrecision(*arg2_type);
+ UInt32 s2 = getDecimalScale(*arg2_type);
UInt32 precision;
UInt32 scale;
@@ -1100,7 +1102,7 @@ bool
SerializedPlanParser::convertBinaryArithmeticFunDecimalArgs(
}
void SerializedPlanParser::parseFunctionArguments(
- DB::ActionsDAGPtr & actions_dag,
+ ActionsDAGPtr & actions_dag,
ActionsDAG::NodeRawConstPtrs & parsed_args,
std::string & function_name,
const substrait::Expression_ScalarFunction & scalar_function)
@@ -1114,7 +1116,7 @@ void SerializedPlanParser::parseFunctionArguments(
{
parseFunctionArgument(actions_dag, parsed_args, function_name,
args[0]);
auto data_type = TypeParser::parseType(scalar_function.output_type());
- parsed_args.emplace_back(addColumn(actions_dag,
std::make_shared<DB::DataTypeString>(), data_type->getName()));
+ parsed_args.emplace_back(addColumn(actions_dag,
std::make_shared<DataTypeString>(), data_type->getName()));
}
else if (function_name == "sparkTupleElement" || function_name ==
"tupleElement")
{
@@ -1124,12 +1126,12 @@ void SerializedPlanParser::parseFunctionArguments(
throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's
second argument must be a literal");
auto [data_type, field] = parseLiteral(args[1].value().literal());
- if (data_type->getTypeId() != DB::TypeIndex::Int32)
+ if (data_type->getTypeId() != TypeIndex::Int32)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's
second argument must be i32");
// tuple indecies start from 1, in spark, start from 0
Int32 field_index = static_cast<Int32>(field.get<Int32>() + 1);
- const auto * index_node = addColumn(actions_dag,
std::make_shared<DB::DataTypeUInt32>(), field_index);
+ const auto * index_node = addColumn(actions_dag,
std::make_shared<DataTypeUInt32>(), field_index);
parsed_args.emplace_back(index_node);
}
else if (function_name == "tuple")
@@ -1144,26 +1146,26 @@ void SerializedPlanParser::parseFunctionArguments(
// repeat. the field index must be unsigned integer in CH, cast the
signed integer in substrait
// which must be a positive value into unsigned integer here.
parseFunctionArgument(actions_dag, parsed_args, function_name,
args[0]);
- const DB::ActionsDAG::Node * repeat_times_node =
parseFunctionArgument(actions_dag, function_name, args[1]);
- DB::DataTypeNullable
target_type(std::make_shared<DB::DataTypeUInt32>());
+ const ActionsDAG::Node * repeat_times_node =
parseFunctionArgument(actions_dag, function_name, args[1]);
+ DataTypeNullable target_type(std::make_shared<DataTypeUInt32>());
repeat_times_node = ActionsDAGUtil::convertNodeType(actions_dag,
repeat_times_node, target_type.getName());
parsed_args.emplace_back(repeat_times_node);
}
else if (function_name == "isNaN")
{
// the result of isNaN(NULL) is NULL in CH, but false in Spark
- const DB::ActionsDAG::Node * arg_node = nullptr;
+ const ActionsDAG::Node * arg_node = nullptr;
if (args[0].value().has_cast())
{
arg_node = parseExpression(actions_dag,
args[0].value().cast().input());
const auto * res_type = arg_node->result_type.get();
if (res_type->isNullable())
{
- res_type = typeid_cast<const DB::DataTypeNullable
*>(res_type)->getNestedType().get();
+ res_type = typeid_cast<const DataTypeNullable
*>(res_type)->getNestedType().get();
}
if (isString(*res_type))
{
- DB::ActionsDAG::NodeRawConstPtrs cast_func_args = {arg_node};
+ ActionsDAG::NodeRawConstPtrs cast_func_args = {arg_node};
arg_node = toFunctionNode(actions_dag, "toFloat64OrZero",
cast_func_args);
}
else
@@ -1176,14 +1178,14 @@ void SerializedPlanParser::parseFunctionArguments(
arg_node = parseFunctionArgument(actions_dag, function_name,
args[0]);
}
- DB::ActionsDAG::NodeRawConstPtrs ifnull_func_args = {arg_node,
addColumn(actions_dag, std::make_shared<DataTypeInt32>(), 0)};
+ ActionsDAG::NodeRawConstPtrs ifnull_func_args = {arg_node,
addColumn(actions_dag, std::make_shared<DataTypeInt32>(), 0)};
parsed_args.emplace_back(toFunctionNode(actions_dag, "IfNull",
ifnull_func_args));
}
else if (function_name == "space")
{
// convert space function to repeat
- const DB::ActionsDAG::Node * repeat_times_node =
parseFunctionArgument(actions_dag, "repeat", args[0]);
- const DB::ActionsDAG::Node * space_str_node = addColumn(actions_dag,
std::make_shared<DataTypeString>(), " ");
+ const ActionsDAG::Node * repeat_times_node =
parseFunctionArgument(actions_dag, "repeat", args[0]);
+ const ActionsDAG::Node * space_str_node = addColumn(actions_dag,
std::make_shared<DataTypeString>(), " ");
function_name = "repeat";
parsed_args.emplace_back(space_str_node);
parsed_args.emplace_back(repeat_times_node);
@@ -1225,7 +1227,7 @@ void SerializedPlanParser::parseFunctionArguments(
}
void SerializedPlanParser::parseFunctionArgument(
- DB::ActionsDAGPtr & actions_dag,
+ ActionsDAGPtr & actions_dag,
ActionsDAG::NodeRawConstPtrs & parsed_args,
const std::string & function_name,
const substrait::FunctionArgument & arg)
@@ -1233,12 +1235,12 @@ void SerializedPlanParser::parseFunctionArgument(
parsed_args.emplace_back(parseFunctionArgument(actions_dag, function_name,
arg));
}
-const DB::ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument(
- DB::ActionsDAGPtr & actions_dag,
+const ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument(
+ ActionsDAGPtr & actions_dag,
const std::string & function_name,
const substrait::FunctionArgument & arg)
{
- const DB::ActionsDAG::Node * res;
+ const ActionsDAG::Node * res;
if (arg.value().has_scalar_function())
{
std::string arg_name;
@@ -1254,18 +1256,18 @@ const DB::ActionsDAG::Node *
SerializedPlanParser::parseFunctionArgument(
}
// Convert signed integer index into unsigned integer index
-std::pair<DB::DataTypePtr, DB::Field>
SerializedPlanParser::convertStructFieldType(const DB::DataTypePtr & type,
const DB::Field & field)
+std::pair<DataTypePtr, Field>
SerializedPlanParser::convertStructFieldType(const DataTypePtr & type, const
Field & field)
{
// For tupelElement, field index starts from 1, but int substrait plan, it
starts from 0.
#define UINT_CONVERT(type_ptr, field, type_name) \
- if ((type_ptr)->getTypeId() == DB::TypeIndex::type_name) \
+ if ((type_ptr)->getTypeId() == TypeIndex::type_name) \
{ \
- return {std::make_shared<DB::DataTypeU##type_name>(),
static_cast<U##type_name>((field).get<type_name>()) + 1}; \
+ return {std::make_shared<DataTypeU##type_name>(),
static_cast<U##type_name>((field).get<type_name>()) + 1}; \
}
auto type_id = type->getTypeId();
- if (type_id == DB::TypeIndex::UInt8 || type_id == DB::TypeIndex::UInt16 ||
type_id == DB::TypeIndex::UInt32
- || type_id == DB::TypeIndex::UInt64)
+ if (type_id == TypeIndex::UInt8 || type_id == TypeIndex::UInt16 || type_id
== TypeIndex::UInt32
+ || type_id == TypeIndex::UInt64)
{
return {type, field};
}
@@ -1273,7 +1275,7 @@ std::pair<DB::DataTypePtr, DB::Field>
SerializedPlanParser::convertStructFieldTy
UINT_CONVERT(type, field, Int16)
UINT_CONVERT(type, field, Int32)
UINT_CONVERT(type, field, Int64)
- throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Not valid
interger type: {}", type->getName());
+ throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Not valid interger
type: {}", type->getName());
#undef UINT_CONVERT
}
@@ -1349,16 +1351,17 @@ ActionsDAGPtr SerializedPlanParser::parseJsonTuple(
throw Exception(ErrorCodes::BAD_ARGUMENTS, "The function json_tuple
should has at least 2 arguments.");
}
auto first_arg = args[0].value();
- const DB::ActionsDAG::Node * json_expr_node = parseExpression(actions_dag,
first_arg);
+ const ActionsDAG::Node * json_expr_node = parseExpression(actions_dag,
first_arg);
std::string extract_expr = "Tuple(";
for (int i = 1; i < args.size(); i++)
{
auto arg_value = args[i].value();
if (!arg_value.has_literal())
{
- throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "The arguments
of function {} must be string literal", function_name);
+ throw Exception(ErrorCodes::BAD_ARGUMENTS, "The arguments of
function {} must be string literal", function_name);
}
- DB::Field f = arg_value.literal().string();
+
+ Field f = arg_value.literal().string();
std::string s;
if (f.tryGet(s))
{
@@ -1370,7 +1373,7 @@ ActionsDAGPtr SerializedPlanParser::parseJsonTuple(
}
}
extract_expr.append(")");
- const DB::ActionsDAG::Node * extract_expr_node = addColumn(actions_dag,
std::make_shared<DataTypeString>(), extract_expr);
+ const ActionsDAG::Node * extract_expr_node = addColumn(actions_dag,
std::make_shared<DataTypeString>(), extract_expr);
auto json_extract_builder = FunctionFactory::instance().get("JSONExtract",
context);
auto json_extract_result_name = "JSONExtract(" +
json_expr_node->result_name + "," + extract_expr_node->result_name + ")";
const ActionsDAG::Node * json_extract_node
@@ -1396,9 +1399,9 @@ ActionsDAGPtr SerializedPlanParser::parseJsonTuple(
}
const ActionsDAG::Node *
-SerializedPlanParser::toFunctionNode(ActionsDAGPtr actions_dag, const String &
function, const DB::ActionsDAG::NodeRawConstPtrs & args)
+SerializedPlanParser::toFunctionNode(ActionsDAGPtr actions_dag, const String &
function, const ActionsDAG::NodeRawConstPtrs & args)
{
- auto function_builder = DB::FunctionFactory::instance().get(function,
context);
+ auto function_builder = FunctionFactory::instance().get(function, context);
std::string args_name = join(args, ',');
auto result_name = function + "(" + args_name + ")";
const auto * function_node = &actions_dag->addFunction(function_builder,
args, result_name);
@@ -1632,48 +1635,51 @@ const ActionsDAG::Node *
SerializedPlanParser::parseExpression(ActionsDAGPtr act
case substrait::Expression::RexTypeCase::kCast: {
if (!rel.cast().has_type() || !rel.cast().has_input())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type
or input in cast node.");
- DB::ActionsDAG::NodeRawConstPtrs args;
+ ActionsDAG::NodeRawConstPtrs args;
const auto & input = rel.cast().input();
args.emplace_back(parseExpression(actions_dag, input));
const auto & substrait_type = rel.cast().type();
+ const auto & input_type = args[0]->result_type;
+ DataTypePtr non_nullable_input_type = removeNullable(input_type);
+ DataTypePtr output_type = TypeParser::parseType(substrait_type);
+ DataTypePtr non_nullable_output_type = removeNullable(output_type);
+
const ActionsDAG::Node * function_node = nullptr;
- if (DB::isString(DB::removeNullable(args.back()->result_type)) &&
substrait_type.has_date())
+ if (substrait_type.has_binary())
{
- function_node = toFunctionNode(actions_dag, "sparkToDate",
args);
+ /// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
+ function_node = toFunctionNode(actions_dag,
"reinterpretAsStringSpark", args);
}
- else if
(DB::isString(DB::removeNullable(args.back()->result_type)) &&
substrait_type.has_timestamp())
- {
+ else if (isString(non_nullable_input_type) &&
isDate32(non_nullable_output_type))
+ function_node = toFunctionNode(actions_dag, "sparkToDate",
args);
+ else if (isString(non_nullable_input_type) &&
isDateTime64(non_nullable_output_type))
function_node = toFunctionNode(actions_dag, "sparkToDateTime",
args);
+ else if (isDecimal(non_nullable_input_type) &&
isString(non_nullable_output_type))
+ {
+ /// Spark cast(x as STRING) if x is Decimal -> CH
toDecimalString(x, scale)
+ UInt8 scale = getDecimalScale(*non_nullable_input_type);
+ args.emplace_back(addColumn(actions_dag,
std::make_shared<DataTypeUInt8>(), Field(scale)));
+ function_node = toFunctionNode(actions_dag, "toDecimalString",
args);
}
- else if (substrait_type.has_binary())
+ else if (isFloat(non_nullable_input_type) &&
isInt(non_nullable_output_type))
{
- // Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
- function_node = toFunctionNode(actions_dag,
"reinterpretAsStringSpark", args);
+ String function_name = "sparkCastFloatTo" +
non_nullable_output_type->getName();
+ function_node = toFunctionNode(actions_dag, function_name,
args);
}
else
{
- DataTypePtr ch_type = TypeParser::parseType(substrait_type);
- if (DB::isString(DB::removeNullable(ch_type)) &&
isDecimalOrNullableDecimal(args[0]->result_type))
+ if (isString(non_nullable_input_type) &&
isInt(non_nullable_output_type))
{
- UInt8 scale =
getDecimalScale(*DB::removeNullable(args[0]->result_type));
- args.emplace_back(addColumn(actions_dag,
std::make_shared<DataTypeUInt8>(), Field(scale)));
- function_node = toFunctionNode(actions_dag,
"toDecimalString", args);
- }
- else
- {
- if (isFloat(DB::removeNullable(args[0]->result_type)) &&
isInt(DB::removeNullable(ch_type)))
- {
- String function_name = "sparkCastFloatTo" +
DB::removeNullable(ch_type)->getName();
- function_node = toFunctionNode(actions_dag,
function_name, args);
- }
- else
- {
- args.emplace_back(addColumn(actions_dag,
std::make_shared<DataTypeString>(), ch_type->getName()));
- function_node = toFunctionNode(actions_dag, "CAST",
args);
- }
+ /// Spark cast(x as INT) if x is String -> CH cast(trim(x)
as INT)
+ /// Refer to
https://github.com/apache/incubator-gluten/issues/4956
+ args[0] = toFunctionNode(actions_dag, "trim", {args[0]});
}
+
+ /// Common process
+ args.emplace_back(addColumn(actions_dag,
std::make_shared<DataTypeString>(), output_type->getName()));
+ function_node = toFunctionNode(actions_dag, "CAST", args);
}
actions_dag->addOrReplaceInOutputs(*function_node);
@@ -1682,13 +1688,13 @@ const ActionsDAG::Node *
SerializedPlanParser::parseExpression(ActionsDAGPtr act
case substrait::Expression::RexTypeCase::kIfThen: {
const auto & if_then = rel.if_then();
- DB::FunctionOverloadResolverPtr function_ptr = nullptr;
+ FunctionOverloadResolverPtr function_ptr = nullptr;
auto condition_nums = if_then.ifs_size();
if (condition_nums == 1)
- function_ptr = DB::FunctionFactory::instance().get("if",
context);
+ function_ptr = FunctionFactory::instance().get("if", context);
else
- function_ptr = DB::FunctionFactory::instance().get("multiIf",
context);
- DB::ActionsDAG::NodeRawConstPtrs args;
+ function_ptr = FunctionFactory::instance().get("multiIf",
context);
+ ActionsDAG::NodeRawConstPtrs args;
for (int i = 0; i < condition_nums; ++i)
{
@@ -1727,7 +1733,7 @@ const ActionsDAG::Node *
SerializedPlanParser::parseExpression(ActionsDAGPtr act
if (!options[0].has_literal())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Options of
SingularOrList must have literal type");
- DB::ActionsDAG::NodeRawConstPtrs args;
+ ActionsDAG::NodeRawConstPtrs args;
args.emplace_back(parseExpression(actions_dag,
rel.singular_or_list().value()));
bool nullable = false;
@@ -1781,7 +1787,7 @@ const ActionsDAG::Node *
SerializedPlanParser::parseExpression(ActionsDAGPtr act
/// In CH: return `false`
/// So we used if(a, b, c) cast `false` to `null` if sets has
`null`
auto type = wrapNullableType(true, function_node->result_type);
- DB::ActionsDAG::NodeRawConstPtrs cast_args(
+ ActionsDAG::NodeRawConstPtrs cast_args(
{function_node, addColumn(actions_dag, type, true),
addColumn(actions_dag, type, Field())});
auto cast = FunctionFactory::instance().get("if", context);
function_node = toFunctionNode(actions_dag, "if", cast_args);
@@ -2004,7 +2010,7 @@ ASTPtr ASTParser::parseArgumentToAST(const Names & names,
const substrait::Expre
const auto & if_then = rel.if_then();
auto condition_nums = if_then.ifs_size();
std::string ch_function_name = condition_nums == 1 ? "if" :
"multiIf";
- auto function_multi_if =
DB::FunctionFactory::instance().get(ch_function_name, context);
+ auto function_multi_if =
FunctionFactory::instance().get(ch_function_name, context);
ASTs args;
for (int i = 0; i < condition_nums; ++i)
@@ -2088,7 +2094,7 @@ ASTPtr ASTParser::parseArgumentToAST(const Names & names,
const substrait::Expre
}
}
-void SerializedPlanParser::removeNullable(const std::set<String> &
require_columns, ActionsDAGPtr actions_dag)
+void SerializedPlanParser::removeNullableForRequiredColumns(const
std::set<String> & require_columns, ActionsDAGPtr actions_dag)
{
for (const auto & item : require_columns)
{
@@ -2140,14 +2146,14 @@ void LocalExecutor::execute(QueryPlanPtr query_plan)
current_query_plan = std::move(query_plan);
auto * logger = &Poco::Logger::get("LocalExecutor");
- DB::QueryPriorities priorities;
- auto query_status = std::make_shared<DB::QueryStatus>(
+ QueryPriorities priorities;
+ auto query_status = std::make_shared<QueryStatus>(
context,
"",
context->getClientInfo(),
priorities.insert(static_cast<int>(settings.priority)),
- DB::CurrentThread::getGroup(),
- DB::IAST::QueryKind::Select,
+ CurrentThread::getGroup(),
+ IAST::QueryKind::Select,
settings,
0);
@@ -2201,7 +2207,7 @@ bool LocalExecutor::hasNext()
has_next = true;
}
}
- catch (DB::Exception & e)
+ catch (Exception & e)
{
LOG_ERROR(
&Poco::Logger::get("LocalExecutor"),
@@ -2263,7 +2269,7 @@ std::string LocalExecutor::dumpPipeline()
const auto & processors = query_pipeline.getProcessors();
for (auto & processor : processors)
{
- DB::WriteBufferFromOwnString buffer;
+ WriteBufferFromOwnString buffer;
auto data_stats = processor->getProcessorDataStats();
buffer << "(";
buffer << "\nexcution time: " << processor->getElapsedUs() << " us.";
@@ -2276,13 +2282,13 @@ std::string LocalExecutor::dumpPipeline()
buffer << ")";
processor->setDescription(buffer.str());
}
- DB::WriteBufferFromOwnString out;
- DB::printPipeline(processors, out);
+ WriteBufferFromOwnString out;
+ printPipeline(processors, out);
return out.str();
}
NonNullableColumnsResolver::NonNullableColumnsResolver(
- const DB::Block & header_,
+ const Block & header_,
SerializedPlanParser & parser_,
const substrait::Expression & cond_rel_)
: header(header_)
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
index 1ecce8722..87c86f027 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
@@ -384,7 +384,7 @@ private:
const ActionsDAG::Node *
toFunctionNode(ActionsDAGPtr actions_dag, const String & function, const
DB::ActionsDAG::NodeRawConstPtrs & args);
// remove nullable after isNotNull
- void removeNullable(const std::set<String> & require_columns,
ActionsDAGPtr actions_dag);
+ void removeNullableForRequiredColumns(const std::set<String> &
require_columns, ActionsDAGPtr actions_dag);
std::string getUniqueName(const std::string & name) { return name + "_" +
std::to_string(name_no++); }
static std::pair<DataTypePtr, Field> parseLiteral(const
substrait::Expression_Literal & literal);
void wrapNullable(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]