This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit b271cfa4e39c8bb9d68a1e388d65ca3dc6adecf2
Author: amory <[email protected]>
AuthorDate: Tue Apr 23 16:18:21 2024 +0800

    [FIX]fix cidr func with const param (#33968)
---
 be/src/vec/functions/function_ip.h                 | 127 ++++++++++++++++-----
 .../ip_functions/test_ipv4_cidr_to_range.out       |  30 +++++
 .../test_ipv6_cidr_to_range_function.out           |  33 ++++++
 .../ip_functions/test_ipv4_cidr_to_range.groovy    |   3 +
 .../test_ipv6_cidr_to_range_function.groovy        |   6 +-
 5 files changed, 171 insertions(+), 28 deletions(-)

diff --git a/be/src/vec/functions/function_ip.h 
b/be/src/vec/functions/function_ip.h
index 4fd1c38282b..9f70d4b3afa 100644
--- a/be/src/vec/functions/function_ip.h
+++ b/be/src/vec/functions/function_ip.h
@@ -770,8 +770,8 @@ public:
         ColumnWithTypeAndName& ip_column = block.get_by_position(arguments[0]);
         ColumnWithTypeAndName& cidr_column = 
block.get_by_position(arguments[1]);
 
-        const ColumnPtr& ip_column_ptr = ip_column.column;
-        const ColumnPtr& cidr_column_ptr = cidr_column.column;
+        const auto& [ip_column_ptr, ip_col_const] = 
unpack_if_const(ip_column.column);
+        const auto& [cidr_column_ptr, cidr_col_const] = 
unpack_if_const(cidr_column.column);
 
         const auto* col_ip_column = 
check_and_get_column<ColumnVector<IPv4>>(ip_column_ptr.get());
         const auto* col_cidr_column =
@@ -787,15 +787,41 @@ public:
 
         static constexpr UInt8 max_cidr_mask = IPV4_BINARY_LENGTH * 8;
 
-        for (size_t i = 0; i < input_rows_count; ++i) {
-            auto ip = vec_ip_input[i];
-            auto cidr = vec_cidr_input[i];
-            if (0 <= cidr && cidr <= max_cidr_mask) {
+        if (ip_col_const) {
+            auto ip = vec_ip_input[0];
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                auto cidr = vec_cidr_input[i];
+                if (cidr < 0 || cidr > max_cidr_mask) {
+                    throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal cidr 
value '{}'",
+                                    std::to_string(cidr));
+                }
+                auto range = apply_cidr_mask(ip, cidr);
+                vec_lower_range_output[i] = range.first;
+                vec_upper_range_output[i] = range.second;
+            }
+        } else if (cidr_col_const) {
+            auto cidr = vec_cidr_input[0];
+            if (cidr < 0 || cidr > max_cidr_mask) {
+                throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal cidr 
value '{}'",
+                                std::to_string(cidr));
+            }
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                auto ip = vec_ip_input[i];
+                auto range = apply_cidr_mask(ip, cidr);
+                vec_lower_range_output[i] = range.first;
+                vec_upper_range_output[i] = range.second;
+            }
+        } else {
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                auto ip = vec_ip_input[i];
+                auto cidr = vec_cidr_input[i];
+                if (cidr < 0 || cidr > max_cidr_mask) {
+                    throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal cidr 
value '{}'",
+                                    std::to_string(cidr));
+                }
                 auto range = apply_cidr_mask(ip, cidr);
                 vec_lower_range_output[i] = range.first;
                 vec_upper_range_output[i] = range.second;
-            } else {
-                return Status::InvalidArgument("Invalid row {}, cidr is out of 
range", i);
             }
         }
 
@@ -855,17 +881,22 @@ public:
         const auto& cidr_column_with_type_and_name = 
block.get_by_position(arguments[1]);
         WhichDataType addr_type(addr_column_with_type_and_name.type);
         WhichDataType cidr_type(cidr_column_with_type_and_name.type);
-        const ColumnPtr& addr_column = addr_column_with_type_and_name.column;
-        const ColumnPtr& cidr_column = cidr_column_with_type_and_name.column;
+        const auto& [addr_column, add_col_const] =
+                unpack_if_const(addr_column_with_type_and_name.column);
+        const auto& [cidr_column, col_const] =
+                unpack_if_const(cidr_column_with_type_and_name.column);
+
         const auto* cidr_col = assert_cast<const 
ColumnInt16*>(cidr_column.get());
         ColumnPtr col_res = nullptr;
 
         if (addr_type.is_ipv6()) {
             const auto* ipv6_addr_column = 
check_and_get_column<ColumnIPv6>(addr_column.get());
-            col_res = execute_impl<ColumnIPv6>(*ipv6_addr_column, *cidr_col, 
input_rows_count);
+            col_res = execute_impl<ColumnIPv6>(*ipv6_addr_column, *cidr_col, 
input_rows_count,
+                                               add_col_const, col_const);
         } else if (addr_type.is_string()) {
             const auto* str_addr_column = 
check_and_get_column<ColumnString>(addr_column.get());
-            col_res = execute_impl<ColumnString>(*str_addr_column, *cidr_col, 
input_rows_count);
+            col_res = execute_impl<ColumnString>(*str_addr_column, *cidr_col, 
input_rows_count,
+                                                 add_col_const, col_const);
         } else {
             return Status::RuntimeError(
                     "Illegal column {} of argument of function {}, Expected 
IPv6 or String",
@@ -878,7 +909,8 @@ public:
 
     template <typename FromColumn>
     static ColumnPtr execute_impl(const FromColumn& from_column, const 
ColumnInt16& cidr_column,
-                                  size_t input_rows_count) {
+                                  size_t input_rows_count, bool is_addr_const 
= false,
+                                  bool is_cidr_const = false) {
         auto col_res_lower_range = ColumnIPv6::create(input_rows_count, 0);
         auto col_res_upper_range = ColumnIPv6::create(input_rows_count, 0);
         auto& vec_res_lower_range = col_res_lower_range->get_data();
@@ -886,26 +918,67 @@ public:
 
         static constexpr UInt8 max_cidr_mask = IPV6_BINARY_LENGTH * 8;
 
-        for (size_t i = 0; i < input_rows_count; ++i) {
-            auto cidr = cidr_column.get_int(i);
+        if (is_addr_const) {
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                auto cidr = cidr_column.get_int(i);
+                if (cidr < 0 || cidr > max_cidr_mask) {
+                    throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal cidr 
value '{}'",
+                                    std::to_string(cidr));
+                }
+                if constexpr (std::is_same_v<FromColumn, ColumnString>) {
+                    // 16 bytes ipv6 string is stored in big-endian byte order
+                    // so transfer to little-endian firstly
+                    auto* src_data = 
const_cast<char*>(from_column.get_data_at(0).data);
+                    std::reverse(src_data, src_data + IPV6_BINARY_LENGTH);
+                    apply_cidr_mask(src_data, 
reinterpret_cast<char*>(&vec_res_lower_range[i]),
+                                    
reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr);
+                } else {
+                    apply_cidr_mask(from_column.get_data_at(0).data,
+                                    
reinterpret_cast<char*>(&vec_res_lower_range[i]),
+                                    
reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr);
+                }
+            }
+        } else if (is_cidr_const) {
+            auto cidr = cidr_column.get_int(0);
             if (cidr < 0 || cidr > max_cidr_mask) {
                 throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal cidr 
value '{}'",
                                 std::to_string(cidr));
             }
-            if constexpr (std::is_same_v<FromColumn, ColumnString>) {
-                // 16 bytes ipv6 string is stored in big-endian byte order
-                // so transfer to little-endian firstly
-                auto* src_data = 
const_cast<char*>(from_column.get_data_at(i).data);
-                std::reverse(src_data, src_data + IPV6_BINARY_LENGTH);
-                apply_cidr_mask(src_data, 
reinterpret_cast<char*>(&vec_res_lower_range[i]),
-                                
reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr);
-            } else {
-                apply_cidr_mask(from_column.get_data_at(i).data,
-                                
reinterpret_cast<char*>(&vec_res_lower_range[i]),
-                                
reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr);
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                if constexpr (std::is_same_v<FromColumn, ColumnString>) {
+                    // 16 bytes ipv6 string is stored in big-endian byte order
+                    // so transfer to little-endian firstly
+                    auto* src_data = 
const_cast<char*>(from_column.get_data_at(i).data);
+                    std::reverse(src_data, src_data + IPV6_BINARY_LENGTH);
+                    apply_cidr_mask(src_data, 
reinterpret_cast<char*>(&vec_res_lower_range[i]),
+                                    
reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr);
+                } else {
+                    apply_cidr_mask(from_column.get_data_at(i).data,
+                                    
reinterpret_cast<char*>(&vec_res_lower_range[i]),
+                                    
reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr);
+                }
+            }
+        } else {
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                auto cidr = cidr_column.get_int(i);
+                if (cidr < 0 || cidr > max_cidr_mask) {
+                    throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal cidr 
value '{}'",
+                                    std::to_string(cidr));
+                }
+                if constexpr (std::is_same_v<FromColumn, ColumnString>) {
+                    // 16 bytes ipv6 string is stored in big-endian byte order
+                    // so transfer to little-endian firstly
+                    auto* src_data = 
const_cast<char*>(from_column.get_data_at(i).data);
+                    std::reverse(src_data, src_data + IPV6_BINARY_LENGTH);
+                    apply_cidr_mask(src_data, 
reinterpret_cast<char*>(&vec_res_lower_range[i]),
+                                    
reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr);
+                } else {
+                    apply_cidr_mask(from_column.get_data_at(i).data,
+                                    
reinterpret_cast<char*>(&vec_res_lower_range[i]),
+                                    
reinterpret_cast<char*>(&vec_res_upper_range[i]), cidr);
+                }
             }
         }
-
         return ColumnStruct::create(
                 Columns {std::move(col_res_lower_range), 
std::move(col_res_upper_range)});
     }
diff --git 
a/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.out
 
b/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.out
index 035426afe54..acfd87789e4 100644
--- 
a/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.out
+++ 
b/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.out
@@ -8,6 +8,36 @@
 6      127.0.0.0       127.255.255.255
 7      0.0.0.0 255.255.255.255
 
+-- !sql --
+1      \N
+2      {"min": "127.0.0.0", "max": "127.0.255.255"}
+3      {"min": "127.0.0.0", "max": "127.0.255.255"}
+4      {"min": "127.0.0.0", "max": "127.0.255.255"}
+5      {"min": "127.0.0.0", "max": "127.0.255.255"}
+6      {"min": "127.0.0.0", "max": "127.0.255.255"}
+7      {"min": "127.0.0.0", "max": "127.0.255.255"}
+
+-- !sql --
+1      {"min": "0.0.0.0", "max": "255.255.255.255"}
+2      \N
+3      {"min": "127.0.0.1", "max": "127.0.0.1"}
+4      {"min": "127.0.0.0", "max": "127.0.0.255"}
+5      {"min": "127.0.0.0", "max": "127.0.255.255"}
+6      {"min": "127.0.0.0", "max": "127.255.255.255"}
+7      {"min": "0.0.0.0", "max": "255.255.255.255"}
+
+-- !sql --
+0      {"min": "127.0.0.0", "max": "127.0.255.255"}
+1      {"min": "127.0.0.0", "max": "127.0.255.255"}
+2      {"min": "127.0.0.0", "max": "127.0.255.255"}
+3      {"min": "127.0.0.0", "max": "127.0.255.255"}
+4      {"min": "127.0.0.0", "max": "127.0.255.255"}
+5      {"min": "127.0.0.0", "max": "127.0.255.255"}
+6      {"min": "127.0.0.0", "max": "127.0.255.255"}
+7      {"min": "127.0.0.0", "max": "127.0.255.255"}
+8      {"min": "127.0.0.0", "max": "127.0.255.255"}
+9      {"min": "127.0.0.0", "max": "127.0.255.255"}
+
 -- !sql --
 \N
 
diff --git 
a/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.out
 
b/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.out
index 6ffd4ee56f7..201af35987a 100644
--- 
a/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.out
+++ 
b/regression-test/data/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.out
@@ -1,3 +1,4 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
 -- !sql --
 1      ::      ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff
 2      2001:db8::      2001:db8:ffff:ffff:ffff:ffff:ffff:ffff
@@ -9,6 +10,37 @@
 8      \N      \N
 9      \N      \N
 
+-- !sql --
+1      {"min": "2001::", "max": "2001:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+2      {"min": "2001::", "max": "2001:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+3      {"min": "ffff::", "max": "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+4      {"min": "2001::", "max": "2001:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+5      {"min": "2001::", "max": "2001:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+6      {"min": "::", "max": "0:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+7      {"min": "ffff::", "max": "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+8      \N
+9      {"min": "ffff::", "max": "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+
+-- !sql --
+1      {"min": "::", "max": "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+4      {"min": "3132:372e::", "max": "3132:372e:ffff:ffff:ffff:ffff:ffff:ffff"}
+5      {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+7      {"min": "b000::", "max": "bfff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+8      {"min": "be00::", "max": "beff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+9      \N
+
+-- !sql --
+0      {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+1      {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+2      {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+3      {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+4      {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+5      {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+6      {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+7      {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+8      {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+9      {"min": "3132::", "max": "3132:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+
 -- !sql --
 1      ::      ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff
 2      2001:db8::      2001:db8:ffff:ffff:ffff:ffff:ffff:ffff
@@ -34,3 +66,4 @@
 
 -- !sql --
 {"min": "f000::", "max": "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}
+
diff --git 
a/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.groovy
 
b/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.groovy
index eda4174de99..2c0333ecb56 100644
--- 
a/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.groovy
+++ 
b/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv4_cidr_to_range.groovy
@@ -46,6 +46,9 @@ suite("test_ipv4_cidr_to_range") {
         """
 
     qt_sql "select id, struct_element(ipv4_cidr_to_range(addr, cidr), 'min') 
as min_range, struct_element(ipv4_cidr_to_range(addr, cidr), 'max') as 
max_range from test_ipv4_cidr_to_range order by id"
+    qt_sql "select id, ipv4_cidr_to_range(addr, 16) from 
test_ipv4_cidr_to_range order by id;"
+    qt_sql """ select id, ipv4_cidr_to_range("127.0.0.1", cidr) from 
test_ipv4_cidr_to_range order by id;"""
+    qt_sql """ select number, ipv4_cidr_to_range("127.0.0.1", 16) from 
numbers("number"="10") order by number;"""
 
     sql """ DROP TABLE IF EXISTS test_ipv4_cidr_to_range """
 
diff --git 
a/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.groovy
 
b/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.groovy
index 9b19c22bc2c..d1fe0e2761a 100644
--- 
a/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.groovy
+++ 
b/regression-test/suites/query_p0/sql_functions/ip_functions/test_ipv6_cidr_to_range_function.groovy
@@ -48,6 +48,10 @@ suite("test_ipv6_cidr_to_range_function") {
         """
 
     qt_sql "select id, struct_element(ipv6_cidr_to_range(addr, cidr), 'min') 
as min_range, struct_element(ipv6_cidr_to_range(addr, cidr), 'max') as 
max_range from test_ipv6_cidr_to_range_function order by id"
+    qt_sql "select id, ipv6_cidr_to_range(addr, 16) from 
test_ipv6_cidr_to_range_function order by id;"
+    sql """ delete from test_ipv6_cidr_to_range_function where id in 
(2,3,6);"""
+    qt_sql """ select id, ipv6_cidr_to_range("127.0.0.1", cidr) from 
test_ipv6_cidr_to_range_function order by id;"""
+    qt_sql """ select number, ipv6_cidr_to_range("127.0.0.1", 16) from 
numbers("number"="10") order by number;"""
 
     sql """ DROP TABLE IF EXISTS test_ipv6_cidr_to_range_function """
     sql """ DROP TABLE IF EXISTS test_str_cidr_to_range_function """
@@ -87,4 +91,4 @@ suite("test_ipv6_cidr_to_range_function") {
     qt_sql "select 
ipv6_cidr_to_range(ipv6_string_to_num('ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'),
 64)"
     qt_sql "select 
ipv6_cidr_to_range(ipv6_string_to_num('0000:0000:0000:0000:0000:0000:0000:0000'),
 8)"
     qt_sql "select 
ipv6_cidr_to_range(ipv6_string_to_num('ffff:0000:0000:0000:0000:0000:0000:0000'),
 4)"
-}
\ No newline at end of file
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to