This is an automated email from the ASF dual-hosted git repository.
panxiaolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new e1f6df7cd21 [refine](cast) Add the line processing function of cast
(#53654)
e1f6df7cd21 is described below
commit e1f6df7cd215214bcbab68bbb3922e8127ae5141
Author: Mryange <[email protected]>
AuthorDate: Tue Jul 22 15:26:16 2025 +0800
[refine](cast) Add the line processing function of cast (#53654)
At present, this pr only impl cast to bool.
---
.../data_types/serde/data_type_number_serde.cpp | 17 ++-
be/src/vec/functions/cast/cast_base.h | 7 +-
be/src/vec/functions/cast/cast_to_boolean.h | 142 +++++++++++++++++++--
be/src/vec/io/io_helper.h | 2 +-
4 files changed, 150 insertions(+), 18 deletions(-)
diff --git a/be/src/vec/data_types/serde/data_type_number_serde.cpp
b/be/src/vec/data_types/serde/data_type_number_serde.cpp
index bd072a49b38..10566eafe29 100644
--- a/be/src/vec/data_types/serde/data_type_number_serde.cpp
+++ b/be/src/vec/data_types/serde/data_type_number_serde.cpp
@@ -27,6 +27,7 @@
#include "util/mysql_global.h"
#include "vec/columns/column_nullable.h"
#include "vec/core/types.h"
+#include "vec/functions/cast/cast_to_boolean.h"
#include "vec/io/io_helper.h"
namespace doris::vectorized {
@@ -608,11 +609,14 @@ Status
DataTypeNumberSerDe<T>::from_string_strict_mode(StringRef& str, IColumn&
}
template <PrimitiveType PT, bool enable_strict_cast>
-bool try_parse_impl(typename PrimitiveTypeTraits<PT>::ColumnItemType& x,
ReadBuffer& rb) {
+bool try_parse_impl(typename PrimitiveTypeTraits<PT>::ColumnItemType& x,
ReadBuffer& rb,
+ CastParameters& params) {
+ /// TODO: change the ReadBuffer arg to StringRef.
if constexpr (is_float_or_double(PT)) {
return try_read_float_text(x, rb);
} else if constexpr (PT == TYPE_BOOLEAN) {
- return try_read_bool_text(x, rb);
+ StringRef str {rb.position(), rb.count()};
+ return CastToBool::from_string(str, x, params);
} else if constexpr (is_int(PT)) {
return try_read_int_text<typename
PrimitiveTypeTraits<PT>::ColumnItemType,
enable_strict_cast>(x, rb);
@@ -636,12 +640,14 @@ Status DataTypeNumberSerDe<T>::from_string_batch(const
ColumnString& str, Column
const ColumnString::Chars* chars = &str.get_chars();
const IColumn::Offsets* offsets = &str.get_offsets();
+ CastParameters params;
+ params.is_strict = false;
for (size_t i = 0; i < size; ++i) {
size_t next_offset = (*offsets)[i];
size_t string_size = next_offset - current_offset;
ReadBuffer read_buffer(&(*chars)[current_offset], string_size);
- null_map[i] = !try_parse_impl<T, false>(vec_to[i], read_buffer);
+ null_map[i] = !try_parse_impl<T, false>(vec_to[i], read_buffer,
params);
current_offset = next_offset;
}
return Status::OK();
@@ -693,7 +699,8 @@ Status
DataTypeNumberSerDe<T>::from_string_strict_mode_batch(
auto& column_to = assert_cast<ColumnType&>(column);
auto& vec_to = column_to.get_data();
-
+ CastParameters params;
+ params.is_strict = true;
for (size_t i = 0; i < size; ++i) {
if (null_map && null_map[i]) {
continue;
@@ -702,7 +709,7 @@ Status
DataTypeNumberSerDe<T>::from_string_strict_mode_batch(
size_t string_size = next_offset - current_offset;
ReadBuffer read_buffer(&(*chars)[current_offset], string_size);
- if (!try_parse_impl<T, true>(vec_to[i], read_buffer)) {
+ if (!try_parse_impl<T, true>(vec_to[i], read_buffer, params)) {
return Status::InvalidArgument(
"parse number fail, string: '{}'",
std::string((char*)&(*chars)[current_offset],
string_size));
diff --git a/be/src/vec/functions/cast/cast_base.h
b/be/src/vec/functions/cast/cast_base.h
index d2ff94434fe..2b74f8dc922 100644
--- a/be/src/vec/functions/cast/cast_base.h
+++ b/be/src/vec/functions/cast/cast_base.h
@@ -40,7 +40,7 @@
#include "vec/data_types/data_type_time.h"
#include "vec/functions/function.h"
#include "vec/functions/function_helpers.h"
-
+#include "vec/io/io_helper.h"
namespace doris::vectorized {
struct NameCast {
@@ -160,6 +160,11 @@ public:
}
};
+struct CastParameters {
+ Status status = Status::OK();
+ bool is_strict;
+};
+
#ifdef BE_TEST
inline CastWrapper::WrapperType get_cast_wrapper(FunctionContext* context,
const DataTypePtr& from_type,
diff --git a/be/src/vec/functions/cast/cast_to_boolean.h
b/be/src/vec/functions/cast/cast_to_boolean.h
index f687a529370..e89d5ff1ab1 100644
--- a/be/src/vec/functions/cast/cast_to_boolean.h
+++ b/be/src/vec/functions/cast/cast_to_boolean.h
@@ -18,9 +18,104 @@
#pragma once
#include "cast_base.h"
+#include "vec/core/types.h"
+#include "vec/io/io_helper.h"
namespace doris::vectorized {
+struct CastToBool {
+ template <class SRC>
+ static inline bool from_number(const SRC& from, UInt8& to, CastParameters&
params);
+
+ template <class SRC>
+ static inline bool from_decimal(const SRC& from, UInt8& to, UInt32
precision, UInt32 scale,
+ CastParameters& params);
+
+ static inline bool from_string(const StringRef& from, UInt8& to,
CastParameters& params);
+};
+
+template <>
+inline bool CastToBool::from_number(const UInt8& from, UInt8& to,
CastParameters&) {
+ to = from;
+ return true;
+}
+
+template <>
+inline bool CastToBool::from_number(const Int8& from, UInt8& to,
CastParameters&) {
+ to = (from != 0);
+ return true;
+}
+template <>
+inline bool CastToBool::from_number(const Int16& from, UInt8& to,
CastParameters&) {
+ to = (from != 0);
+ return true;
+}
+template <>
+inline bool CastToBool::from_number(const Int32& from, UInt8& to,
CastParameters&) {
+ to = (from != 0);
+ return true;
+}
+template <>
+inline bool CastToBool::from_number(const Int64& from, UInt8& to,
CastParameters&) {
+ to = (from != 0);
+ return true;
+}
+template <>
+inline bool CastToBool::from_number(const Int128& from, UInt8& to,
CastParameters&) {
+ to = (from != 0);
+ return true;
+}
+
+template <>
+inline bool CastToBool::from_number(const Float32& from, UInt8& to,
CastParameters&) {
+ to = (from != 0);
+ return true;
+}
+template <>
+inline bool CastToBool::from_number(const Float64& from, UInt8& to,
CastParameters&) {
+ to = (from != 0);
+ return true;
+}
+
+template <>
+inline bool CastToBool::from_decimal(const Decimal32& from, UInt8& to, UInt32,
UInt32,
+ CastParameters&) {
+ to = (from.value != 0);
+ return true;
+}
+
+template <>
+inline bool CastToBool::from_decimal(const Decimal64& from, UInt8& to, UInt32,
UInt32,
+ CastParameters&) {
+ to = (from.value != 0);
+ return true;
+}
+
+template <>
+inline bool CastToBool::from_decimal(const Decimal128V2& from, UInt8& to,
UInt32, UInt32,
+ CastParameters&) {
+ to = (from.value != 0);
+ return true;
+}
+
+template <>
+inline bool CastToBool::from_decimal(const Decimal128V3& from, UInt8& to,
UInt32, UInt32,
+ CastParameters&) {
+ to = (from.value != 0);
+ return true;
+}
+
+template <>
+inline bool CastToBool::from_decimal(const Decimal256& from, UInt8& to,
UInt32, UInt32,
+ CastParameters&) {
+ to = (from.value != 0);
+ return true;
+}
+
+inline bool CastToBool::from_string(const StringRef& from, UInt8& to,
CastParameters&) {
+ return try_read_bool_text(to, from);
+}
+
template <CastModeType Mode>
class CastToImpl<Mode, DataTypeString, DataTypeBool> : public CastToBase {
public:
@@ -56,25 +151,50 @@ public:
return Status::OK();
}
};
-template <CastModeType AllMode, typename NumberOrDecimalType>
- requires(IsDataTypeNumber<NumberOrDecimalType> ||
IsDataTypeDecimal<NumberOrDecimalType>)
-class CastToImpl<AllMode, NumberOrDecimalType, DataTypeBool> : public
CastToBase {
+template <CastModeType AllMode, typename NumberType>
+ requires(IsDataTypeNumber<NumberType>)
+class CastToImpl<AllMode, NumberType, DataTypeBool> : public CastToBase {
public:
Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count,
const NullMap::value_type* null_map = nullptr) const
override {
- const auto* col_from = check_and_get_column<typename
NumberOrDecimalType::ColumnType>(
+ const auto* col_from = check_and_get_column<typename
NumberType::ColumnType>(
block.get_by_position(arguments[0]).column.get());
DataTypeBool::ColumnType::MutablePtr col_to =
DataTypeBool::ColumnType::create(input_rows_count);
+ CastParameters params;
+ params.is_strict = (AllMode == CastModeType::StrictMode);
for (size_t i = 0; i < input_rows_count; ++i) {
- if constexpr (IsDataTypeDecimal<NumberOrDecimalType>) {
- using NativeType = typename
NumberOrDecimalType::FieldType::NativeType;
- col_to->get_element(i) =
((NativeType)col_from->get_element(i)) != 0;
- } else {
- col_to->get_element(i) = col_from->get_element(i) != 0;
- }
+ CastToBool::from_number(col_from->get_element(i),
col_to->get_element(i), params);
+ }
+
+ block.get_by_position(result).column = std::move(col_to);
+ return Status::OK();
+ }
+};
+
+template <CastModeType AllMode, typename DecimalType>
+ requires(IsDataTypeDecimal<DecimalType>)
+class CastToImpl<AllMode, DecimalType, DataTypeBool> : public CastToBase {
+public:
+ Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
+ uint32_t result, size_t input_rows_count,
+ const NullMap::value_type* null_map = nullptr) const
override {
+ const auto* col_from = check_and_get_column<typename
DecimalType::ColumnType>(
+ block.get_by_position(arguments[0]).column.get());
+ const auto type_from = block.get_by_position(arguments[0]).type;
+ DataTypeBool::ColumnType::MutablePtr col_to =
+ DataTypeBool::ColumnType::create(input_rows_count);
+
+ CastParameters params;
+ params.is_strict = (AllMode == CastModeType::StrictMode);
+
+ auto precision = type_from->get_precision();
+ auto scale = type_from->get_scale();
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ CastToBool::from_decimal(col_from->get_element(i),
col_to->get_element(i), precision,
+ scale, params);
}
block.get_by_position(result).column = std::move(col_to);
@@ -83,7 +203,7 @@ public:
};
namespace CastWrapper {
-WrapperType create_boolean_wrapper(FunctionContext* context, const
DataTypePtr& from_type) {
+inline WrapperType create_boolean_wrapper(FunctionContext* context, const
DataTypePtr& from_type) {
std::shared_ptr<CastToBase> cast_to_bool;
auto make_bool_wrapper = [&](const auto& types) -> bool {
diff --git a/be/src/vec/io/io_helper.h b/be/src/vec/io/io_helper.h
index 117cd597581..ca45a64aa3d 100644
--- a/be/src/vec/io/io_helper.h
+++ b/be/src/vec/io/io_helper.h
@@ -349,7 +349,7 @@ bool try_read_datetime_v2_text(T& x, ReadBuffer& in, const
cctz::time_zone& loca
#include "common/compile_check_begin.h"
-bool inline try_read_bool_text(UInt8& x, StringRef& buf) {
+bool inline try_read_bool_text(UInt8& x, const StringRef& buf) {
StringParser::ParseResult result;
x = StringParser::string_to_bool(buf.data, buf.size, &result);
return result == StringParser::PARSE_SUCCESS;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]