This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit b271cfa4e39c8bb9d68a1e388d65ca3dc6adecf2 Author: amory <[email protected]> AuthorDate: Tue Apr 23 16:18:21 2024 +0800 [FIX]fix cidr func with const param (#33968) --- be/src/vec/functions/function_ip.h | 127 ++++++++++++++++----- .../ip_functions/test_ipv4_cidr_to_range.out | 30 +++++ .../test_ipv6_cidr_to_range_function.out | 33 ++++++ .../ip_functions/test_ipv4_cidr_to_range.groovy | 3 + .../test_ipv6_cidr_to_range_function.groovy | 6 +- 5 files changed, 171 insertions(+), 28 deletions(-) diff --git a/be/src/vec/functions/function_ip.h b/be/src/vec/functions/function_ip.h index 4fd1c38282b..9f70d4b3afa 100644 --- a/be/src/vec/functions/function_ip.h +++ b/be/src/vec/functions/function_ip.h @@ -770,8 +770,8 @@ public: ColumnWithTypeAndName& ip_column = block.get_by_position(arguments[0]); ColumnWithTypeAndName& cidr_column = block.get_by_position(arguments[1]); - const ColumnPtr& ip_column_ptr = ip_column.column; - const ColumnPtr& cidr_column_ptr = cidr_column.column; + const auto& [ip_column_ptr, ip_col_const] = unpack_if_const(ip_column.column); + const auto& [cidr_column_ptr, cidr_col_const] = unpack_if_const(cidr_column.column); const auto* col_ip_column = check_and_get_column<ColumnVector<IPv4>>(ip_column_ptr.get()); const auto* col_cidr_column = @@ -787,15 +787,41 @@ public: static constexpr UInt8 max_cidr_mask = IPV4_BINARY_LENGTH * 8; - for (size_t i = 0; i < input_rows_count; ++i) { - auto ip = vec_ip_input[i]; - auto cidr = vec_cidr_input[i]; - if (0 <= cidr && cidr <= max_cidr_mask) { + if (ip_col_const) { + auto ip = vec_ip_input[0]; + for (size_t i = 0; i < input_rows_count; ++i) { + auto cidr = vec_cidr_input[i]; + if (cidr < 0 || cidr > max_cidr_mask) { + throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal cidr value '{}'", + std::to_string(cidr)); + } + auto range = apply_cidr_mask(ip, cidr); + vec_lower_range_output[i] = range.first; + vec_upper_range_output[i] = range.second; + } + } else if (cidr_col_const) { + auto cidr = vec_cidr_input[0]; + if (cidr < 0 || cidr > max_cidr_mask) { + throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal cidr value '{}'", + std::to_string(cidr)); + } + for (size_t i = 0; i < input_rows_count; ++i) { + auto ip = vec_ip_input[i]; + auto range = apply_cidr_mask(ip, cidr); + vec_lower_range_output[i] = range.first; + vec_upper_range_output[i] = range.second; + } + } else { + for (size_t i = 0; i < input_rows_count; ++i) { + auto ip = vec_ip_input[i]; + auto cidr = vec_cidr_input[i]; + if (cidr < 0 || cidr > max_cidr_mask) { + throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal cidr value '{}'", + std::to_string(cidr)); + } auto range = apply_cidr_mask(ip, cidr); vec_lower_range_output[i] = range.first; vec_upper_range_output[i] = range.second; - } else { - return Status::InvalidArgument("Invalid row {}, cidr is out of range", i); } } @@ -855,17 +881,22 @@ public: const auto& cidr_column_with_type_and_name = block.get_by_position(arguments[1]); WhichDataType addr_type(addr_column_with_type_and_name.type); WhichDataType cidr_type(cidr_column_with_type_and_name.type); - const ColumnPtr& addr_column = addr_column_with_type_and_name.column; - const ColumnPtr& cidr_column = cidr_column_with_type_and_name.column; + const auto& [addr_column, add_col_const] = + unpack_if_const(addr_column_with_type_and_name.column); + const auto& [cidr_column, col_const] = + unpack_if_const(cidr_column_with_type_and_name.column); + const auto* cidr_col = assert_cast<const ColumnInt16*>(cidr_column.get()); ColumnPtr col_res = nullptr; if (addr_type.is_ipv6()) { const auto* ipv6_addr_column = check_and_get_column<ColumnIPv6>(addr_column.get()); - col_res = execute_impl<ColumnIPv6>(*ipv6_addr_column, *cidr_col, input_rows_count); + col_res = execute_impl<ColumnIPv6>(*ipv6_addr_column, *cidr_col, input_rows_count, + add_col_const, col_const); } else if (addr_type.is_string()) { const auto* str_addr_column = check_and_get_column<ColumnString>(addr_column.get()); - col_res = execute_impl<ColumnString>(*str_addr_column, *cidr_col, input_rows_count); + col_res = execute_impl<ColumnString>(*str_addr_column, *cidr_col, input_rows_count, + add_col_const, col_const); } else { return Status::RuntimeError( "Illegal column {} of argument of function {}, Expected IPv6 or String", @@ -878,7 +909,8 @@ public: template <typename FromColumn> static ColumnPtr execute_impl(const FromColumn& from_column, const ColumnInt16& cidr_column, - size_t input_rows_count) { + size_t input_rows_count, bool is_addr_const = false, + bool is_cidr_const = false) { auto col_res_lower_range = ColumnIPv6::create(input_rows_count, 0); auto col_res_upper_range = ColumnIPv6::create(input_rows_count, 0); auto& vec_res_lower_range = col_res_lower_range->get_data(); @@ -886,26 +918,67 @@ public: static constexpr UInt8 max_cidr_mask = IPV6_BINARY_LENGTH * 8; - for (size_t i = 0; i < input_rows_count; ++i) { - auto cidr = cidr_column.get_int(i); + if (is_addr_const) { + for (size_t i = 0; i < input_rows_count; ++i) { + auto cidr = cidr_column.get_int(i); + if (cidr < 0 || cidr > max_cidr_mask) { + throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal cidr value '{}'", + std::to_string(cidr)); + } + if constexpr (std::is_same_v<FromColumn, ColumnString>) { + // 16 bytes ipv6 string is stored in big-endian byte order + // so transfer to little-endian firstly + auto* src_data = const_cast<char*>(from_column.get_data_at(0).data); + std::reverse(src_data, src_data + IPV6_BINARY_LENGTH); + apply_cidr_mask(src_data, reinterpret_cast<char*>(&vec_res_lower_range[i]), + reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr); + } else { + apply_cidr_mask(from_column.get_data_at(0).data, + reinterpret_cast<char*>(&vec_res_lower_range[i]), + reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr); + } + } + } else if (is_cidr_const) { + auto cidr = cidr_column.get_int(0); if (cidr < 0 || cidr > max_cidr_mask) { throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal cidr value '{}'", std::to_string(cidr)); } - if constexpr (std::is_same_v<FromColumn, ColumnString>) { - // 16 bytes ipv6 string is stored in big-endian byte order - // so transfer to little-endian firstly - auto* src_data = const_cast<char*>(from_column.get_data_at(i).data); - std::reverse(src_data, src_data + IPV6_BINARY_LENGTH); - apply_cidr_mask(src_data, reinterpret_cast<char*>(&vec_res_lower_range[i]), - reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr); - } else { - apply_cidr_mask(from_column.get_data_at(i).data, - reinterpret_cast<char*>(&vec_res_lower_range[i]), - reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr); + for (size_t i = 0; i < input_rows_count; ++i) { + if constexpr (std::is_same_v<FromColumn, ColumnString>) { + // 16 bytes ipv6 string is stored in big-endian byte order + // so transfer to little-endian firstly + auto* src_data = const_cast<char*>(from_column.get_data_at(i).data); + std::reverse(src_data, src_data + IPV6_BINARY_LENGTH); + apply_cidr_mask(src_data, reinterpret_cast<char*>(&vec_res_lower_range[i]), + reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr); + } else { + apply_cidr_mask(from_column.get_data_at(i).data, + reinterpret_cast<char*>(&vec_res_lower_range[i]), + reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr); + } + } + } else { + for (size_t i = 0; i < input_rows_count; ++i) { + auto cidr = cidr_column.get_int(i); + if (cidr < 0 || cidr > max_cidr_mask) { + throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal cidr value '{}'", + std::to_string(cidr)); + } + if constexpr (std::is_same_v<FromColumn, ColumnString>) { + // 16 bytes ipv6 string is stored in big-endian byte order + // so transfer to little-endian firstly + auto* src_data = const_cast<char*>(from_column.get_data_at(i).data); + std::reverse(src_data, src_data + IPV6_BINARY_LENGTH); + apply_cidr_mask(src_data, reinterpret_cast<char*>(&vec_res_lower_range[i]), + reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr); + } else { + apply_cidr_mask(from_column.get_data_at(i).data, + reinterpret_cast<char*>(&vec_res_lower_range[i]), + reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr); + } } } - return ColumnStruct::create( Columns {std::move(col_res_lower_range), std::move(col_res_upper_range)}); } diff --git a/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.out b/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.out index 035426afe54..acfd87789e4 100644 --- a/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.out +++ b/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.out @@ -8,6 +8,36 @@ 6 127.0.0.0 127.255.255.255 7 0.0.0.0 255.255.255.255 +-- !sql -- +1 \N +2 {"min": "127.0.0.0", "max": "127.0.255.255"} +3 {"min": "127.0.0.0", "max": "127.0.255.255"} +4 {"min": "127.0.0.0", "max": "127.0.255.255"} +5 {"min": "127.0.0.0", "max": "127.0.255.255"} +6 {"min": "127.0.0.0", "max": "127.0.255.255"} +7 {"min": "127.0.0.0", "max": "127.0.255.255"} + +-- !sql -- +1 {"min": "0.0.0.0", "max": "255.255.255.255"} +2 \N +3 {"min": "127.0.0.1", "max": "127.0.0.1"} +4 {"min": "127.0.0.0", "max": "127.0.0.255"} +5 {"min": "127.0.0.0", "max": "127.0.255.255"} +6 {"min": "127.0.0.0", "max": "127.255.255.255"} +7 {"min": "0.0.0.0", "max": "255.255.255.255"} + +-- !sql -- +0 {"min": "127.0.0.0", "max": "127.0.255.255"} +1 {"min": "127.0.0.0", "max": "127.0.255.255"} +2 {"min": "127.0.0.0", "max": "127.0.255.255"} +3 {"min": "127.0.0.0", "max": "127.0.255.255"} +4 {"min": "127.0.0.0", "max": "127.0.255.255"} +5 {"min": "127.0.0.0", "max": "127.0.255.255"} +6 {"min": "127.0.0.0", "max": "127.0.255.255"} +7 {"min": "127.0.0.0", "max": "127.0.255.255"} +8 {"min": "127.0.0.0", "max": "127.0.255.255"} +9 {"min": "127.0.0.0", "max": "127.0.255.255"} + -- !sql -- \N diff --git a/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.out b/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.out index 6ffd4ee56f7..201af35987a 100644 --- a/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.out +++ b/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.out @@ -1,3 +1,4 @@ +-- This file is automatically generated. You should know what you did if you want to edit this -- !sql -- 1 :: ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 2 2001:db8:: 2001:db8:ffff:ffff:ffff:ffff:ffff:ffff @@ -9,6 +10,37 @@ 8 \N \N 9 \N \N +-- !sql -- +1 {"min": "2001::", "max": "2001:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +2 {"min": "2001::", "max": "2001:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +3 {"min": "ffff::", "max": "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +4 {"min": "2001::", "max": "2001:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +5 {"min": "2001::", "max": "2001:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +6 {"min": "::", "max": "0:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +7 {"min": "ffff::", "max": "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +8 \N +9 {"min": "ffff::", "max": "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} + +-- !sql -- +1 {"min": "::", "max": "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +4 {"min": "3132:372e::", "max": "3132:372e:ffff:ffff:ffff:ffff:ffff:ffff"} +5 {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +7 {"min": "b000::", "max": "bfff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +8 {"min": "be00::", "max": "beff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +9 \N + +-- !sql -- +0 {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +1 {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +2 {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +3 {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +4 {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +5 {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +6 {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +7 {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +8 {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} +9 {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} + -- !sql -- 1 :: ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 2 2001:db8:: 2001:db8:ffff:ffff:ffff:ffff:ffff:ffff @@ -34,3 +66,4 @@ -- !sql -- {"min": "f000::", "max": "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"} + diff --git a/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.groovy b/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.groovy index eda4174de99..2c0333ecb56 100644 --- a/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.groovy +++ b/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.groovy @@ -46,6 +46,9 @@ suite("test_ipv4_cidr_to_range") { """ qt_sql "select id, struct_element(ipv4_cidr_to_range(addr, cidr), 'min') as min_range, struct_element(ipv4_cidr_to_range(addr, cidr), 'max') as max_range from test_ipv4_cidr_to_range order by id" + qt_sql "select id, ipv4_cidr_to_range(addr, 16) from test_ipv4_cidr_to_range order by id;" + qt_sql """ select id, ipv4_cidr_to_range("127.0.0.1", cidr) from test_ipv4_cidr_to_range order by id;""" + qt_sql """ select number, ipv4_cidr_to_range("127.0.0.1", 16) from numbers("number"="10") order by number;""" sql """ DROP TABLE IF EXISTS test_ipv4_cidr_to_range """ diff --git a/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.groovy b/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.groovy index 9b19c22bc2c..d1fe0e2761a 100644 --- a/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.groovy +++ b/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.groovy @@ -48,6 +48,10 @@ suite("test_ipv6_cidr_to_range_function") { """ qt_sql "select id, struct_element(ipv6_cidr_to_range(addr, cidr), 'min') as min_range, struct_element(ipv6_cidr_to_range(addr, cidr), 'max') as max_range from test_ipv6_cidr_to_range_function order by id" + qt_sql "select id, ipv6_cidr_to_range(addr, 16) from test_ipv6_cidr_to_range_function order by id;" + sql """ delete from test_ipv6_cidr_to_range_function where id in (2,3,6);""" + qt_sql """ select id, ipv6_cidr_to_range("127.0.0.1", cidr) from test_ipv6_cidr_to_range_function order by id;""" + qt_sql """ select number, ipv6_cidr_to_range("127.0.0.1", 16) from numbers("number"="10") order by number;""" sql """ DROP TABLE IF EXISTS test_ipv6_cidr_to_range_function """ sql """ DROP TABLE IF EXISTS test_str_cidr_to_range_function """ @@ -87,4 +91,4 @@ suite("test_ipv6_cidr_to_range_function") { qt_sql "select ipv6_cidr_to_range(ipv6_string_to_num('ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'), 64)" qt_sql "select ipv6_cidr_to_range(ipv6_string_to_num('0000:0000:0000:0000:0000:0000:0000:0000'), 8)" qt_sql "select ipv6_cidr_to_range(ipv6_string_to_num('ffff:0000:0000:0000:0000:0000:0000:0000'), 4)" -} \ No newline at end of file +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
