This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 2ad689b8d [GLUTEN-1632][CH]Daily Update Clickhouse Version (20240815)
(#6848)
2ad689b8d is described below
commit 2ad689b8dae64b00a9d6d54a2317e5a9e6e3b48a
Author: Kyligence Git <[email protected]>
AuthorDate: Thu Aug 15 08:07:11 2024 -0500
[GLUTEN-1632][CH]Daily Update Clickhouse Version (20240815) (#6848)
* [GLUTEN-1632][CH]Daily Update Clickhouse Version (20240815)
* Fix Build due to https://github.com/ClickHouse/ClickHouse/pull/68107
* Fix Build due to https://github.com/ClickHouse/ClickHouse/pull/68135
* Fix UT due to https://github.com/ClickHouse/ClickHouse/pull/68135
* Add ut for https://github.com/apache/incubator-gluten/issues/2584
- Rebase failed with https://github.com/ClickHouse/ClickHouse/pull/67879,
and hence we can remove https://github.com/Kyligence/ClickHouse/pull/454
(cherry picked from commit 583aa8d6566a9e1c0924c1a3ab1d315fcc229fa6)
* Fix CH BUG due to https://github.com/ClickHouse/ClickHouse/pull/68135
see
https://github.com/Kyligence/ClickHouse/commit/d87dbba64fcbafa7dfbbe41647bcca8357fdd6cd
* Resolve conflict
---------
Co-authored-by: kyligence-git <[email protected]>
Co-authored-by: Chang Chen <[email protected]>
---
.../GlutenClickHouseNativeWriteTableSuite.scala | 33 +++++
.../apache/spark/gluten/NativeWriteChecker.scala | 6 +
cpp-ch/clickhouse.version | 5 +-
.../AggregateFunctionGroupBloomFilter.cpp | 4 +-
.../AggregateFunctionSparkAvg.cpp | 2 +-
cpp-ch/local-engine/Common/CHUtil.cpp | 15 +-
cpp-ch/local-engine/Common/CHUtil.h | 4 +-
.../Functions/SparkFunctionArraySort.cpp | 2 +-
.../Functions/SparkFunctionCheckDecimalOverflow.h | 14 +-
cpp-ch/local-engine/Functions/SparkFunctionFloor.h | 2 +-
.../Functions/SparkFunctionHashingExtended.h | 40 +++---
.../Functions/SparkFunctionMakeDecimal.cpp | 2 +-
.../Functions/SparkFunctionRoundHalfUp.h | 2 +-
.../Functions/SparkFunctionToDateTime.h | 2 +-
cpp-ch/local-engine/Operator/ExpandTransform.cpp | 2 +-
.../Parser/AggregateFunctionParser.cpp | 2 +-
cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp | 93 ++++--------
cpp-ch/local-engine/Parser/FunctionParser.cpp | 8 +-
.../local-engine/Parser/SerializedPlanParser.cpp | 10 +-
cpp-ch/local-engine/Parser/WriteRelParser.cpp | 4 +-
.../ApproxPercentileParser.cpp | 2 +-
.../BloomFilterAggParser.cpp | 4 +-
.../aggregate_function_parser/LeadLagParser.cpp | 9 +-
.../aggregate_function_parser/NtileParser.cpp | 2 +-
.../arrayHighOrderFunctions.cpp | 11 +-
.../scalar_function_parser/arrayPosition.cpp | 2 +-
.../Parser/scalar_function_parser/elt.cpp | 2 +-
.../Parser/scalar_function_parser/findInset.cpp | 5 +-
.../scalar_function_parser/lambdaFunction.cpp | 4 +-
.../Parser/scalar_function_parser/locate.cpp | 3 +-
.../Parser/scalar_function_parser/repeat.cpp | 4 +-
.../Parser/scalar_function_parser/slice.cpp | 2 +-
.../Parser/scalar_function_parser/tupleElement.cpp | 2 +-
.../Storages/Mergetree/SparkMergeTreeWriter.cpp | 8 +-
.../Storages/Parquet/ParquetConverter.h | 8 +-
.../local-engine/tests/data/68135.snappy.parquet | Bin 0 -> 461 bytes
.../tests/gtest_clickhouse_pr_verify.cpp | 24 +++-
.../tests/json/clickhouse_pr_68135.json | 160 +++++++++++++++++++++
38 files changed, 346 insertions(+), 158 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
index 9f5dc4d3c..652b15fc2 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala
@@ -937,4 +937,37 @@ class GlutenClickHouseNativeWriteTableSuite
_ => {})
)
}
+
+ test("GLUTEN-2584: fix native write and read mismatch about complex types") {
+ def table(format: String): String = s"t_$format"
+ def create(format: String, table_name: Option[String] = None): String =
+ s"""CREATE TABLE ${table_name.getOrElse(table(format))}(
+ | id INT,
+ | info STRUCT<name:STRING, age:INT>,
+ | data MAP<STRING, INT>,
+ | values ARRAY<INT>
+ |) stored as $format""".stripMargin
+ def insert(format: String, table_name: Option[String] = None): String =
+ s"""INSERT overwrite ${table_name.getOrElse(table(format))} VALUES
+ | (6, null, null, null);
+ """.stripMargin
+
+ nativeWrite2(
+ format => (table(format), create(format), insert(format)),
+ (table_name, format) => {
+ val vanilla_table = s"${table_name}_v"
+ val vanilla_create = create(format, Some(vanilla_table))
+ vanillaWrite {
+ withDestinationTable(vanilla_table, Option(vanilla_create)) {
+ checkInsertQuery(insert(format, Some(vanilla_table)), checkNative
= false)
+ }
+ }
+ val rowsFromOriginTable =
+ spark.sql(s"select * from $vanilla_table").collect()
+ val dfFromWriteTable =
+ spark.sql(s"select * from $table_name")
+ checkAnswer(dfFromWriteTable, rowsFromOriginTable)
+ }
+ )
+ }
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
index 4bee3f177..49e368c88 100644
---
a/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
@@ -75,6 +75,12 @@ trait NativeWriteChecker
}
}
+ def vanillaWrite(block: => Unit): Unit = {
+ withSQLConf(("spark.gluten.sql.native.writer.enabled", "false")) {
+ block
+ }
+ }
+
def withSource(df: Dataset[Row], viewName: String, pairs: (String, String)*)(
block: => Unit): Unit = {
withSQLConf(pairs: _*) {
diff --git a/cpp-ch/clickhouse.version b/cpp-ch/clickhouse.version
index 8068b57f2..7c93bc124 100644
--- a/cpp-ch/clickhouse.version
+++ b/cpp-ch/clickhouse.version
@@ -1,4 +1,3 @@
CH_ORG=Kyligence
-CH_BRANCH=rebase_ch/20240809
-CH_COMMIT=01e780d46d9
-
+CH_BRANCH=rebase_ch/20240815
+CH_COMMIT=d87dbba64fc
diff --git
a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionGroupBloomFilter.cpp
b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionGroupBloomFilter.cpp
index 5555302a5..1b853cc67 100644
---
a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionGroupBloomFilter.cpp
+++
b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionGroupBloomFilter.cpp
@@ -62,10 +62,10 @@ createAggregateFunctionBloomFilter(const std::string &
name, const DataTypes & a
if (type != Field::Types::Int64 && type != Field::Types::UInt64)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for
aggregate function {} should be Int64 or UInt64", name);
- if ((type == Field::Types::Int64 && parameters[i].get<Int64>() <
0))
+ if ((type == Field::Types::Int64 && parameters[i].safeGet<Int64>()
< 0))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for
aggregate function {} should be non-negative number", name);
- return parameters[i].get<UInt64>();
+ return parameters[i].safeGet<UInt64>();
};
filter_size = get_parameter(0);
diff --git
a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
index 5eb3a0b36..0aa233145 100644
--- a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
+++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
@@ -140,7 +140,7 @@ createAggregateFunctionSparkAvg(const std::string & name,
const DataTypes & argu
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument
for aggregate function {}", data_type->getName(), name);
- bool allowPrecisionLoss =
settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).get<bool>();
+ bool allowPrecisionLoss =
settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).safeGet<bool>();
const UInt32 p1 = DB::getDecimalPrecision(*data_type);
const UInt32 s1 = DB::getDecimalScale(*data_type);
auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL;
diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp
b/cpp-ch/local-engine/Common/CHUtil.cpp
index 35b4f0c97..0409b66bd 100644
--- a/cpp-ch/local-engine/Common/CHUtil.cpp
+++ b/cpp-ch/local-engine/Common/CHUtil.cpp
@@ -51,11 +51,13 @@
#include <Interpreters/JIT/CompiledExpressionCache.h>
#include <Parser/RelParser.h>
#include <Parser/SerializedPlanParser.h>
+#include <Planner/PlannerActionsVisitor.h>
#include <Processors/Chunk.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <QueryPipeline/printPipeline.h>
+#include <Storages/Cache/CacheManager.h>
#include <Storages/Output/WriteBufferBuilder.h>
#include <Storages/StorageMergeTreeFactory.h>
#include <Storages/SubstraitSource/ReadBufferBuilder.h>
@@ -72,7 +74,6 @@
#include <Common/LoggerExtend.h>
#include <Common/logger_useful.h>
#include <Common/typeid_cast.h>
-#include <Storages/Cache/CacheManager.h>
namespace DB
{
@@ -463,20 +464,22 @@ const DB::ColumnWithTypeAndName *
NestedColumnExtractHelper::findColumn(const DB
const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeType(
DB::ActionsDAG & actions_dag,
const DB::ActionsDAG::Node * node,
- const std::string & type_name,
+ const DataTypePtr & cast_to_type,
const std::string & result_name,
CastType cast_type)
{
DB::ColumnWithTypeAndName type_name_col;
- type_name_col.name = type_name;
+ type_name_col.name = cast_to_type->getName();
type_name_col.column = DB::DataTypeString().createColumnConst(0,
type_name_col.name);
type_name_col.type = std::make_shared<DB::DataTypeString>();
const auto * right_arg = &actions_dag.addColumn(std::move(type_name_col));
const auto * left_arg = node;
DB::CastDiagnostic diagnostic = {node->result_name, node->result_name};
+ ColumnWithTypeAndName left_column{nullptr, node->result_type, {}};
DB::ActionsDAG::NodeRawConstPtrs children = {left_arg, right_arg};
- return &actions_dag.addFunction(
- DB::createInternalCastOverloadResolver(cast_type,
std::move(diagnostic)), std::move(children), result_name);
+ auto func_base_cast = createInternalCast(std::move(left_column),
cast_to_type, cast_type, diagnostic);
+
+ return &actions_dag.addFunction(func_base_cast, std::move(children),
result_name);
}
const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeTypeIfNeeded(
@@ -489,7 +492,7 @@ const DB::ActionsDAG::Node *
ActionsDAGUtil::convertNodeTypeIfNeeded(
if (node->result_type->equals(*dst_type))
return node;
- return convertNodeType(actions_dag, node, dst_type->getName(),
result_name, cast_type);
+ return convertNodeType(actions_dag, node, dst_type, result_name,
cast_type);
}
String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline)
diff --git a/cpp-ch/local-engine/Common/CHUtil.h
b/cpp-ch/local-engine/Common/CHUtil.h
index 785d5d6c0..c91b7264d 100644
--- a/cpp-ch/local-engine/Common/CHUtil.h
+++ b/cpp-ch/local-engine/Common/CHUtil.h
@@ -128,8 +128,8 @@ class ActionsDAGUtil
public:
static const DB::ActionsDAG::Node * convertNodeType(
DB::ActionsDAG & actions_dag,
- const DB::ActionsDAG::Node * node,
- const std::string & type_name,
+ const DB::ActionsDAG::Node * node_to_cast,
+ const DB::DataTypePtr & cast_to_type,
const std::string & result_name = "",
DB::CastType cast_type = DB::CastType::nonAccurate);
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
index 1371ec60e..cf9d67f16 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
@@ -60,7 +60,7 @@ struct LambdaLess
auto compare_res_col = lambda_->reduce();
DB::Field field;
compare_res_col.column->get(0, field);
- return field.get<Int32>() < 0;
+ return field.safeGet<Int32>() < 0;
}
private:
ALWAYS_INLINE DB::ColumnPtr oneRowColumn(size_t i) const
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.h
b/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.h
index 32bf79a56..e501c7fc5 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.h
@@ -50,17 +50,17 @@ template <typename To>
Field convertNumericType(const Field & from)
{
if (from.getType() == Field::Types::UInt64)
- return convertNumericTypeImpl<UInt64, To>(from.get<UInt64>());
+ return convertNumericTypeImpl<UInt64, To>(from.safeGet<UInt64>());
if (from.getType() == Field::Types::Int64)
- return convertNumericTypeImpl<Int64, To>(from.get<Int64>());
+ return convertNumericTypeImpl<Int64, To>(from.safeGet<Int64>());
if (from.getType() == Field::Types::UInt128)
- return convertNumericTypeImpl<UInt128, To>(from.get<UInt128>());
+ return convertNumericTypeImpl<UInt128, To>(from.safeGet<UInt128>());
if (from.getType() == Field::Types::Int128)
- return convertNumericTypeImpl<Int128, To>(from.get<Int128>());
+ return convertNumericTypeImpl<Int128, To>(from.safeGet<Int128>());
if (from.getType() == Field::Types::UInt256)
- return convertNumericTypeImpl<UInt256, To>(from.get<UInt256>());
+ return convertNumericTypeImpl<UInt256, To>(from.safeGet<UInt256>());
if (from.getType() == Field::Types::Int256)
- return convertNumericTypeImpl<Int256, To>(from.get<Int256>());
+ return convertNumericTypeImpl<Int256, To>(from.safeGet<Int256>());
throw Exception(ErrorCodes::TYPE_MISMATCH, "Type mismatch. Expected:
Integer. Got: {}", from.getType());
}
@@ -81,7 +81,7 @@ inline UInt32 extractArgument(const ColumnWithTypeAndName &
named_column)
throw Exception(
ErrorCodes::DECIMAL_OVERFLOW, "{} convert overflow,
precision/scale value must in UInt32", named_column.type->getName());
}
- return static_cast<UInt32>(to.get<UInt32>());
+ return static_cast<UInt32>(to.safeGet<UInt32>());
}
}
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionFloor.h
b/cpp-ch/local-engine/Functions/SparkFunctionFloor.h
index ce33d11db..4a3f99a9a 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionFloor.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionFloor.h
@@ -197,7 +197,7 @@ class SparkFunctionFloor : public DB::FunctionFloor
if (scale_field.getType() != Field::Types::UInt64 &&
scale_field.getType() != Field::Types::Int64)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument
for rounding functions must have integer type");
- Int64 scale64 = scale_field.get<Int64>();
+ Int64 scale64 = scale_field.safeGet<Int64>();
if (scale64 > std::numeric_limits<Scale>::max() || scale64 <
std::numeric_limits<Scale>::min())
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale
argument for rounding function is too large");
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionHashingExtended.h
b/cpp-ch/local-engine/Functions/SparkFunctionHashingExtended.h
index 57bf00ba9..c64990314 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionHashingExtended.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionHashingExtended.h
@@ -101,42 +101,42 @@ private:
if (which.isNothing())
return seed;
else if (which.isUInt8())
- return applyNumber<UInt8>(field.get<UInt8>(), seed);
+ return applyNumber<UInt8>(field.safeGet<UInt8>(), seed);
else if (which.isUInt16())
- return applyNumber<UInt16>(field.get<UInt16>(), seed);
+ return applyNumber<UInt16>(field.safeGet<UInt16>(), seed);
else if (which.isUInt32())
- return applyNumber<UInt32>(field.get<UInt32>(), seed);
+ return applyNumber<UInt32>(field.safeGet<UInt32>(), seed);
else if (which.isUInt64())
- return applyNumber<UInt64>(field.get<UInt64>(), seed);
+ return applyNumber<UInt64>(field.safeGet<UInt64>(), seed);
else if (which.isInt8())
- return applyNumber<Int8>(field.get<Int8>(), seed);
+ return applyNumber<Int8>(field.safeGet<Int8>(), seed);
else if (which.isInt16())
- return applyNumber<Int16>(field.get<Int16>(), seed);
+ return applyNumber<Int16>(field.safeGet<Int16>(), seed);
else if (which.isInt32())
- return applyNumber<Int32>(field.get<Int32>(), seed);
+ return applyNumber<Int32>(field.safeGet<Int32>(), seed);
else if (which.isInt64())
- return applyNumber<Int64>(field.get<Int64>(), seed);
+ return applyNumber<Int64>(field.safeGet<Int64>(), seed);
else if (which.isFloat32())
- return applyNumber<Float32>(field.get<Float32>(), seed);
+ return applyNumber<Float32>(field.safeGet<Float32>(), seed);
else if (which.isFloat64())
- return applyNumber<Float64>(field.get<Float64>(), seed);
+ return applyNumber<Float64>(field.safeGet<Float64>(), seed);
else if (which.isDate())
- return applyNumber<UInt16>(field.get<UInt16>(), seed);
+ return applyNumber<UInt16>(field.safeGet<UInt16>(), seed);
else if (which.isDate32())
- return applyNumber<Int32>(field.get<Int32>(), seed);
+ return applyNumber<Int32>(field.safeGet<Int32>(), seed);
else if (which.isDateTime())
- return applyNumber<UInt32>(field.get<UInt32>(), seed);
+ return applyNumber<UInt32>(field.safeGet<UInt32>(), seed);
else if (which.isDateTime64())
- return applyDecimal<DateTime64>(field.get<DateTime64>(), seed);
+ return applyDecimal<DateTime64>(field.safeGet<DateTime64>(), seed);
else if (which.isDecimal32())
- return applyDecimal<Decimal32>(field.get<Decimal32>(), seed);
+ return applyDecimal<Decimal32>(field.safeGet<Decimal32>(), seed);
else if (which.isDecimal64())
- return applyDecimal<Decimal64>(field.get<Decimal64>(), seed);
+ return applyDecimal<Decimal64>(field.safeGet<Decimal64>(), seed);
else if (which.isDecimal128())
- return applyDecimal<Decimal128>(field.get<Decimal128>(), seed);
+ return applyDecimal<Decimal128>(field.safeGet<Decimal128>(), seed);
else if (which.isStringOrFixedString())
{
- const String & str = field.get<String>();
+ const String & str = field.safeGet<String>();
return applyUnsafeBytes(str.data(), str.size(), seed);
}
else if (which.isTuple())
@@ -145,7 +145,7 @@ private:
assert(tuple_type);
const auto & elements = tuple_type->getElements();
- const Tuple & tuple = field.get<Tuple>();
+ const Tuple & tuple = field.safeGet<Tuple>();
assert(tuple.size() == elements.size());
for (size_t i = 0; i < elements.size(); ++i)
@@ -160,7 +160,7 @@ private:
assert(array_type);
const auto & nested_type = array_type->getNestedType();
- const Array & array = field.get<Array>();
+ const Array & array = field.safeGet<Array>();
for (size_t i=0; i < array.size(); ++i)
{
seed = applyGeneric(array[i], seed, nested_type);
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionMakeDecimal.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionMakeDecimal.cpp
index 231856b02..795e2b0be 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionMakeDecimal.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionMakeDecimal.cpp
@@ -205,7 +205,7 @@ namespace
else
return false;
}
- result = static_cast<ToNativeType>(convert_to.get<ToNativeType>());
+ result =
static_cast<ToNativeType>(convert_to.safeGet<ToNativeType>());
ToNativeType pow10 = intExp10OfSize<ToNativeType>(precision_value);
if ((result < 0 && result <= -pow10) || result >= pow10)
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h
b/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h
index 441842d4e..0bd28b116 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h
@@ -271,7 +271,7 @@ public:
if (scale_field.getType() != Field::Types::UInt64 &&
scale_field.getType() != Field::Types::Int64)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument
for rounding functions must have integer type");
- Int64 scale64 = scale_field.get<Int64>();
+ Int64 scale64 = scale_field.safeGet<Int64>();
if (scale64 > std::numeric_limits<Scale>::max() || scale64 <
std::numeric_limits<Scale>::min())
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale
argument for rounding function is too large");
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h
b/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h
index 980af85bd..aab8aabc3 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h
@@ -128,7 +128,7 @@ public:
Field field;
named_column.column->get(0, field);
- return static_cast<UInt32>(field.get<UInt32>());
+ return static_cast<UInt32>(field.safeGet<UInt32>());
}
DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &
arguments) const override
diff --git a/cpp-ch/local-engine/Operator/ExpandTransform.cpp
b/cpp-ch/local-engine/Operator/ExpandTransform.cpp
index 5100ad070..29e254bc0 100644
--- a/cpp-ch/local-engine/Operator/ExpandTransform.cpp
+++ b/cpp-ch/local-engine/Operator/ExpandTransform.cpp
@@ -110,7 +110,7 @@ void ExpandTransform::work()
if (kind == EXPAND_FIELD_KIND_SELECTION)
{
- auto index = field.get<Int32>();
+ auto index = field.safeGet<Int32>();
const auto & input_column = input_columns[index];
DB::ColumnWithTypeAndName input_arg;
diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
index f976d50ad..b843d1565 100644
--- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
+++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
@@ -155,7 +155,7 @@ const DB::ActionsDAG::Node *
AggregateFunctionParser::convertNodeTypeIfNeeded(
if (need_convert_type)
{
func_node = ActionsDAGUtil::convertNodeType(
- actions_dag, func_node,
TypeParser::parseType(output_type)->getName(), func_node->result_name);
+ actions_dag, func_node, TypeParser::parseType(output_type),
func_node->result_name);
actions_dag.addOrReplaceInOutputs(*func_node);
}
diff --git a/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp
b/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp
index 3d5a7731b..602cd3d68 100644
--- a/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp
+++ b/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp
@@ -501,7 +501,7 @@ int64_t BackingDataLengthCalculator::calculate(const Field
& field) const
if (which.isStringOrFixedString())
{
- const auto & str = field.get<String>();
+ const auto & str = field.safeGet<String>();
return roundNumberOfBytesToNearestWord(str.size());
}
@@ -511,7 +511,7 @@ int64_t BackingDataLengthCalculator::calculate(const Field
& field) const
if (which.isArray())
{
/// 内存布局:numElements(8B) | null_bitmap(与numElements成正比) |
values(每个值长度与类型有关) | backing buffer
- const auto & array = field.get<Array>(); /// Array can not be wrapped
with Nullable
+ const auto & array = field.safeGet<Array>(); /// Array can not be
wrapped with Nullable
const auto num_elems = array.size();
int64_t res = 8 + calculateBitSetWidthInBytes(num_elems);
@@ -531,7 +531,7 @@ int64_t BackingDataLengthCalculator::calculate(const Field
& field) const
int64_t res = 8;
/// Construct Array of keys and values from Map
- const auto & map = field.get<Map>(); /// Map can not be wrapped with
Nullable
+ const auto & map = field.safeGet<Map>(); /// Map can not be wrapped
with Nullable
const auto num_keys = map.size();
auto array_key = Array();
auto array_val = Array();
@@ -539,7 +539,7 @@ int64_t BackingDataLengthCalculator::calculate(const Field
& field) const
array_val.reserve(num_keys);
for (size_t i = 0; i < num_keys; ++i)
{
- const auto & pair = map[i].get<DB::Tuple>();
+ const auto & pair = map[i].safeGet<DB::Tuple>();
array_key.push_back(pair[0]);
array_val.push_back(pair[1]);
}
@@ -561,7 +561,7 @@ int64_t BackingDataLengthCalculator::calculate(const Field
& field) const
if (which.isTuple())
{
/// 内存布局:null_bitmap(字节数与字段数成正比) | field1 value(8B) | field2 value(8B)
| ... | fieldn value(8B) | backing buffer
- const auto & tuple = field.get<Tuple>(); /// Tuple can not be wrapped
with Nullable
+ const auto & tuple = field.safeGet<Tuple>(); /// Tuple can not be
wrapped with Nullable
const auto * type_tuple = typeid_cast<const DataTypeTuple
*>(type_without_nullable.get());
const auto & type_fields = type_tuple->getElements();
const auto num_fields = type_fields.size();
@@ -687,30 +687,7 @@ int64_t VariableLengthDataWriter::writeArray(size_t
row_idx, const DB::Array & a
bitSet(buffer_address + offset + start + 8, i);
else
{
- if (writer.getWhichDataType().isFloat32())
- {
- // We can not use get<char>() directly here to process
Float32 field,
- // because it will get 8 byte data, but Float32 is 4 byte,
which will cause error conversion.
- auto v = static_cast<Float32>(elem.get<Float32>());
- writer.unsafeWrite(
- reinterpret_cast<const char *>(&v), buffer_address +
offset + start + 8 + len_null_bitmap + i * elem_size);
- }
- else if (writer.getWhichDataType().isFloat64())
- {
- // Fix 'Invalid Field get from type Float64 to type Int64'
in debug build.
- auto v = elem.get<Float64>();
- writer.unsafeWrite(reinterpret_cast<const char *>(&v),
buffer_address + offset + start + 8 + len_null_bitmap + i * elem_size);
- }
- else if (writer.getWhichDataType().isDecimal32())
- {
- // We can not use get<char>() directly here to process
Decimal32 field,
- // because it will get 4 byte data, but Decimal32 is 8 byte
in Spark, which will cause error conversion.
- writer.write(elem, buffer_address + offset + start + 8 +
len_null_bitmap + i * elem_size);
- }
- else
- writer.unsafeWrite(
- reinterpret_cast<const char *>(&elem.get<char>()),
- buffer_address + offset + start + 8 + len_null_bitmap
+ i * elem_size);
+ writer.write(elem, buffer_address + offset + start + 8 +
len_null_bitmap + i * elem_size);
}
}
}
@@ -754,7 +731,7 @@ int64_t VariableLengthDataWriter::writeMap(size_t row_idx,
const DB::Map & map,
val_array.reserve(num_pairs);
for (size_t i = 0; i < num_pairs; ++i)
{
- const auto & pair = map[i].get<DB::Tuple>();
+ const auto & pair = map[i].safeGet<DB::Tuple>();
key_array.push_back(pair[0]);
val_array.push_back(pair[1]);
}
@@ -812,27 +789,7 @@ int64_t VariableLengthDataWriter::writeStruct(size_t
row_idx, const DB::Tuple &
if
(BackingDataLengthCalculator::isFixedLengthDataType(removeNullable(field_type)))
{
FixedLengthDataWriter writer(field_type);
- if (writer.getWhichDataType().isFloat32())
- {
- // We can not use get<char>() directly here to process Float32
field,
- // because it will get 8 byte data, but Float32 is 4 byte,
which will cause error conversion.
- auto v = static_cast<Float32>(field_value.get<Float32>());
- writer.unsafeWrite(reinterpret_cast<const char *>(&v),
buffer_address + offset + start + len_null_bitmap + i * 8);
- }
- else if (writer.getWhichDataType().isFloat64())
- {
- // Fix 'Invalid Field get from type Float64 to type Int64' in
debug build.
- auto v = field_value.get<Float64>();
- writer.unsafeWrite(reinterpret_cast<const char *>(&v),
buffer_address + offset + start + len_null_bitmap + i * 8);
- }
- else if (writer.getWhichDataType().isDecimal64() ||
writer.getWhichDataType().isDateTime64())
- {
- auto v = field_value.get<Decimal64>();
- writer.unsafeWrite(reinterpret_cast<const char *>(&v),
buffer_address + offset + start + len_null_bitmap + i * 8);
- }
- else
- writer.unsafeWrite(
- reinterpret_cast<const char *>(&field_value.get<char>()),
buffer_address + offset + start + len_null_bitmap + i * 8);
+ writer.write(field_value, buffer_address + offset + start +
len_null_bitmap + i * 8);
}
else
{
@@ -853,7 +810,7 @@ int64_t VariableLengthDataWriter::write(size_t row_idx,
const DB::Field & field,
if (which.isStringOrFixedString())
{
- const auto & str = field.get<String>();
+ const auto & str = field.safeGet<String>();
return writeUnalignedBytes(row_idx, str.data(), str.size(),
parent_offset);
}
@@ -868,19 +825,19 @@ int64_t VariableLengthDataWriter::write(size_t row_idx,
const DB::Field & field,
if (which.isArray())
{
- const auto & array = field.get<Array>();
+ const auto & array = field.safeGet<Array>();
return writeArray(row_idx, array, parent_offset);
}
if (which.isMap())
{
- const auto & map = field.get<Map>();
+ const auto & map = field.safeGet<Map>();
return writeMap(row_idx, map, parent_offset);
}
if (which.isTuple())
{
- const auto & tuple = field.get<Tuple>();
+ const auto & tuple = field.safeGet<Tuple>();
return writeStruct(row_idx, tuple, parent_offset);
}
@@ -926,64 +883,64 @@ void FixedLengthDataWriter::write(const DB::Field &
field, char * buffer)
if (which.isUInt8())
{
- const auto value = UInt8(field.get<UInt8>());
+ const auto value = static_cast<UInt8>(field.safeGet<UInt8>());
memcpy(buffer, &value, 1);
}
else if (which.isUInt16() || which.isDate())
{
- const auto value = UInt16(field.get<UInt16>());
+ const auto value = static_cast<UInt16>(field.safeGet<UInt16>());
memcpy(buffer, &value, 2);
}
else if (which.isUInt32() || which.isDate32())
{
- const auto value = UInt32(field.get<UInt32>());
+ const auto value = static_cast<UInt32>(field.safeGet<UInt32>());
memcpy(buffer, &value, 4);
}
else if (which.isUInt64())
{
- const auto & value = field.get<UInt64>();
+ const auto & value = field.safeGet<UInt64>();
memcpy(buffer, &value, 8);
}
else if (which.isInt8())
{
- const auto value = Int8(field.get<Int8>());
+ const auto value = static_cast<Int8>(field.safeGet<Int8>());
memcpy(buffer, &value, 1);
}
else if (which.isInt16())
{
- const auto value = Int16(field.get<Int16>());
+ const auto value = static_cast<Int16>(field.safeGet<Int16>());
memcpy(buffer, &value, 2);
}
else if (which.isInt32())
{
- const auto value = Int32(field.get<Int32>());
+ const auto value = static_cast<Int32>(field.safeGet<Int32>());
memcpy(buffer, &value, 4);
}
else if (which.isInt64())
{
- const auto & value = field.get<Int64>();
+ const auto & value = field.safeGet<Int64>();
memcpy(buffer, &value, 8);
}
else if (which.isFloat32())
{
- const auto value = Float32(field.get<Float32>());
+ const auto value = static_cast<Float32>(field.safeGet<Float32>());
memcpy(buffer, &value, 4);
}
else if (which.isFloat64())
{
- const auto & value = field.get<Float64>();
+ const auto & value = field.safeGet<Float64>();
memcpy(buffer, &value, 8);
}
else if (which.isDecimal32())
{
- const auto & value = field.get<Decimal32>();
+ const auto & value = field.safeGet<Decimal32>();
const Int64 decimal = static_cast<Int64>(value.getValue());
memcpy(buffer, &decimal, 8);
}
else if (which.isDecimal64() || which.isDateTime64())
{
- const auto & value = field.get<Decimal64>();
- auto decimal = value.getValue();
+ const auto & value = field.safeGet<Decimal64>();
+ const auto decimal = value.getValue();
memcpy(buffer, &decimal, 8);
}
else
diff --git a/cpp-ch/local-engine/Parser/FunctionParser.cpp
b/cpp-ch/local-engine/Parser/FunctionParser.cpp
index d46110431..a875da275 100644
--- a/cpp-ch/local-engine/Parser/FunctionParser.cpp
+++ b/cpp-ch/local-engine/Parser/FunctionParser.cpp
@@ -80,8 +80,8 @@ const ActionsDAG::Node *
FunctionParser::convertNodeTypeIfNeeded(
actions_dag,
func_node,
// as stated in isTypeMatched, currently we don't change
nullability of the result type
- func_node->result_type->isNullable() ?
local_engine::wrapNullableType(true, result_type)->getName()
- :
local_engine::removeNullable(result_type)->getName(),
+ func_node->result_type->isNullable() ?
local_engine::wrapNullableType(true, result_type)
+ :
local_engine::removeNullable(result_type),
func_node->result_name,
CastType::accurateOrNull);
}
@@ -91,8 +91,8 @@ const ActionsDAG::Node *
FunctionParser::convertNodeTypeIfNeeded(
actions_dag,
func_node,
// as stated in isTypeMatched, currently we don't change
nullability of the result type
- func_node->result_type->isNullable() ?
local_engine::wrapNullableType(true,
TypeParser::parseType(output_type))->getName()
- :
DB::removeNullable(TypeParser::parseType(output_type))->getName(),
+ func_node->result_type->isNullable() ?
local_engine::wrapNullableType(true, TypeParser::parseType(output_type))
+ :
DB::removeNullable(TypeParser::parseType(output_type)),
func_node->result_name);
}
}
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index d0924a745..297551bcc 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -285,7 +285,10 @@ QueryPlanStepPtr
SerializedPlanParser::parseReadRealWithLocalFile(const substrai
if (rel.has_local_files())
local_files = rel.local_files();
else
+ {
local_files =
BinaryToMessage<substrait::ReadRel::LocalFiles>(split_infos.at(nextSplitInfoIndex()));
+ logDebugMessage(local_files, "local_files");
+ }
auto source = std::make_shared<SubstraitFileSource>(context, header,
local_files);
auto source_pipe = Pipe(source);
auto source_step = std::make_unique<SubstraitFileSourceStep>(context,
std::move(source_pipe), "substrait local files");
@@ -496,7 +499,10 @@ QueryPlanPtr SerializedPlanParser::parseOp(const
substrait::Rel & rel, std::list
if (read.has_extension_table())
extension_table = read.extension_table();
else
+ {
extension_table =
BinaryToMessage<substrait::ReadRel::ExtensionTable>(split_infos.at(nextSplitInfoIndex()));
+ logDebugMessage(extension_table, "extension_table");
+ }
MergeTreeRelParser mergeTreeParser(this, context);
query_plan =
mergeTreeParser.parseReadRel(std::make_unique<QueryPlan>(), read,
extension_table);
@@ -689,7 +695,7 @@ ActionsDAG::NodeRawConstPtrs
SerializedPlanParser::parseArrayJoinWithDAG(
/// pos = cast(arrayJoin(arg_not_null).1, "Int32")
const auto * pos_node = add_tuple_element(array_join_node, 1);
- pos_node = ActionsDAGUtil::convertNodeType(actions_dag, pos_node,
"Int32");
+ pos_node = ActionsDAGUtil::convertNodeType(actions_dag, pos_node,
INT());
/// if is_map is false, output col = arrayJoin(arg_not_null).2
/// if is_map is true, output (key, value) = arrayJoin(arg_not_null).2
@@ -772,7 +778,7 @@ std::pair<DataTypePtr, Field>
SerializedPlanParser::convertStructFieldType(const
#define UINT_CONVERT(type_ptr, field, type_name) \
if ((type_ptr)->getTypeId() == TypeIndex::type_name) \
{ \
- return {std::make_shared<DataTypeU##type_name>(),
static_cast<U##type_name>((field).get<type_name>()) + 1}; \
+ return {std::make_shared<DataTypeU##type_name>(),
static_cast<U##type_name>((field).safeGet<type_name>()) + 1}; \
}
auto type_id = type->getTypeId();
diff --git a/cpp-ch/local-engine/Parser/WriteRelParser.cpp
b/cpp-ch/local-engine/Parser/WriteRelParser.cpp
index 9b6226adb..1a468a41e 100644
--- a/cpp-ch/local-engine/Parser/WriteRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/WriteRelParser.cpp
@@ -137,12 +137,12 @@ void addSinkTransfrom(const DB::ContextPtr & context,
const substrait::WriteRel
DB::Field field_tmp_dir;
if (!settings.tryGet(SPARK_TASK_WRITE_TMEP_DIR, field_tmp_dir))
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Write Pipeline
need inject temp directory.");
- const auto & tmp_dir = field_tmp_dir.get<std::string>();
+ const auto & tmp_dir = field_tmp_dir.safeGet<std::string>();
DB::Field field_filename;
if (!settings.tryGet(SPARK_TASK_WRITE_FILENAME, field_filename))
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Write Pipeline
need inject file name.");
- const auto & filename = field_filename.get<std::string>();
+ const auto & filename = field_filename.safeGet<std::string>();
assert(write_rel.has_named_table());
const substrait::NamedObjectWrite & named_table = write_rel.named_table();
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp
b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp
index 237da650c..ceddbd2ae 100644
---
a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp
+++
b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp
@@ -98,7 +98,7 @@ DB::Array ApproxPercentileParser::parseFunctionParameters(
if (isArray(type2))
{
/// Multiple percentages for quantilesGK
- const Array & percentags = field2.get<Array>();
+ const Array & percentags = field2.safeGet<Array>();
for (const auto & percentage : percentags)
params.emplace_back(percentage);
}
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp
b/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp
index 8788abb6d..10bf0b094 100644
---
a/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp
+++
b/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp
@@ -63,8 +63,8 @@ DB::Array
AggregateFunctionParserBloomFilterAgg::parseFunctionParameters(
node->column->get(0, ret);
return ret;
};
- Int64 insert_num = get_parameter_field(arg_nodes[1], 1).get<Int64>();
- Int64 bits_num = get_parameter_field(arg_nodes[2], 2).get<Int64>();
+ Int64 insert_num = get_parameter_field(arg_nodes[1],
1).safeGet<Int64>();
+ Int64 bits_num = get_parameter_field(arg_nodes[2], 2).safeGet<Int64>();
// Delete all args except the first arg.
arg_nodes.resize(1);
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp
b/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp
index 6d0075705..536aec1b6 100644
--- a/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp
+++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp
@@ -19,6 +19,7 @@
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <Interpreters/ActionsDAG.h>
+#include <Common/BlockTypeUtils.h>
#include <Common/CHUtil.h>
namespace local_engine
@@ -41,7 +42,7 @@ LeadParser::parseFunctionArguments(const CommonFunctionInfo &
func_info, DB::Act
node = ActionsDAGUtil::convertNodeType(
actions_dag,
&actions_dag.findInOutputs(arg0_col_name),
- DB::makeNullable(arg0_col_type)->getName(),
+ DB::makeNullable(arg0_col_type),
arg0_col_name);
actions_dag.addOrReplaceInOutputs(*node);
args.push_back(node);
@@ -52,7 +53,7 @@ LeadParser::parseFunctionArguments(const CommonFunctionInfo &
func_info, DB::Act
}
node = parseExpression(actions_dag, arg1);
- node = ActionsDAGUtil::convertNodeType(actions_dag, node,
DB::DataTypeInt64().getName());
+ node = ActionsDAGUtil::convertNodeType(actions_dag, node, BIGINT());
actions_dag.addOrReplaceInOutputs(*node);
args.push_back(node);
@@ -84,7 +85,7 @@ LagParser::parseFunctionArguments(const CommonFunctionInfo &
func_info, DB::Acti
node = ActionsDAGUtil::convertNodeType(
actions_dag,
&actions_dag.findInOutputs(arg0_col_name),
- DB::makeNullable(arg0_col_type)->getName(),
+ makeNullable(arg0_col_type),
arg0_col_name);
actions_dag.addOrReplaceInOutputs(*node);
args.push_back(node);
@@ -100,7 +101,7 @@ LagParser::parseFunctionArguments(const CommonFunctionInfo
& func_info, DB::Acti
auto real_field = 0 - literal_result.second.safeGet<Int32>();
node = &actions_dag.addColumn(ColumnWithTypeAndName(
literal_result.first->createColumnConst(1, real_field),
literal_result.first, getUniqueName(toString(real_field))));
- node = ActionsDAGUtil::convertNodeType(actions_dag, node,
DB::DataTypeInt64().getName());
+ node = ActionsDAGUtil::convertNodeType(actions_dag, node, BIGINT());
actions_dag.addOrReplaceInOutputs(*node);
args.push_back(node);
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp
b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp
index 62f83223c..1a24e3206 100644
--- a/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp
+++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp
@@ -32,7 +32,7 @@ NtileParser::parseFunctionArguments(const CommonFunctionInfo
& func_info, DB::Ac
auto [data_type, field] = parseLiteral(arg0.literal());
if (!(DB::WhichDataType(data_type).isInt32()))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "ntile's argument must be
i32");
- Int32 field_index = static_cast<Int32>(field.get<Int32>());
+ Int32 field_index = static_cast<Int32>(field.safeGet<Int32>());
// For CH, the data type of the args[0] must be the UInt32
const auto * index_node = addColumnToActionsDAG(actions_dag,
std::make_shared<DataTypeUInt32>(), field_index);
args.emplace_back(index_node);
diff --git
a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
index a475a1efb..aa82b33a7 100644
---
a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
+++
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
@@ -23,6 +23,7 @@
#include <Parser/TypeParser.h>
#include <Parser/scalar_function_parser/lambdaFunction.h>
#include <Poco/Logger.h>
+#include <Common/BlockTypeUtils.h>
#include <Common/CHUtil.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>
@@ -60,7 +61,7 @@ public:
/// filter with index argument.
const auto * range_end_node = toFunctionNode(actions_dag, "length",
{toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]})});
range_end_node = ActionsDAGUtil::convertNodeType(
- actions_dag, range_end_node, "Nullable(Int32)",
range_end_node->result_name);
+ actions_dag, range_end_node, makeNullable(INT()),
range_end_node->result_name);
const auto * index_array_node = toFunctionNode(
actions_dag,
"range",
@@ -106,7 +107,7 @@ public:
/// transform with index argument.
const auto * range_end_node = toFunctionNode(actions_dag, "length",
{toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]})});
range_end_node = ActionsDAGUtil::convertNodeType(
- actions_dag, range_end_node, "Nullable(Int32)",
range_end_node->result_name);
+ actions_dag, range_end_node, makeNullable(INT()),
range_end_node->result_name);
const auto * index_array_node = toFunctionNode(
actions_dag,
"range",
@@ -141,7 +142,7 @@ public:
parsed_args[1] = ActionsDAGUtil::convertNodeType(
actions_dag,
parsed_args[1],
- function_type->getReturnType()->getName(),
+ function_type->getReturnType(),
parsed_args[1]->result_name);
}
@@ -215,14 +216,14 @@ private:
if (!var_expr.has_literal())
return false;
auto [_, name] = plan_parser->parseLiteral(var_expr.literal());
- return var == name.get<String>();
+ return var == name.safeGet<String>();
};
auto is_int_value = [&](const substrait::Expression & expr, Int32 val)
{
if (!expr.has_literal())
return false;
auto [_, x] = plan_parser->parseLiteral(expr.literal());
- return val == x.get<Int32>();
+ return val == x.safeGet<Int32>();
};
auto is_variable_null = [&](const substrait::Expression & expr, const
String & var) {
diff --git
a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayPosition.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayPosition.cpp
index 1fda3d8fa..b0ade35a3 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayPosition.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayPosition.cpp
@@ -86,7 +86,7 @@ public:
DataTypePtr wrap_arr_nullable_type = wrapNullableType(true,
ch_function_node->result_type);
const auto * wrap_index_of_node = ActionsDAGUtil::convertNodeType(
- actions_dag, ch_function_node, wrap_arr_nullable_type->getName(),
ch_function_node->result_name);
+ actions_dag, ch_function_node, wrap_arr_nullable_type,
ch_function_node->result_name);
const auto * null_const_node = addColumnToActionsDAG(actions_dag,
wrap_arr_nullable_type, Field{});
const auto * or_condition_node = toFunctionNode(actions_dag, "or",
{arr_is_null_node, val_is_null_node});
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/elt.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/elt.cpp
index 992235cd9..accc6d418 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/elt.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/elt.cpp
@@ -74,7 +74,7 @@ public:
auto nullable_result_type = makeNullable(result_type);
const auto * nullable_array_element_node =
ActionsDAGUtil::convertNodeType(
- actions_dag, array_element_node, nullable_result_type->getName(),
array_element_node->result_name);
+ actions_dag, array_element_node, nullable_result_type,
array_element_node->result_name);
const auto * null_const_node = addColumnToActionsDAG(actions_dag,
nullable_result_type, Field());
const auto * is_null_node = toFunctionNode(actions_dag, "isNull",
{index_arg});
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/findInset.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/findInset.cpp
index ca9fb372c..96fedc6fe 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/findInset.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/findInset.cpp
@@ -19,6 +19,7 @@
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/IDataType.h>
#include <Parser/FunctionParser.h>
+#include <Common/BlockTypeUtils.h>
#include <Common/CHUtil.h>
namespace DB
@@ -73,9 +74,9 @@ public:
if (!str_is_nullable && !str_array_is_nullable)
return convertNodeTypeIfNeeded(substrait_func, index_of_node,
actions_dag);
- auto nullable_result_type =
makeNullable(std::make_shared<DataTypeInt32>());
+ auto nullable_result_type = makeNullable(INT());
const auto * nullable_index_of_node = ActionsDAGUtil::convertNodeType(
- actions_dag, index_of_node, nullable_result_type->getName(),
index_of_node->result_name);
+ actions_dag, index_of_node, nullable_result_type,
index_of_node->result_name);
const auto * null_const_node = addColumnToActionsDAG(actions_dag,
nullable_result_type, Field());
const auto * str_is_null_node = toFunctionNode(actions_dag, "isNull",
{str_arg});
diff --git
a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp
index 547ffd971..c2841564e 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp
@@ -43,7 +43,7 @@ DB::NamesAndTypesList collectLambdaArguments(const
SerializedPlanParser & plan_p
&&
plan_parser_.getFunctionSignatureName(arg.value().scalar_function().function_reference())
== "namedlambdavariable")
{
auto [_, col_name_field] =
plan_parser_.parseLiteral(arg.value().scalar_function().arguments()[0].value().literal());
- String col_name = col_name_field.get<String>();
+ String col_name = col_name_field.safeGet<String>();
if (collected_names.contains(col_name))
{
continue;
@@ -187,7 +187,7 @@ public:
const DB::ActionsDAG::Node * parse(const
substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG &
actions_dag) const override
{
auto [_, col_name_field] =
parseLiteral(substrait_func.arguments()[0].value().literal());
- String col_name = col_name_field.get<String>();
+ String col_name = col_name_field.safeGet<String>();
auto type = TypeParser::parseType(substrait_func.output_type());
const auto & inputs = actions_dag.getInputs();
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp
index b948daeda..17115895e 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/locate.cpp
@@ -17,6 +17,7 @@
#include <DataTypes/IDataType.h>
#include <Parser/FunctionParser.h>
+#include <Common/BlockTypeUtils.h>
#include <Common/CHUtil.h>
namespace DB
@@ -50,7 +51,7 @@ public:
const auto * substr_arg = parsed_args[0];
const auto * str_arg = parsed_args[1];
- const auto * start_pos_arg =
ActionsDAGUtil::convertNodeType(actions_dag, parsed_args[2],
"Nullable(UInt32)");
+ const auto * start_pos_arg =
ActionsDAGUtil::convertNodeType(actions_dag, parsed_args[2],
makeNullable(UINT()));
const auto * is_start_pos_null_node = toFunctionNode(actions_dag,
"isNull", {start_pos_arg});
const auto * const_1_node = addColumnToActionsDAG(actions_dag,
std::make_shared<DataTypeUInt64>(), 0);
const auto * position_node = toFunctionNode(actions_dag,
"positionUTF8Spark", {str_arg, substr_arg, start_pos_arg});
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/repeat.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/repeat.cpp
index ada91f853..74254911a 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/repeat.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/repeat.cpp
@@ -18,6 +18,7 @@
#include <DataTypes/DataTypeNumberBase.h>
#include <Parser/FunctionParser.h>
#include <Parser/TypeParser.h>
+#include <Common/BlockTypeUtils.h>
#include <Common/CHUtil.h>
#include <Common/Exception.h>
@@ -42,8 +43,7 @@ public:
const auto & args = substrait_func.arguments();
parsed_args.emplace_back(parseExpression(actions_dag,
args[0].value()));
const auto * repeat_times_node = parseExpression(actions_dag,
args[1].value());
- DB::DataTypeNullable
target_type(std::make_shared<DB::DataTypeUInt32>());
- repeat_times_node = ActionsDAGUtil::convertNodeType(actions_dag,
repeat_times_node, target_type.getName());
+ repeat_times_node = ActionsDAGUtil::convertNodeType(actions_dag,
repeat_times_node, makeNullable(UINT()));
parsed_args.emplace_back(repeat_times_node);
const auto * func_node = toFunctionNode(actions_dag, ch_function_name,
parsed_args);
return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag);
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/slice.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/slice.cpp
index 264320735..a96dca8ef 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/slice.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/slice.cpp
@@ -89,7 +89,7 @@ public:
DataTypePtr wrap_arr_nullable_type = wrapNullableType(true,
slice_node->result_type);
const auto * wrap_slice_node = ActionsDAGUtil::convertNodeType(
- actions_dag, slice_node, wrap_arr_nullable_type->getName(),
slice_node->result_name);
+ actions_dag, slice_node, wrap_arr_nullable_type,
slice_node->result_name);
const auto * null_const_node = addColumnToActionsDAG(actions_dag,
wrap_arr_nullable_type, Field{});
const auto * arr_is_null_node = toFunctionNode(actions_dag, "isNull",
{arr_arg});
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/tupleElement.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/tupleElement.cpp
index 179aa7860..4809cc887 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/tupleElement.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/tupleElement.cpp
@@ -45,7 +45,7 @@ namespace local_engine
auto [data_type, field] = parseLiteral(args[1].value().literal());
\
if (!DB::WhichDataType(data_type).isInt32()) \
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{}'s
second argument must be i32", #substrait_name); \
- Int32 field_index = static_cast<Int32>(field.get<Int32>() + 1); \
+ Int32 field_index = static_cast<Int32>(field.safeGet<Int32>() +
1); \
const auto * index_node = addColumnToActionsDAG(actions_dag,
std::make_shared<DataTypeUInt32>(), field_index); \
parsed_args.emplace_back(index_node); \
const auto * func_node = toFunctionNode(actions_dag,
ch_function_name, parsed_args); \
diff --git a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp
b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp
index 6fee65efe..93f4374d4 100644
--- a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp
+++ b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp
@@ -71,16 +71,16 @@ SparkMergeTreeWriter::SparkMergeTreeWriter(
, thread_pool(CurrentMetrics::LocalThread,
CurrentMetrics::LocalThreadActive, CurrentMetrics::LocalThreadScheduled, 1, 1,
100000)
{
const DB::Settings & settings = context->getSettingsRef();
- merge_after_insert =
settings.get(MERGETREE_MERGE_AFTER_INSERT).get<bool>();
- insert_without_local_storage =
settings.get(MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE).get<bool>();
+ merge_after_insert =
settings.get(MERGETREE_MERGE_AFTER_INSERT).safeGet<bool>();
+ insert_without_local_storage =
settings.get(MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE).safeGet<bool>();
Field limit_size_field;
if (settings.tryGet("optimize.minFileSize", limit_size_field))
- merge_min_size = limit_size_field.get<Int64>() <= 0 ? merge_min_size :
limit_size_field.get<Int64>();
+ merge_min_size = limit_size_field.safeGet<Int64>() <= 0 ?
merge_min_size : limit_size_field.safeGet<Int64>();
Field limit_cnt_field;
if (settings.tryGet("mergetree.max_num_part_per_merge_task",
limit_cnt_field))
- merge_limit_parts = limit_cnt_field.get<Int64>() <= 0 ?
merge_limit_parts : limit_cnt_field.get<Int64>();
+ merge_limit_parts = limit_cnt_field.safeGet<Int64>() <= 0 ?
merge_limit_parts : limit_cnt_field.safeGet<Int64>();
dest_storage = MergeTreeRelParser::parseStorage(merge_tree_table,
SerializedPlanParser::global_context);
isRemoteStorage =
dest_storage->getStoragePolicy()->getAnyDisk()->isRemote();
diff --git a/cpp-ch/local-engine/Storages/Parquet/ParquetConverter.h
b/cpp-ch/local-engine/Storages/Parquet/ParquetConverter.h
index 89e83e668..312cea7ef 100644
--- a/cpp-ch/local-engine/Storages/Parquet/ParquetConverter.h
+++ b/cpp-ch/local-engine/Storages/Parquet/ParquetConverter.h
@@ -38,9 +38,9 @@ struct ToParquet
T as(const DB::Field & value, const parquet::ColumnDescriptor &)
{
if constexpr (std::is_same_v<PhysicalType, parquet::Int32Type>)
- return static_cast<T>(value.get<Int64>());
+ return static_cast<T>(value.safeGet<Int64>());
// parquet::BooleanType, parquet::Int64Type, parquet::FloatType,
parquet::DoubleType
- return value.get<T>(); // FLOAT, DOUBLE, INT64
+ return value.safeGet<T>(); // FLOAT, DOUBLE, INT64
}
};
@@ -51,7 +51,7 @@ struct ToParquet<parquet::ByteArrayType>
T as(const DB::Field & value, const parquet::ColumnDescriptor &)
{
assert(value.getType() == DB::Field::Types::String);
- const std::string & s = value.get<std::string>();
+ const std::string & s = value.safeGet<std::string>();
const auto * const ptr = reinterpret_cast<const uint8_t *>(s.data());
return parquet::ByteArray(static_cast<uint32_t>(s.size()), ptr);
}
@@ -74,7 +74,7 @@ struct ToParquet<parquet::FLBAType>
"descriptor.type_length() = {} , which is > {}, e.g.
sizeof(Int128)",
descriptor.type_length(),
sizeof(Int128));
- Int128 val = value.get<DB::DecimalField<DB::Decimal128>>().getValue();
+ Int128 val =
value.safeGet<DB::DecimalField<DB::Decimal128>>().getValue();
std::reverse(reinterpret_cast<char *>(&val), reinterpret_cast<char
*>(&val) + sizeof(val));
const int offset = sizeof(Int128) - descriptor.type_length();
memcpy(buf, reinterpret_cast<char *>(&val) + offset,
descriptor.type_length());
diff --git a/cpp-ch/local-engine/tests/data/68135.snappy.parquet
b/cpp-ch/local-engine/tests/data/68135.snappy.parquet
new file mode 100644
index 000000000..ddd627790
Binary files /dev/null and
b/cpp-ch/local-engine/tests/data/68135.snappy.parquet differ
diff --git a/cpp-ch/local-engine/tests/gtest_clickhouse_pr_verify.cpp
b/cpp-ch/local-engine/tests/gtest_clickhouse_pr_verify.cpp
index 6352a8199..9e4165d90 100644
--- a/cpp-ch/local-engine/tests/gtest_clickhouse_pr_verify.cpp
+++ b/cpp-ch/local-engine/tests/gtest_clickhouse_pr_verify.cpp
@@ -63,7 +63,7 @@ TEST(Clickhouse, PR54881)
Field field;
const auto & col_1 = *(block.getColumns()[1]);
col_1.get(0, field);
- const Tuple & row_0 = field.get<DB::Tuple>();
+ const Tuple & row_0 = field.safeGet<DB::Tuple>();
EXPECT_EQ(2, row_0.size());
Int64 actual{-1};
@@ -74,7 +74,7 @@ TEST(Clickhouse, PR54881)
EXPECT_EQ(10, actual);
col_1.get(1, field);
- const Tuple & row_1 = field.get<DB::Tuple>();
+ const Tuple & row_1 = field.safeGet<DB::Tuple>();
EXPECT_EQ(2, row_1.size());
EXPECT_TRUE(row_1[0].tryGet<Int64>(actual));
EXPECT_EQ(10, actual);
@@ -96,4 +96,24 @@ TEST(Clickhouse, PR65234)
const auto plan = local_engine::JsonStringToMessage<substrait::Plan>(
{reinterpret_cast<const char *>(gresource_embedded_pr_65234_jsonData),
gresource_embedded_pr_65234_jsonSize});
auto query_plan = parser.parse(plan);
+}
+
+INCBIN(resource_embedded_pr_68135_json, SOURCE_DIR
"/utils/extern-local-engine/tests/json/clickhouse_pr_68135.json");
+TEST(Clickhouse, PR68135)
+{
+ const std::string split_template
+ =
R"({"items":[{"uriFile":"{replace_local_files}","partitionIndex":"0","length":"461","parquet":{},"schema":{},"metadataColumns":[{}]}]})";
+ const std::string split
+ = replaceLocalFilesWildcards(split_template,
GLUTEN_DATA_DIR("/utils/extern-local-engine/tests/data/68135.snappy.parquet"));
+
+ SerializedPlanParser parser(SerializedPlanParser::global_context);
+
parser.addSplitInfo(local_engine::JsonStringToBinary<substrait::ReadRel::LocalFiles>(split));
+
+ const auto plan = local_engine::JsonStringToMessage<substrait::Plan>(
+ {reinterpret_cast<const char *>(gresource_embedded_pr_68135_jsonData),
gresource_embedded_pr_68135_jsonSize});
+
+ auto local_executor = parser.createExecutor(plan);
+ EXPECT_TRUE(local_executor->hasNext());
+ const Block & x = *local_executor->nextColumnar();
+ debug::headBlock(x);
}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/tests/json/clickhouse_pr_68135.json
b/cpp-ch/local-engine/tests/json/clickhouse_pr_68135.json
new file mode 100644
index 000000000..c8b49857c
--- /dev/null
+++ b/cpp-ch/local-engine/tests/json/clickhouse_pr_68135.json
@@ -0,0 +1,160 @@
+{
+ "relations": [
+ {
+ "root": {
+ "input": {
+ "filter": {
+ "common": {
+ "direct": {}
+ },
+ "input": {
+ "read": {
+ "common": {
+ "direct": {}
+ },
+ "baseSchema": {
+ "names": [
+ "a"
+ ],
+ "struct": {
+ "types": [
+ {
+ "decimal": {
+ "scale": 2,
+ "precision": 9,
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ }
+ ]
+ },
+ "columnTypes": [
+ "NORMAL_COL"
+ ]
+ },
+ "filter": {
+ "singularOrList": {
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ }
+ }
+ },
+ "options": [
+ {
+ "literal": {
+ "decimal": {
+ "value": "yAAAAAAAAAAAAAAAAAAAAA==",
+ "precision": 9,
+ "scale": 2
+ }
+ }
+ },
+ {
+ "literal": {
+ "decimal": {
+ "value": "LAEAAAAAAAAAAAAAAAAAAA==",
+ "precision": 9,
+ "scale": 2
+ }
+ }
+ },
+ {
+ "literal": {
+ "decimal": {
+ "value": "kAEAAAAAAAAAAAAAAAAAAA==",
+ "precision": 9,
+ "scale": 2
+ }
+ }
+ },
+ {
+ "literal": {
+ "decimal": {
+ "value": "9AEAAAAAAAAAAAAAAAAAAA==",
+ "precision": 9,
+ "scale": 2
+ }
+ }
+ }
+ ]
+ }
+ },
+ "advancedExtension": {
+ "optimization": {
+ "@type": "type.googleapis.com/google.protobuf.StringValue",
+ "value": "isMergeTree=0\n"
+ }
+ }
+ }
+ },
+ "condition": {
+ "singularOrList": {
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ }
+ }
+ },
+ "options": [
+ {
+ "literal": {
+ "decimal": {
+ "value": "yAAAAAAAAAAAAAAAAAAAAA==",
+ "precision": 9,
+ "scale": 2
+ }
+ }
+ },
+ {
+ "literal": {
+ "decimal": {
+ "value": "LAEAAAAAAAAAAAAAAAAAAA==",
+ "precision": 9,
+ "scale": 2
+ }
+ }
+ },
+ {
+ "literal": {
+ "decimal": {
+ "value": "kAEAAAAAAAAAAAAAAAAAAA==",
+ "precision": 9,
+ "scale": 2
+ }
+ }
+ },
+ {
+ "literal": {
+ "decimal": {
+ "value": "9AEAAAAAAAAAAAAAAAAAAA==",
+ "precision": 9,
+ "scale": 2
+ }
+ }
+ }
+ ]
+ }
+ }
+ }
+ },
+ "names": [
+ "a#26"
+ ],
+ "outputSchema": {
+ "types": [
+ {
+ "decimal": {
+ "scale": 2,
+ "precision": 9,
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ }
+ ],
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ }
+ }
+ ]
+}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]