This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new dba4439a2 [GLUTEN-6561][CH] Fix incompatiable type exception throw in
capture function while processing array literal with `transform` (#6601)
dba4439a2 is described below
commit dba4439a20ef9ffac78b14ff0f229460a0e84089
Author: 李扬 <[email protected]>
AuthorDate: Tue Jul 30 17:16:00 2024 +0800
[GLUTEN-6561][CH] Fix incompatiable type exception throw in capture
function while processing array literal with `transform` (#6601)
* fix style
* fix issue https://github.com/apache/incubator-gluten/issues/6561
* add uts
* add uts
* fix uts
* fix style
* ignore some checks when spark 3.3
---
.../GlutenClickHouseNativeWriteTableSuite.scala | 28 +++++-----------------
cpp-ch/local-engine/Common/CHUtil.cpp | 15 +++++++++++-
cpp-ch/local-engine/Common/CHUtil.h | 7 ++++++
.../CommonScalarFunctionParser.cpp | 2 +-
.../arrayHighOrderFunctions.cpp | 24 +++++++++++++------
5 files changed, 45 insertions(+), 31 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
index 99f946cd7..578c43292 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
@@ -903,33 +903,17 @@ class GlutenClickHouseNativeWriteTableSuite
| ) partitioned by (day string)
| stored as $format""".stripMargin
- // FIXME:
- // Spark analyzer(>=3.4) will resolve map type to
- // map_from_arrays(transform(map_keys(map('t1','a','t2','b')),
v->v),
- // transform(map_values(map('t1','a','t2','b')),
v->v))
- // which cause core dump. see
https://github.com/apache/incubator-gluten/issues/6561
- // for details.
val insert_sql =
- if (isSparkVersionLE("3.3")) {
- s"""insert overwrite $table_name partition (day)
- |select id as a,
- | str_to_map(concat('t1:','a','&t2:','b'),'&',':'),
- | struct('1', null) as c,
- | '2024-01-08' as day
- |from range(10)""".stripMargin
- } else {
- s"""insert overwrite $table_name partition (day)
- |select id as a,
- | map('t1', 'a', 't2', 'b'),
- | struct('1', null) as c,
- | '2024-01-08' as day
- |from range(10)""".stripMargin
- }
+ s"""insert overwrite $table_name partition (day)
+ |select id as a,
+ | str_to_map(concat('t1:','a','&t2:','b'),'&',':'),
+ | struct('1', null) as c,
+ | '2024-01-08' as day
+ |from range(10)""".stripMargin
(table_name, create_sql, insert_sql)
},
(table_name, _) =>
if (isSparkVersionGE("3.4")) {
- // FIXME: Don't Know Why Failed
compareResultsAgainstVanillaSpark(
s"select * from $table_name",
compareResult = true,
diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp
b/cpp-ch/local-engine/Common/CHUtil.cpp
index a606a06b5..003accf00 100644
--- a/cpp-ch/local-engine/Common/CHUtil.cpp
+++ b/cpp-ch/local-engine/Common/CHUtil.cpp
@@ -477,6 +477,19 @@ const DB::ActionsDAG::Node *
ActionsDAGUtil::convertNodeType(
DB::createInternalCastOverloadResolver(cast_type,
std::move(diagnostic)), std::move(children), result_name);
}
+const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeTypeIfNeeded(
+ DB::ActionsDAGPtr & actions_dag,
+ const DB::ActionsDAG::Node * node,
+ const DB::DataTypePtr & dst_type,
+ const std::string & result_name,
+ CastType cast_type)
+{
+ if (node->result_type->equals(*dst_type))
+ return node;
+
+ return convertNodeType(actions_dag, node, dst_type->getName(),
result_name, cast_type);
+}
+
String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline)
{
DB::WriteBufferFromOwnString buf;
@@ -844,7 +857,7 @@ void
BackendInitializerUtil::initContexts(DB::Context::ConfigurationPtr config)
size_t index_uncompressed_cache_size =
config->getUInt64("index_uncompressed_cache_size",
DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE);
double index_uncompressed_cache_size_ratio =
config->getDouble("index_uncompressed_cache_size_ratio",
DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO);
global_context->setIndexUncompressedCache(index_uncompressed_cache_policy,
index_uncompressed_cache_size, index_uncompressed_cache_size_ratio);
-
+
String index_mark_cache_policy =
config->getString("index_mark_cache_policy", DEFAULT_INDEX_MARK_CACHE_POLICY);
size_t index_mark_cache_size =
config->getUInt64("index_mark_cache_size", DEFAULT_INDEX_MARK_CACHE_MAX_SIZE);
double index_mark_cache_size_ratio =
config->getDouble("index_mark_cache_size_ratio",
DEFAULT_INDEX_MARK_CACHE_SIZE_RATIO);
diff --git a/cpp-ch/local-engine/Common/CHUtil.h
b/cpp-ch/local-engine/Common/CHUtil.h
index 05b730552..8a2a32df3 100644
--- a/cpp-ch/local-engine/Common/CHUtil.h
+++ b/cpp-ch/local-engine/Common/CHUtil.h
@@ -132,6 +132,13 @@ public:
const std::string & type_name,
const std::string & result_name = "",
DB::CastType cast_type = DB::CastType::nonAccurate);
+
+ static const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
+ DB::ActionsDAGPtr & actions_dag,
+ const DB::ActionsDAG::Node * node,
+ const DB::DataTypePtr & dst_type,
+ const std::string & result_name = "",
+ DB::CastType cast_type = DB::CastType::nonAccurate);
};
class QueryPipelineUtil
diff --git
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
index e4855b507..726d1683d 100644
---
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
+++
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
@@ -64,7 +64,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ToUnixTimestamp,
to_unix_timestamp, parse
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Position, positive, identity);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Negative, negative, negate);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Pmod, pmod, pmod);
-REGISTER_COMMON_SCALAR_FUNCTION_PARSER(abs, abs, abs);
+REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Abs, abs, abs);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Ceil, ceil, ceil);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Round, round, roundHalfUp);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Bround, bround, roundBankers);
diff --git
a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
index eacd72ed0..f9f093cba 100644
---
a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
+++
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
@@ -15,16 +15,17 @@
* limitations under the License.
*/
-#include <Parser/FunctionParser.h>
-#include <Common/Exception.h>
-#include <Poco/Logger.h>
-#include <Common/logger_useful.h>
-#include <Common/CHUtil.h>
+#include <Core/Types.h>
+#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeFunction.h>
#include <DataTypes/DataTypeNullable.h>
-#include <Core/Types.h>
+#include <Parser/FunctionParser.h>
#include <Parser/TypeParser.h>
#include <Parser/scalar_function_parser/lambdaFunction.h>
+#include <Poco/Logger.h>
+#include <Common/CHUtil.h>
+#include <Common/Exception.h>
+#include <Common/logger_useful.h>
namespace DB::ErrorCodes
{
@@ -90,7 +91,16 @@ public:
assert(parsed_args.size() == 2);
if (lambda_args.size() == 1)
{
- return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1],
parsed_args[0]});
+ /// Convert Array(T) to Array(U) if needed, Array(T) is the type
of the first argument of transform.
+ /// U is the argument type of lambda function. In some cases
Array(T) is not equal to Array(U).
+ /// e.g. in the second query of
https://github.com/apache/incubator-gluten/issues/6561, T is String, and U is
Nullable(String)
+ /// The difference of both types will result in runtime exceptions
in function capture.
+ const auto & src_array_type = parsed_args[0]->result_type;
+ DataTypePtr dst_array_type =
std::make_shared<DataTypeArray>(lambda_args.front().type);
+ if (isNullableOrLowCardinalityNullable(src_array_type))
+ dst_array_type =
std::make_shared<DataTypeNullable>(dst_array_type);
+ const auto * dst_array_arg =
ActionsDAGUtil::convertNodeTypeIfNeeded(actions_dag, parsed_args[0],
dst_array_type);
+ return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1],
dst_array_arg});
}
/// transform with index argument.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]