This is an automated email from the ASF dual-hosted git repository.

lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c000cd923 cast for values built from nothing types (#9042)
5c000cd923 is described below

commit 5c000cd923c1f5dec7397bf5b6916edc6df52d6f
Author: lgbo <[email protected]>
AuthorDate: Wed Mar 19 12:13:55 2025 +0800

    cast for values built from nothing types (#9042)
---
 .../GlutenClickHouseTPCHSaltNullParquetSuite.scala |  67 +++++++
 cpp-ch/local-engine/Parser/ExpressionParser.cpp    | 216 +++++++++++++--------
 cpp-ch/local-engine/Parser/ExpressionParser.h      |   5 +
 3 files changed, 205 insertions(+), 83 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 3ce03565e8..affc35bc27 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -3396,5 +3396,72 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends 
GlutenClickHouseTPCHAbstr
     compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
   }
 
+  test("GLUTEN-9032 default values from nothing types") {
+    val sql1 =
+      """
+        |select a, b, c from (
+        |  select
+        |    n_regionkey as a, n_nationkey as b, array() as c
+        |  from nation where n_nationkey % 2 = 1
+        |  union all
+        |  select n_nationkey as a, n_regionkey as b, array('123') as c
+        |  from nation where n_nationkey % 2 = 0
+        |)
+      """.stripMargin
+    compareResultsAgainstVanillaSpark(sql1, true, { _ => })
+
+    val sql2 =
+      """
+        |select a, b, c from (
+        |  select
+        |    n_regionkey as a, n_nationkey as b, array() as c
+        |  from nation where n_nationkey % 2 = 1
+        |  union all
+        |  select n_nationkey as a, n_regionkey as b, array('123', null) as c
+        |  from nation where n_nationkey % 2 = 0
+        |)
+      """.stripMargin
+    compareResultsAgainstVanillaSpark(sql2, true, { _ => })
+
+    val sql3 =
+      """
+        |select a, b, c from (
+        |  select
+        |    n_regionkey as a, n_nationkey as b, array() as c
+        |  from nation where n_nationkey % 2 = 1
+        |  union all
+        |  select n_nationkey as a, n_regionkey as b, array(null) as c
+        |  from nation where n_nationkey % 2 = 0
+        |)
+      """.stripMargin
+    compareResultsAgainstVanillaSpark(sql3, true, { _ => })
+
+    val sql4 =
+      """
+        |select a, b, c from (
+        |  select
+        |    n_regionkey as a, n_nationkey as b, map() as c
+        |  from nation where n_nationkey % 2 = 1
+        |  union all
+        |  select n_nationkey as a, n_regionkey as b, map('123', 1) as c
+        |  from nation where n_nationkey % 2 = 0
+        |)
+      """.stripMargin
+    compareResultsAgainstVanillaSpark(sql4, true, { _ => })
+
+    val sql5 =
+      """
+        |select a, b, c from (
+        |  select
+        |    n_regionkey as a, n_nationkey as b, map() as c
+        |  from nation where n_nationkey % 2 = 1
+        |  union all
+        |  select n_nationkey as a, n_regionkey as b, map('123', null) as c
+        |  from nation where n_nationkey % 2 = 0
+        |)
+      """.stripMargin
+    compareResultsAgainstVanillaSpark(sql5, true, { _ => })
+
+  }
 }
 // scalastyle:on line.size.limit
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.cpp 
b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
index 4261241f25..f297f94fc5 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
@@ -305,89 +305,7 @@ ExpressionParser::NodeRawConstPtr 
ExpressionParser::parseExpression(ActionsDAG &
         case substrait::Expression::RexTypeCase::kCast: {
             if (!rel.cast().has_type() || !rel.cast().has_input())
                 throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type 
or input in cast node.");
-            ActionsDAG::NodeRawConstPtrs args;
-
-            const auto & input = rel.cast().input();
-            args.emplace_back(parseExpression(actions_dag, input));
-
-            const auto & substrait_type = rel.cast().type();
-            const auto & input_type = args[0]->result_type;
-            DataTypePtr denull_input_type = removeNullable(input_type);
-            DataTypePtr output_type = TypeParser::parseType(substrait_type);
-            DataTypePtr denull_output_type = removeNullable(output_type);
-            const ActionsDAG::Node * result_node = nullptr;
-            if (substrait_type.has_binary())
-            {
-                /// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
-                result_node = toFunctionNode(actions_dag, 
"reinterpretAsStringSpark", args);
-            }
-            else if (isString(denull_input_type) && 
isDate32(denull_output_type))
-                result_node = toFunctionNode(actions_dag, "sparkToDate", args);
-            else if (isString(denull_input_type) && 
isDateTime64(denull_output_type))
-                result_node = toFunctionNode(actions_dag, "sparkToDateTime", 
args);
-            else if (isDecimal(denull_input_type) && 
isString(denull_output_type))
-            {
-                /// Spark cast(x as STRING) if x is Decimal -> CH 
toDecimalString(x, scale)
-                UInt8 scale = getDecimalScale(*denull_input_type);
-                args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeUInt8>(), Field(scale)));
-                result_node = toFunctionNode(actions_dag, "toDecimalString", 
args);
-            }
-            else if (isFloat(denull_input_type) && isInt(denull_output_type))
-            {
-                String function_name = "sparkCastFloatTo" + 
denull_output_type->getName();
-                result_node = toFunctionNode(actions_dag, function_name, args);
-            }
-            else if (isFloat(denull_input_type) && 
isString(denull_output_type))
-                result_node = toFunctionNode(actions_dag, 
"sparkCastFloatToString", args);
-            else if ((isDecimal(denull_input_type) || 
isNativeNumber(denull_input_type)) && substrait_type.has_decimal())
-            {
-                int precision = substrait_type.decimal().precision();
-                int scale = substrait_type.decimal().scale();
-                if (precision)
-                {
-                    args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeInt32>(), precision));
-                    args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeInt32>(), scale));
-                    result_node = toFunctionNode(actions_dag, 
"checkDecimalOverflowSparkOrNull", args);
-                }
-            }
-            else if (isMap(denull_input_type) && isString(denull_output_type))
-            {
-                // ISSUE-7389: spark cast(map to string) has different 
behavior with CH cast(map to string)
-                auto map_input_type = std::static_pointer_cast<const 
DataTypeMap>(denull_input_type);
-                args.emplace_back(addConstColumn(actions_dag, 
map_input_type->getKeyType(), map_input_type->getKeyType()->getDefault()));
-                args.emplace_back(
-                    addConstColumn(actions_dag, 
map_input_type->getValueType(), map_input_type->getValueType()->getDefault()));
-                result_node = toFunctionNode(actions_dag, 
"sparkCastMapToString", args);
-            }
-            else if (isArray(denull_input_type) && 
isString(denull_output_type))
-            {
-                // ISSUE-7602: spark cast(array to string) has different 
result with CH cast(array to string)
-                result_node = toFunctionNode(actions_dag, 
"sparkCastArrayToString", args);
-            }
-            else if (isString(denull_input_type) && substrait_type.has_bool_())
-            {
-                /// cast(string to boolean)
-                args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeString>(), output_type->getName()));
-                result_node = toFunctionNode(actions_dag, 
"accurateCastOrNull", args);
-            }
-            else if (isString(denull_input_type) && isInt(denull_output_type))
-            {
-                /// Spark cast(x as INT) if x is String -> CH cast(trim(x) as 
INT)
-                /// Refer to 
https://github.com/apache/incubator-gluten/issues/4956 and 
https://github.com/apache/incubator-gluten/issues/8598
-                auto trim_str_arg = addConstColumn(actions_dag, 
std::make_shared<DataTypeString>(), " \t\n\r\f");
-                args[0] = toFunctionNode(actions_dag, "trimBothSpark", 
{args[0], trim_str_arg});
-                args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeString>(), output_type->getName()));
-                result_node = toFunctionNode(actions_dag, "CAST", args);
-            }
-            else
-            {
-                /// Common process: CAST(input, type)
-                args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeString>(), output_type->getName()));
-                result_node = toFunctionNode(actions_dag, "CAST", args);
-            }
-
-            actions_dag.addOrReplaceInOutputs(*result_node);
-            return result_node;
+            return parseCast(actions_dag, rel);
         }
 
         case substrait::Expression::RexTypeCase::kIfThen: {
@@ -516,6 +434,138 @@ ExpressionParser::NodeRawConstPtr 
ExpressionParser::parseExpression(ActionsDAG &
     }
 }
 
+bool ExpressionParser::isValueFromNothingType(const substrait::Expression & 
expr) const
+{
+    const auto & cast_input = expr.cast().input();
+    // null literal
+    if (cast_input.has_literal() && cast_input.literal().has_null() && 
cast_input.literal().null().has_nothing())
+        return true;
+    else if (cast_input.has_scalar_function())
+    {
+        auto function_name = 
getFunctionNameInSignature(cast_input.scalar_function());
+        // empty map
+        if (cast_input.scalar_function().output_type().has_map())
+        {
+            const auto & map_type = 
cast_input.scalar_function().output_type().map();
+            if (map_type.key().has_nothing() && map_type.value().has_nothing())
+                return true;
+        }
+        // empty array
+        else if (cast_input.scalar_function().output_type().has_list())
+        {
+            const auto & list_type = 
cast_input.scalar_function().output_type().list();
+            if (list_type.type().has_nothing())
+                return true;
+        }
+    }
+    return false;
+}
+
+ExpressionParser::NodeRawConstPtr ExpressionParser::parseCast(DB::ActionsDAG & 
actions_dag, const substrait::Expression & cast_expr) const
+{
+    if (isValueFromNothingType(cast_expr))
+        return parseNothingValuesCast(actions_dag, cast_expr);
+    return parseNormalValuesCast(actions_dag, cast_expr);
+}
+
+// Build a default value from the output type of `cast` when the `cast`'s 
input is built from `nothing` type.
+// `nothing` type is wrapped in nullable in `TypeParser`, it could cause 
nullability missmatch.
+ExpressionParser::NodeRawConstPtr
+ExpressionParser::parseNothingValuesCast(DB::ActionsDAG & actions_dag, const 
substrait::Expression & cast_expr) const
+{
+    auto ch_type = TypeParser::parseType(cast_expr.cast().type());
+    // use the target type to create the default value.
+    auto default_value = ch_type->getDefault();
+    return addConstColumn(actions_dag, ch_type, default_value);
+}
+
+ExpressionParser::NodeRawConstPtr
+ExpressionParser::parseNormalValuesCast(DB::ActionsDAG & actions_dag, const 
substrait::Expression & cast_expr) const
+{
+    ActionsDAG::NodeRawConstPtrs args;
+    const auto & input = cast_expr.cast().input();
+    args.emplace_back(parseExpression(actions_dag, input));
+
+    const auto & substrait_type = cast_expr.cast().type();
+    const auto & input_type = args[0]->result_type;
+    DataTypePtr denull_input_type = removeNullable(input_type);
+    DataTypePtr output_type = TypeParser::parseType(substrait_type);
+    DataTypePtr denull_output_type = removeNullable(output_type);
+    const ActionsDAG::Node * result_node = nullptr;
+    if (substrait_type.has_binary())
+    {
+        /// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
+        result_node = toFunctionNode(actions_dag, "reinterpretAsStringSpark", 
args);
+    }
+    else if (isString(denull_input_type) && isDate32(denull_output_type))
+        result_node = toFunctionNode(actions_dag, "sparkToDate", args);
+    else if (isString(denull_input_type) && isDateTime64(denull_output_type))
+        result_node = toFunctionNode(actions_dag, "sparkToDateTime", args);
+    else if (isDecimal(denull_input_type) && isString(denull_output_type))
+    {
+        /// Spark cast(x as STRING) if x is Decimal -> CH toDecimalString(x, 
scale)
+        UInt8 scale = getDecimalScale(*denull_input_type);
+        args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeUInt8>(), Field(scale)));
+        result_node = toFunctionNode(actions_dag, "toDecimalString", args);
+    }
+    else if (isFloat(denull_input_type) && isInt(denull_output_type))
+    {
+        String function_name = "sparkCastFloatTo" + 
denull_output_type->getName();
+        result_node = toFunctionNode(actions_dag, function_name, args);
+    }
+    else if (isFloat(denull_input_type) && isString(denull_output_type))
+        result_node = toFunctionNode(actions_dag, "sparkCastFloatToString", 
args);
+    else if ((isDecimal(denull_input_type) || 
isNativeNumber(denull_input_type)) && substrait_type.has_decimal())
+    {
+        int precision = substrait_type.decimal().precision();
+        int scale = substrait_type.decimal().scale();
+        if (precision)
+        {
+            args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeInt32>(), precision));
+            args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeInt32>(), scale));
+            result_node = toFunctionNode(actions_dag, 
"checkDecimalOverflowSparkOrNull", args);
+        }
+    }
+    else if (isMap(denull_input_type) && isString(denull_output_type))
+    {
+        // ISSUE-7389: spark cast(map to string) has different behavior with 
CH cast(map to string)
+        auto map_input_type = std::static_pointer_cast<const 
DataTypeMap>(denull_input_type);
+        args.emplace_back(addConstColumn(actions_dag, 
map_input_type->getKeyType(), map_input_type->getKeyType()->getDefault()));
+        args.emplace_back(addConstColumn(actions_dag, 
map_input_type->getValueType(), map_input_type->getValueType()->getDefault()));
+        result_node = toFunctionNode(actions_dag, "sparkCastMapToString", 
args);
+    }
+    else if (isArray(denull_input_type) && isString(denull_output_type))
+    {
+        // ISSUE-7602: spark cast(array to string) has different result with 
CH cast(array to string)
+        result_node = toFunctionNode(actions_dag, "sparkCastArrayToString", 
args);
+    }
+    else if (isString(denull_input_type) && substrait_type.has_bool_())
+    {
+        /// cast(string to boolean)
+        args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeString>(), output_type->getName()));
+        result_node = toFunctionNode(actions_dag, "accurateCastOrNull", args);
+    }
+    else if (isString(denull_input_type) && isInt(denull_output_type))
+    {
+        /// Spark cast(x as INT) if x is String -> CH cast(trim(x) as INT)
+        /// Refer to https://github.com/apache/incubator-gluten/issues/4956 
and https://github.com/apache/incubator-gluten/issues/8598
+        auto trim_str_arg = addConstColumn(actions_dag, 
std::make_shared<DataTypeString>(), " \t\n\r\f");
+        args[0] = toFunctionNode(actions_dag, "trimBothSpark", {args[0], 
trim_str_arg});
+        args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeString>(), output_type->getName()));
+        result_node = toFunctionNode(actions_dag, "CAST", args);
+    }
+    else
+    {
+        /// Common process: CAST(input, type)
+        args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeString>(), output_type->getName()));
+        result_node = toFunctionNode(actions_dag, "CAST", args);
+    }
+
+    actions_dag.addOrReplaceInOutputs(*result_node);
+    return result_node;
+}
+
+
 DB::ActionsDAG
 ExpressionParser::expressionsToActionsDAG(const 
std::vector<substrait::Expression> & expressions, const DB::Block & header) 
const
 {
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.h 
b/cpp-ch/local-engine/Parser/ExpressionParser.h
index 1e4a48282a..6e7ae8a924 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.h
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.h
@@ -89,5 +89,10 @@ private:
 
     static bool areEqualNodes(NodeRawConstPtr a, NodeRawConstPtr b);
     NodeRawConstPtr findFirstStructureEqualNode(NodeRawConstPtr target, const 
DB::ActionsDAG & actions_dag) const;
+
+    NodeRawConstPtr parseCast(DB::ActionsDAG & actions_dag, const 
substrait::Expression & cast_expr) const;
+    bool isValueFromNothingType(const substrait::Expression & expr) const;
+    NodeRawConstPtr parseNothingValuesCast(DB::ActionsDAG & actions_dag, const 
substrait::Expression & cast_expr) const;
+    NodeRawConstPtr parseNormalValuesCast(DB::ActionsDAG & actions_dag, const 
substrait::Expression & cast_expr) const;
 };
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to