This is an automated email from the ASF dual-hosted git repository.

changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new dba4439a2 [GLUTEN-6561][CH] Fix incompatiable type exception throw in 
capture function while processing array literal with `transform` (#6601)
dba4439a2 is described below

commit dba4439a20ef9ffac78b14ff0f229460a0e84089
Author: 李扬 <[email protected]>
AuthorDate: Tue Jul 30 17:16:00 2024 +0800

    [GLUTEN-6561][CH] Fix incompatiable type exception throw in capture 
function while processing array literal with `transform` (#6601)
    
    * fix style
    
    * fix issue https://github.com/apache/incubator-gluten/issues/6561
    
    * add uts
    
    * add uts
    
    * fix uts
    
    * fix style
    
    * ignore some checks when spark 3.3
---
 .../GlutenClickHouseNativeWriteTableSuite.scala    | 28 +++++-----------------
 cpp-ch/local-engine/Common/CHUtil.cpp              | 15 +++++++++++-
 cpp-ch/local-engine/Common/CHUtil.h                |  7 ++++++
 .../CommonScalarFunctionParser.cpp                 |  2 +-
 .../arrayHighOrderFunctions.cpp                    | 24 +++++++++++++------
 5 files changed, 45 insertions(+), 31 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
index 99f946cd7..578c43292 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
@@ -903,33 +903,17 @@ class GlutenClickHouseNativeWriteTableSuite
                | ) partitioned by (day string)
                | stored as $format""".stripMargin
 
-          // FIXME:
-          // Spark analyzer(>=3.4) will resolve map type to
-          //   map_from_arrays(transform(map_keys(map('t1','a','t2','b')), 
v->v),
-          //                   transform(map_values(map('t1','a','t2','b')), 
v->v))
-          // which cause core dump. see 
https://github.com/apache/incubator-gluten/issues/6561
-          // for details.
           val insert_sql =
-            if (isSparkVersionLE("3.3")) {
-              s"""insert overwrite $table_name partition (day)
-                 |select id as a,
-                 |       str_to_map(concat('t1:','a','&t2:','b'),'&',':'),
-                 |       struct('1', null) as c,
-                 |       '2024-01-08' as day
-                 |from range(10)""".stripMargin
-            } else {
-              s"""insert overwrite $table_name partition (day)
-                 |select id as a,
-                 |       map('t1', 'a', 't2', 'b'),
-                 |       struct('1', null) as c,
-                 |       '2024-01-08' as day
-                 |from range(10)""".stripMargin
-            }
+            s"""insert overwrite $table_name partition (day)
+               |select id as a,
+               |       str_to_map(concat('t1:','a','&t2:','b'),'&',':'),
+               |       struct('1', null) as c,
+               |       '2024-01-08' as day
+               |from range(10)""".stripMargin
           (table_name, create_sql, insert_sql)
       },
       (table_name, _) =>
         if (isSparkVersionGE("3.4")) {
-          // FIXME: Don't Know Why Failed
           compareResultsAgainstVanillaSpark(
             s"select * from $table_name",
             compareResult = true,
diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp 
b/cpp-ch/local-engine/Common/CHUtil.cpp
index a606a06b5..003accf00 100644
--- a/cpp-ch/local-engine/Common/CHUtil.cpp
+++ b/cpp-ch/local-engine/Common/CHUtil.cpp
@@ -477,6 +477,19 @@ const DB::ActionsDAG::Node * 
ActionsDAGUtil::convertNodeType(
         DB::createInternalCastOverloadResolver(cast_type, 
std::move(diagnostic)), std::move(children), result_name);
 }
 
+const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeTypeIfNeeded(
+    DB::ActionsDAGPtr & actions_dag,
+    const DB::ActionsDAG::Node * node,
+    const DB::DataTypePtr & dst_type,
+    const std::string & result_name,
+    CastType cast_type)
+{
+    if (node->result_type->equals(*dst_type))
+        return node;
+
+    return convertNodeType(actions_dag, node, dst_type->getName(), 
result_name, cast_type);
+}
+
 String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline)
 {
     DB::WriteBufferFromOwnString buf;
@@ -844,7 +857,7 @@ void 
BackendInitializerUtil::initContexts(DB::Context::ConfigurationPtr config)
         size_t index_uncompressed_cache_size = 
config->getUInt64("index_uncompressed_cache_size", 
DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE);
         double index_uncompressed_cache_size_ratio = 
config->getDouble("index_uncompressed_cache_size_ratio", 
DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO);
         
global_context->setIndexUncompressedCache(index_uncompressed_cache_policy, 
index_uncompressed_cache_size, index_uncompressed_cache_size_ratio);
-        
+
         String index_mark_cache_policy = 
config->getString("index_mark_cache_policy", DEFAULT_INDEX_MARK_CACHE_POLICY);
         size_t index_mark_cache_size = 
config->getUInt64("index_mark_cache_size", DEFAULT_INDEX_MARK_CACHE_MAX_SIZE);
         double index_mark_cache_size_ratio = 
config->getDouble("index_mark_cache_size_ratio", 
DEFAULT_INDEX_MARK_CACHE_SIZE_RATIO);
diff --git a/cpp-ch/local-engine/Common/CHUtil.h 
b/cpp-ch/local-engine/Common/CHUtil.h
index 05b730552..8a2a32df3 100644
--- a/cpp-ch/local-engine/Common/CHUtil.h
+++ b/cpp-ch/local-engine/Common/CHUtil.h
@@ -132,6 +132,13 @@ public:
         const std::string & type_name,
         const std::string & result_name = "",
         DB::CastType cast_type = DB::CastType::nonAccurate);
+
+    static const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
+        DB::ActionsDAGPtr & actions_dag,
+        const DB::ActionsDAG::Node * node,
+        const DB::DataTypePtr & dst_type,
+        const std::string & result_name = "",
+        DB::CastType cast_type = DB::CastType::nonAccurate);
 };
 
 class QueryPipelineUtil
diff --git 
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
 
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
index e4855b507..726d1683d 100644
--- 
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
+++ 
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
@@ -64,7 +64,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ToUnixTimestamp, 
to_unix_timestamp, parse
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Position, positive, identity);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Negative, negative, negate);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Pmod, pmod, pmod);
-REGISTER_COMMON_SCALAR_FUNCTION_PARSER(abs, abs, abs);
+REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Abs, abs, abs);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Ceil, ceil, ceil);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Round, round, roundHalfUp);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Bround, bround, roundBankers);
diff --git 
a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
index eacd72ed0..f9f093cba 100644
--- 
a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
+++ 
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
@@ -15,16 +15,17 @@
  * limitations under the License.
  */
 
-#include <Parser/FunctionParser.h>
-#include <Common/Exception.h>
-#include <Poco/Logger.h>
-#include <Common/logger_useful.h>
-#include <Common/CHUtil.h>
+#include <Core/Types.h>
+#include <DataTypes/DataTypeArray.h>
 #include <DataTypes/DataTypeFunction.h>
 #include <DataTypes/DataTypeNullable.h>
-#include <Core/Types.h>
+#include <Parser/FunctionParser.h>
 #include <Parser/TypeParser.h>
 #include <Parser/scalar_function_parser/lambdaFunction.h>
+#include <Poco/Logger.h>
+#include <Common/CHUtil.h>
+#include <Common/Exception.h>
+#include <Common/logger_useful.h>
 
 namespace DB::ErrorCodes
 {
@@ -90,7 +91,16 @@ public:
         assert(parsed_args.size() == 2);
         if (lambda_args.size() == 1)
         {
-            return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], 
parsed_args[0]});
+            /// Convert Array(T) to Array(U) if needed, Array(T) is the type 
of the first argument of transform.
+            /// U is the argument type of lambda function. In some cases 
Array(T) is not equal to Array(U).
+            /// e.g. in the second query of 
https://github.com/apache/incubator-gluten/issues/6561, T is String, and U is 
Nullable(String)
+            /// The difference of both types will result in runtime exceptions 
in function capture.
+            const auto & src_array_type = parsed_args[0]->result_type;
+            DataTypePtr dst_array_type = 
std::make_shared<DataTypeArray>(lambda_args.front().type);
+            if (isNullableOrLowCardinalityNullable(src_array_type))
+                dst_array_type = 
std::make_shared<DataTypeNullable>(dst_array_type);
+            const auto * dst_array_arg = 
ActionsDAGUtil::convertNodeTypeIfNeeded(actions_dag, parsed_args[0], 
dst_array_type);
+            return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], 
dst_array_arg});
         }
 
         /// transform with index argument.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to