This is an automated email from the ASF dual-hosted git repository.

changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new d90bc4c00 [GLUTEN-4956][CH] Fix parsing string with blank 
prefix/suffix to number  (#5022)
d90bc4c00 is described below

commit d90bc4c00baff9526c89db993a9b55ff46f1e5be
Author: 李扬 <[email protected]>
AuthorDate: Thu Mar 21 15:03:16 2024 +0800

    [GLUTEN-4956][CH] Fix parsing string with blank prefix/suffix to number  
(#5022)
    
    * fix parse string to int
    
    * add uts
    
    * fix failed uts
---
 .../execution/GlutenFunctionValidateSuite.scala    |   6 +
 .../local-engine/Parser/SerializedPlanParser.cpp   | 206 +++++++++++----------
 cpp-ch/local-engine/Parser/SerializedPlanParser.h  |   2 +-
 3 files changed, 113 insertions(+), 101 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
index 8eb982693..818fe4e5f 100644
--- 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
@@ -633,4 +633,10 @@ class GlutenFunctionValidateSuite extends 
GlutenClickHouseWholeStageTransformerS
       """.stripMargin
     runQueryAndCompare(sql)(checkOperatorMatch[ProjectExecTransformer])
   }
+
+  test("test parse string with blank to integer") {
+    // issue https://github.com/apache/incubator-gluten/issues/4956
+    val sql = "select  cast(concat(' ', cast(id as string)) as bigint) from 
range(10)"
+    runQueryAndCompare(sql)(checkOperatorMatch[ProjectExecTransformer])
+  }
 }
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp 
b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index 8723489fc..f0715f500 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -94,14 +94,14 @@ namespace DB
 {
 namespace ErrorCodes
 {
-extern const int LOGICAL_ERROR;
-extern const int UNKNOWN_TYPE;
-extern const int BAD_ARGUMENTS;
-extern const int NO_SUCH_DATA_PART;
-extern const int UNKNOWN_FUNCTION;
-extern const int CANNOT_PARSE_PROTOBUF_SCHEMA;
-extern const int ILLEGAL_TYPE_OF_ARGUMENT;
-extern const int INVALID_JOIN_ON_EXPRESSION;
+    extern const int LOGICAL_ERROR;
+    extern const int UNKNOWN_TYPE;
+    extern const int BAD_ARGUMENTS;
+    extern const int NO_SUCH_DATA_PART;
+    extern const int UNKNOWN_FUNCTION;
+    extern const int CANNOT_PARSE_PROTOBUF_SCHEMA;
+    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+    extern const int INVALID_JOIN_ON_EXPRESSION;
 }
 }
 
@@ -136,7 +136,7 @@ void logDebugMessage(const google::protobuf::Message & 
message, const char * typ
     }
 }
 
-const ActionsDAG::Node * SerializedPlanParser::addColumn(DB::ActionsDAGPtr 
actions_dag, const DataTypePtr & type, const Field & field)
+const ActionsDAG::Node * SerializedPlanParser::addColumn(ActionsDAGPtr 
actions_dag, const DataTypePtr & type, const Field & field)
 {
     return &actions_dag->addColumn(
         ColumnWithTypeAndName(type->createColumnConst(1, field), type, 
getUniqueName(toString(field).substr(0, 10))));
@@ -157,10 +157,10 @@ void SerializedPlanParser::parseExtensions(
     }
 }
 
-std::shared_ptr<DB::ActionsDAG> SerializedPlanParser::expressionsToActionsDAG(
+std::shared_ptr<ActionsDAG> SerializedPlanParser::expressionsToActionsDAG(
     const std::vector<substrait::Expression> & expressions,
-    const DB::Block & header,
-    const DB::Block & read_schema)
+    const Block & header,
+    const Block & read_schema)
 {
     auto actions_dag = 
std::make_shared<ActionsDAG>(blockToNameAndTypeList(header));
     NamesWithAliases required_columns;
@@ -329,7 +329,7 @@ IQueryPlanStep * 
SerializedPlanParser::addRemoveNullableStep(QueryPlan & plan, c
         return nullptr;
 
     auto remove_nullable_actions_dag = 
std::make_shared<ActionsDAG>(blockToNameAndTypeList(plan.getCurrentDataStream().header));
-    removeNullable(columns, remove_nullable_actions_dag);
+    removeNullableForRequiredColumns(columns, remove_nullable_actions_dag);
     auto expression_step = 
std::make_unique<ExpressionStep>(plan.getCurrentDataStream(), 
remove_nullable_actions_dag);
     expression_step->setStepDescription("Remove nullable properties");
     auto * step_ptr = expression_step.get();
@@ -418,7 +418,7 @@ QueryPlanPtr 
SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan)
             auto cols = 
query_plan->getCurrentDataStream().header.getNamesAndTypesList();
             if (cols.getNames().size() != 
static_cast<size_t>(root_rel.root().names_size()))
             {
-                throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Missmatch 
result columns size.");
+                throw Exception(ErrorCodes::LOGICAL_ERROR, "Missmatch result 
columns size.");
             }
             for (int i = 0; i < static_cast<int>(cols.getNames().size()); i++)
             {
@@ -438,17 +438,17 @@ QueryPlanPtr 
SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan)
             const auto & original_cols = 
original_header.getColumnsWithTypeAndName();
             if (static_cast<size_t>(output_schema.types_size()) != 
original_cols.size())
             {
-                throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Mismatch 
output schema");
+                throw Exception(ErrorCodes::LOGICAL_ERROR, "Mismatch output 
schema");
             }
             bool need_final_project = false;
-            DB::ColumnsWithTypeAndName final_cols;
+            ColumnsWithTypeAndName final_cols;
             for (int i = 0; i < output_schema.types_size(); ++i)
             {
                 const auto & col = original_cols[i];
                 auto type = TypeParser::parseType(output_schema.types(i));
                 // At present, we only check nullable mismatch.
                 // intermediate aggregate data is special, no check here.
-                if (type->isNullable() != col.type->isNullable() && 
!typeid_cast<const DB::DataTypeAggregateFunction *>(col.type.get()))
+                if (type->isNullable() != col.type->isNullable() && 
!typeid_cast<const DataTypeAggregateFunction *>(col.type.get()))
                 {
                     if (type->isNullable())
                     {
@@ -458,7 +458,7 @@ QueryPlanPtr 
SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan)
                     }
                     else
                     {
-                        final_cols.emplace_back(type->createColumn(), 
DB::removeNullable(col.type), col.name);
+                        final_cols.emplace_back(type->createColumn(), 
removeNullable(col.type), col.name);
                         need_final_project = true;
                     }
                 }
@@ -708,7 +708,7 @@ SerializedPlanParser::getFunctionName(const std::string & 
function_signature, co
 }
 
 void SerializedPlanParser::parseArrayJoinArguments(
-    DB::ActionsDAGPtr & actions_dag,
+    ActionsDAGPtr & actions_dag,
     const std::string & function_name,
     const substrait::Expression_ScalarFunction & scalar_function,
     bool position,
@@ -732,7 +732,7 @@ void SerializedPlanParser::parseArrayJoinArguments(
     parseFunctionArguments(actions_dag, parsed_args, function_name_copy, 
scalar_function);
 
     auto arg = parsed_args[0];
-    auto arg_type = DB::removeNullable(arg->result_type);
+    auto arg_type = removeNullable(arg->result_type);
     if (isMap(arg_type))
         is_map = true;
     else if (isArray(arg_type))
@@ -767,7 +767,7 @@ void SerializedPlanParser::parseArrayJoinArguments(
 ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG(
     const substrait::Expression & rel,
     std::vector<String> & result_names,
-    DB::ActionsDAGPtr actions_dag,
+    ActionsDAGPtr actions_dag,
     bool keep_result,
     bool position)
 {
@@ -887,7 +887,7 @@ ActionsDAG::NodeRawConstPtrs 
SerializedPlanParser::parseArrayJoinWithDAG(
 const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG(
     const substrait::Expression & rel,
     std::string & result_name,
-    DB::ActionsDAGPtr actions_dag,
+    ActionsDAGPtr actions_dag,
     bool keep_result)
 {
     if (!rel.has_scalar_function())
@@ -924,7 +924,7 @@ const ActionsDAG::Node * 
SerializedPlanParser::parseFunctionWithDAG(
     /// to avoid exception
     if (ch_func_name == "formatDateTimeInJodaSyntax")
     {
-        if (args.size() > 1 && 
isInteger(DB::removeNullable(args[0]->result_type)))
+        if (args.size() > 1 && isInteger(removeNullable(args[0]->result_type)))
             ch_func_name = "fromUnixTimestampInJodaSyntax";
     }
 
@@ -1007,7 +1007,7 @@ const ActionsDAG::Node * 
SerializedPlanParser::parseFunctionWithDAG(
                 ? local_engine::wrapNullableType(true, result_type)->getName()
                 : local_engine::removeNullable(result_type)->getName(),
                 function_node->result_name,
-                DB::CastType::accurateOrNull);
+                CastType::accurateOrNull);
         }
         else
         {
@@ -1044,12 +1044,14 @@ bool 
SerializedPlanParser::convertBinaryArithmeticFunDecimalArgs(
     {
         /// for divide/plus/minus, we need to convert first arg to result 
precision and scale
         /// for multiply, we need to convert first arg to result precision, 
but keep scale
-        if (isDecimalOrNullableDecimal(args[0]->result_type) && 
isDecimalOrNullableDecimal(args[1]->result_type))
+        auto arg1_type = removeNullable(args[0]->result_type);
+        auto arg2_type = removeNullable(args[1]->result_type);
+        if (isDecimal(arg1_type) && isDecimal(arg2_type))
         {
-            UInt32 p1 = 
getDecimalPrecision(*DB::removeNullable(args[0]->result_type));
-            UInt32 s1 = 
getDecimalScale(*DB::removeNullable(args[0]->result_type));
-            UInt32 p2 = 
getDecimalPrecision(*DB::removeNullable(args[1]->result_type));
-            UInt32 s2 = 
getDecimalScale(*DB::removeNullable(args[1]->result_type));
+            UInt32 p1 = getDecimalPrecision(*arg1_type);
+            UInt32 s1 = getDecimalScale(*arg1_type);
+            UInt32 p2 = getDecimalPrecision(*arg2_type);
+            UInt32 s2 = getDecimalScale(*arg2_type);
 
             UInt32 precision;
             UInt32 scale;
@@ -1100,7 +1102,7 @@ bool 
SerializedPlanParser::convertBinaryArithmeticFunDecimalArgs(
 }
 
 void SerializedPlanParser::parseFunctionArguments(
-    DB::ActionsDAGPtr & actions_dag,
+    ActionsDAGPtr & actions_dag,
     ActionsDAG::NodeRawConstPtrs & parsed_args,
     std::string & function_name,
     const substrait::Expression_ScalarFunction & scalar_function)
@@ -1114,7 +1116,7 @@ void SerializedPlanParser::parseFunctionArguments(
     {
         parseFunctionArgument(actions_dag, parsed_args, function_name, 
args[0]);
         auto data_type = TypeParser::parseType(scalar_function.output_type());
-        parsed_args.emplace_back(addColumn(actions_dag, 
std::make_shared<DB::DataTypeString>(), data_type->getName()));
+        parsed_args.emplace_back(addColumn(actions_dag, 
std::make_shared<DataTypeString>(), data_type->getName()));
     }
     else if (function_name == "sparkTupleElement" || function_name == 
"tupleElement")
     {
@@ -1124,12 +1126,12 @@ void SerializedPlanParser::parseFunctionArguments(
             throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's 
second argument must be a literal");
 
         auto [data_type, field] = parseLiteral(args[1].value().literal());
-        if (data_type->getTypeId() != DB::TypeIndex::Int32)
+        if (data_type->getTypeId() != TypeIndex::Int32)
             throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's 
second argument must be i32");
 
         // tuple indecies start from 1, in spark, start from 0
         Int32 field_index = static_cast<Int32>(field.get<Int32>() + 1);
-        const auto * index_node = addColumn(actions_dag, 
std::make_shared<DB::DataTypeUInt32>(), field_index);
+        const auto * index_node = addColumn(actions_dag, 
std::make_shared<DataTypeUInt32>(), field_index);
         parsed_args.emplace_back(index_node);
     }
     else if (function_name == "tuple")
@@ -1144,26 +1146,26 @@ void SerializedPlanParser::parseFunctionArguments(
         // repeat. the field index must be unsigned integer in CH, cast the 
signed integer in substrait
         // which must be a positive value into unsigned integer here.
         parseFunctionArgument(actions_dag, parsed_args, function_name, 
args[0]);
-        const DB::ActionsDAG::Node * repeat_times_node = 
parseFunctionArgument(actions_dag, function_name, args[1]);
-        DB::DataTypeNullable 
target_type(std::make_shared<DB::DataTypeUInt32>());
+        const ActionsDAG::Node * repeat_times_node = 
parseFunctionArgument(actions_dag, function_name, args[1]);
+        DataTypeNullable target_type(std::make_shared<DataTypeUInt32>());
         repeat_times_node = ActionsDAGUtil::convertNodeType(actions_dag, 
repeat_times_node, target_type.getName());
         parsed_args.emplace_back(repeat_times_node);
     }
     else if (function_name == "isNaN")
     {
         // the result of isNaN(NULL) is NULL in CH, but false in Spark
-        const DB::ActionsDAG::Node * arg_node = nullptr;
+        const ActionsDAG::Node * arg_node = nullptr;
         if (args[0].value().has_cast())
         {
             arg_node = parseExpression(actions_dag, 
args[0].value().cast().input());
             const auto * res_type = arg_node->result_type.get();
             if (res_type->isNullable())
             {
-                res_type = typeid_cast<const DB::DataTypeNullable 
*>(res_type)->getNestedType().get();
+                res_type = typeid_cast<const DataTypeNullable 
*>(res_type)->getNestedType().get();
             }
             if (isString(*res_type))
             {
-                DB::ActionsDAG::NodeRawConstPtrs cast_func_args = {arg_node};
+                ActionsDAG::NodeRawConstPtrs cast_func_args = {arg_node};
                 arg_node = toFunctionNode(actions_dag, "toFloat64OrZero", 
cast_func_args);
             }
             else
@@ -1176,14 +1178,14 @@ void SerializedPlanParser::parseFunctionArguments(
             arg_node = parseFunctionArgument(actions_dag, function_name, 
args[0]);
         }
 
-        DB::ActionsDAG::NodeRawConstPtrs ifnull_func_args = {arg_node, 
addColumn(actions_dag, std::make_shared<DataTypeInt32>(), 0)};
+        ActionsDAG::NodeRawConstPtrs ifnull_func_args = {arg_node, 
addColumn(actions_dag, std::make_shared<DataTypeInt32>(), 0)};
         parsed_args.emplace_back(toFunctionNode(actions_dag, "IfNull", 
ifnull_func_args));
     }
     else if (function_name == "space")
     {
         // convert space function to repeat
-        const DB::ActionsDAG::Node * repeat_times_node = 
parseFunctionArgument(actions_dag, "repeat", args[0]);
-        const DB::ActionsDAG::Node * space_str_node = addColumn(actions_dag, 
std::make_shared<DataTypeString>(), " ");
+        const ActionsDAG::Node * repeat_times_node = 
parseFunctionArgument(actions_dag, "repeat", args[0]);
+        const ActionsDAG::Node * space_str_node = addColumn(actions_dag, 
std::make_shared<DataTypeString>(), " ");
         function_name = "repeat";
         parsed_args.emplace_back(space_str_node);
         parsed_args.emplace_back(repeat_times_node);
@@ -1225,7 +1227,7 @@ void SerializedPlanParser::parseFunctionArguments(
 }
 
 void SerializedPlanParser::parseFunctionArgument(
-    DB::ActionsDAGPtr & actions_dag,
+    ActionsDAGPtr & actions_dag,
     ActionsDAG::NodeRawConstPtrs & parsed_args,
     const std::string & function_name,
     const substrait::FunctionArgument & arg)
@@ -1233,12 +1235,12 @@ void SerializedPlanParser::parseFunctionArgument(
     parsed_args.emplace_back(parseFunctionArgument(actions_dag, function_name, 
arg));
 }
 
-const DB::ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument(
-    DB::ActionsDAGPtr & actions_dag,
+const ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument(
+    ActionsDAGPtr & actions_dag,
     const std::string & function_name,
     const substrait::FunctionArgument & arg)
 {
-    const DB::ActionsDAG::Node * res;
+    const ActionsDAG::Node * res;
     if (arg.value().has_scalar_function())
     {
         std::string arg_name;
@@ -1254,18 +1256,18 @@ const DB::ActionsDAG::Node * 
SerializedPlanParser::parseFunctionArgument(
 }
 
 // Convert signed integer index into unsigned integer index
-std::pair<DB::DataTypePtr, DB::Field> 
SerializedPlanParser::convertStructFieldType(const DB::DataTypePtr & type, 
const DB::Field & field)
+std::pair<DataTypePtr, Field> 
SerializedPlanParser::convertStructFieldType(const DataTypePtr & type, const 
Field & field)
 {
     // For tupelElement, field index starts from 1, but int substrait plan, it 
starts from 0.
 #define UINT_CONVERT(type_ptr, field, type_name) \
-    if ((type_ptr)->getTypeId() == DB::TypeIndex::type_name) \
+    if ((type_ptr)->getTypeId() == TypeIndex::type_name) \
     { \
-        return {std::make_shared<DB::DataTypeU##type_name>(), 
static_cast<U##type_name>((field).get<type_name>()) + 1}; \
+        return {std::make_shared<DataTypeU##type_name>(), 
static_cast<U##type_name>((field).get<type_name>()) + 1}; \
     }
 
     auto type_id = type->getTypeId();
-    if (type_id == DB::TypeIndex::UInt8 || type_id == DB::TypeIndex::UInt16 || 
type_id == DB::TypeIndex::UInt32
-        || type_id == DB::TypeIndex::UInt64)
+    if (type_id == TypeIndex::UInt8 || type_id == TypeIndex::UInt16 || type_id 
== TypeIndex::UInt32
+        || type_id == TypeIndex::UInt64)
     {
         return {type, field};
     }
@@ -1273,7 +1275,7 @@ std::pair<DB::DataTypePtr, DB::Field> 
SerializedPlanParser::convertStructFieldTy
     UINT_CONVERT(type, field, Int16)
     UINT_CONVERT(type, field, Int32)
     UINT_CONVERT(type, field, Int64)
-    throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Not valid 
interger type: {}", type->getName());
+    throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Not valid interger 
type: {}", type->getName());
 #undef UINT_CONVERT
 }
 
@@ -1349,16 +1351,17 @@ ActionsDAGPtr SerializedPlanParser::parseJsonTuple(
         throw Exception(ErrorCodes::BAD_ARGUMENTS, "The function json_tuple 
should has at least 2 arguments.");
     }
     auto first_arg = args[0].value();
-    const DB::ActionsDAG::Node * json_expr_node = parseExpression(actions_dag, 
first_arg);
+    const ActionsDAG::Node * json_expr_node = parseExpression(actions_dag, 
first_arg);
     std::string extract_expr = "Tuple(";
     for (int i = 1; i < args.size(); i++)
     {
         auto arg_value = args[i].value();
         if (!arg_value.has_literal())
         {
-            throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "The arguments 
of function {} must be string literal", function_name);
+            throw Exception(ErrorCodes::BAD_ARGUMENTS, "The arguments of 
function {} must be string literal", function_name);
         }
-        DB::Field f = arg_value.literal().string();
+
+        Field f = arg_value.literal().string();
         std::string s;
         if (f.tryGet(s))
         {
@@ -1370,7 +1373,7 @@ ActionsDAGPtr SerializedPlanParser::parseJsonTuple(
         }
     }
     extract_expr.append(")");
-    const DB::ActionsDAG::Node * extract_expr_node = addColumn(actions_dag, 
std::make_shared<DataTypeString>(), extract_expr);
+    const ActionsDAG::Node * extract_expr_node = addColumn(actions_dag, 
std::make_shared<DataTypeString>(), extract_expr);
     auto json_extract_builder = FunctionFactory::instance().get("JSONExtract", 
context);
     auto json_extract_result_name = "JSONExtract(" + 
json_expr_node->result_name + "," + extract_expr_node->result_name + ")";
     const ActionsDAG::Node * json_extract_node
@@ -1396,9 +1399,9 @@ ActionsDAGPtr SerializedPlanParser::parseJsonTuple(
 }
 
 const ActionsDAG::Node *
-SerializedPlanParser::toFunctionNode(ActionsDAGPtr actions_dag, const String & 
function, const DB::ActionsDAG::NodeRawConstPtrs & args)
+SerializedPlanParser::toFunctionNode(ActionsDAGPtr actions_dag, const String & 
function, const ActionsDAG::NodeRawConstPtrs & args)
 {
-    auto function_builder = DB::FunctionFactory::instance().get(function, 
context);
+    auto function_builder = FunctionFactory::instance().get(function, context);
     std::string args_name = join(args, ',');
     auto result_name = function + "(" + args_name + ")";
     const auto * function_node = &actions_dag->addFunction(function_builder, 
args, result_name);
@@ -1632,48 +1635,51 @@ const ActionsDAG::Node * 
SerializedPlanParser::parseExpression(ActionsDAGPtr act
         case substrait::Expression::RexTypeCase::kCast: {
             if (!rel.cast().has_type() || !rel.cast().has_input())
                 throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type 
or input in cast node.");
-            DB::ActionsDAG::NodeRawConstPtrs args;
+            ActionsDAG::NodeRawConstPtrs args;
 
             const auto & input = rel.cast().input();
             args.emplace_back(parseExpression(actions_dag, input));
 
             const auto & substrait_type = rel.cast().type();
+            const auto & input_type = args[0]->result_type;
+            DataTypePtr non_nullable_input_type = removeNullable(input_type);
+            DataTypePtr output_type = TypeParser::parseType(substrait_type);
+            DataTypePtr non_nullable_output_type = removeNullable(output_type);
+
             const ActionsDAG::Node * function_node = nullptr;
-            if (DB::isString(DB::removeNullable(args.back()->result_type)) && 
substrait_type.has_date())
+            if (substrait_type.has_binary())
             {
-                function_node = toFunctionNode(actions_dag, "sparkToDate", 
args);
+                /// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
+                function_node = toFunctionNode(actions_dag, 
"reinterpretAsStringSpark", args);
             }
-            else if 
(DB::isString(DB::removeNullable(args.back()->result_type)) && 
substrait_type.has_timestamp())
-            {
+            else if (isString(non_nullable_input_type) && 
isDate32(non_nullable_output_type))
+                function_node = toFunctionNode(actions_dag, "sparkToDate", 
args);
+            else if (isString(non_nullable_input_type) && 
isDateTime64(non_nullable_output_type))
                 function_node = toFunctionNode(actions_dag, "sparkToDateTime", 
args);
+            else if (isDecimal(non_nullable_input_type) && 
isString(non_nullable_output_type))
+            {
+                /// Spark cast(x as STRING) if x is Decimal -> CH 
toDecimalString(x, scale)
+                UInt8 scale = getDecimalScale(*non_nullable_input_type);
+                args.emplace_back(addColumn(actions_dag, 
std::make_shared<DataTypeUInt8>(), Field(scale)));
+                function_node = toFunctionNode(actions_dag, "toDecimalString", 
args);
             }
-            else if (substrait_type.has_binary())
+            else if (isFloat(non_nullable_input_type) && 
isInt(non_nullable_output_type))
             {
-                // Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
-                function_node = toFunctionNode(actions_dag, 
"reinterpretAsStringSpark", args);
+                String function_name = "sparkCastFloatTo" + 
non_nullable_output_type->getName();
+                function_node = toFunctionNode(actions_dag, function_name, 
args);
             }
             else
             {
-                DataTypePtr ch_type = TypeParser::parseType(substrait_type);
-                if (DB::isString(DB::removeNullable(ch_type)) && 
isDecimalOrNullableDecimal(args[0]->result_type))
+                if (isString(non_nullable_input_type) && 
isInt(non_nullable_output_type))
                 {
-                    UInt8 scale = 
getDecimalScale(*DB::removeNullable(args[0]->result_type));
-                    args.emplace_back(addColumn(actions_dag, 
std::make_shared<DataTypeUInt8>(), Field(scale)));
-                    function_node = toFunctionNode(actions_dag, 
"toDecimalString", args);
-                }
-                else
-                {
-                    if (isFloat(DB::removeNullable(args[0]->result_type)) && 
isInt(DB::removeNullable(ch_type)))
-                    {
-                        String function_name = "sparkCastFloatTo" + 
DB::removeNullable(ch_type)->getName();
-                        function_node = toFunctionNode(actions_dag, 
function_name, args);
-                    }
-                    else
-                    {
-                        args.emplace_back(addColumn(actions_dag, 
std::make_shared<DataTypeString>(), ch_type->getName()));
-                        function_node = toFunctionNode(actions_dag, "CAST", 
args);
-                    }
+                    /// Spark cast(x as INT) if x is String -> CH cast(trim(x) 
as INT)
+                    /// Refer to 
https://github.com/apache/incubator-gluten/issues/4956
+                    args[0] = toFunctionNode(actions_dag, "trim", {args[0]});
                 }
+
+                /// Common process
+                args.emplace_back(addColumn(actions_dag, 
std::make_shared<DataTypeString>(), output_type->getName()));
+                function_node = toFunctionNode(actions_dag, "CAST", args);
             }
 
             actions_dag->addOrReplaceInOutputs(*function_node);
@@ -1682,13 +1688,13 @@ const ActionsDAG::Node * 
SerializedPlanParser::parseExpression(ActionsDAGPtr act
 
         case substrait::Expression::RexTypeCase::kIfThen: {
             const auto & if_then = rel.if_then();
-            DB::FunctionOverloadResolverPtr function_ptr = nullptr;
+            FunctionOverloadResolverPtr function_ptr = nullptr;
             auto condition_nums = if_then.ifs_size();
             if (condition_nums == 1)
-                function_ptr = DB::FunctionFactory::instance().get("if", 
context);
+                function_ptr = FunctionFactory::instance().get("if", context);
             else
-                function_ptr = DB::FunctionFactory::instance().get("multiIf", 
context);
-            DB::ActionsDAG::NodeRawConstPtrs args;
+                function_ptr = FunctionFactory::instance().get("multiIf", 
context);
+            ActionsDAG::NodeRawConstPtrs args;
 
             for (int i = 0; i < condition_nums; ++i)
             {
@@ -1727,7 +1733,7 @@ const ActionsDAG::Node * 
SerializedPlanParser::parseExpression(ActionsDAGPtr act
             if (!options[0].has_literal())
                 throw Exception(ErrorCodes::LOGICAL_ERROR, "Options of 
SingularOrList must have literal type");
 
-            DB::ActionsDAG::NodeRawConstPtrs args;
+            ActionsDAG::NodeRawConstPtrs args;
             args.emplace_back(parseExpression(actions_dag, 
rel.singular_or_list().value()));
 
             bool nullable = false;
@@ -1781,7 +1787,7 @@ const ActionsDAG::Node * 
SerializedPlanParser::parseExpression(ActionsDAGPtr act
                 /// In CH: return `false`
                 /// So we used if(a, b, c) cast `false` to `null` if sets has 
`null`
                 auto type = wrapNullableType(true, function_node->result_type);
-                DB::ActionsDAG::NodeRawConstPtrs cast_args(
+                ActionsDAG::NodeRawConstPtrs cast_args(
                     {function_node, addColumn(actions_dag, type, true), 
addColumn(actions_dag, type, Field())});
                 auto cast = FunctionFactory::instance().get("if", context);
                 function_node = toFunctionNode(actions_dag, "if", cast_args);
@@ -2004,7 +2010,7 @@ ASTPtr ASTParser::parseArgumentToAST(const Names & names, 
const substrait::Expre
             const auto & if_then = rel.if_then();
             auto condition_nums = if_then.ifs_size();
             std::string ch_function_name = condition_nums == 1 ? "if" : 
"multiIf";
-            auto function_multi_if = 
DB::FunctionFactory::instance().get(ch_function_name, context);
+            auto function_multi_if = 
FunctionFactory::instance().get(ch_function_name, context);
             ASTs args;
 
             for (int i = 0; i < condition_nums; ++i)
@@ -2088,7 +2094,7 @@ ASTPtr ASTParser::parseArgumentToAST(const Names & names, 
const substrait::Expre
     }
 }
 
-void SerializedPlanParser::removeNullable(const std::set<String> & 
require_columns, ActionsDAGPtr actions_dag)
+void SerializedPlanParser::removeNullableForRequiredColumns(const 
std::set<String> & require_columns, ActionsDAGPtr actions_dag)
 {
     for (const auto & item : require_columns)
     {
@@ -2140,14 +2146,14 @@ void LocalExecutor::execute(QueryPlanPtr query_plan)
     current_query_plan = std::move(query_plan);
     auto * logger = &Poco::Logger::get("LocalExecutor");
 
-    DB::QueryPriorities priorities;
-    auto query_status = std::make_shared<DB::QueryStatus>(
+    QueryPriorities priorities;
+    auto query_status = std::make_shared<QueryStatus>(
         context,
         "",
         context->getClientInfo(),
         priorities.insert(static_cast<int>(settings.priority)),
-        DB::CurrentThread::getGroup(),
-        DB::IAST::QueryKind::Select,
+        CurrentThread::getGroup(),
+        IAST::QueryKind::Select,
         settings,
         0);
 
@@ -2201,7 +2207,7 @@ bool LocalExecutor::hasNext()
             has_next = true;
         }
     }
-    catch (DB::Exception & e)
+    catch (Exception & e)
     {
         LOG_ERROR(
             &Poco::Logger::get("LocalExecutor"),
@@ -2263,7 +2269,7 @@ std::string LocalExecutor::dumpPipeline()
     const auto & processors = query_pipeline.getProcessors();
     for (auto & processor : processors)
     {
-        DB::WriteBufferFromOwnString buffer;
+        WriteBufferFromOwnString buffer;
         auto data_stats = processor->getProcessorDataStats();
         buffer << "(";
         buffer << "\nexcution time: " << processor->getElapsedUs() << " us.";
@@ -2276,13 +2282,13 @@ std::string LocalExecutor::dumpPipeline()
         buffer << ")";
         processor->setDescription(buffer.str());
     }
-    DB::WriteBufferFromOwnString out;
-    DB::printPipeline(processors, out);
+    WriteBufferFromOwnString out;
+    printPipeline(processors, out);
     return out.str();
 }
 
 NonNullableColumnsResolver::NonNullableColumnsResolver(
-    const DB::Block & header_,
+    const Block & header_,
     SerializedPlanParser & parser_,
     const substrait::Expression & cond_rel_)
     : header(header_)
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h 
b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
index 1ecce8722..87c86f027 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
@@ -384,7 +384,7 @@ private:
     const ActionsDAG::Node *
     toFunctionNode(ActionsDAGPtr actions_dag, const String & function, const 
DB::ActionsDAG::NodeRawConstPtrs & args);
     // remove nullable after isNotNull
-    void removeNullable(const std::set<String> & require_columns, 
ActionsDAGPtr actions_dag);
+    void removeNullableForRequiredColumns(const std::set<String> & 
require_columns, ActionsDAGPtr actions_dag);
     std::string getUniqueName(const std::string & name) { return name + "_" + 
std::to_string(name_no++); }
     static std::pair<DataTypePtr, Field> parseLiteral(const 
substrait::Expression_Literal & literal);
     void wrapNullable(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to