This is an automated email from the ASF dual-hosted git repository.

zclllyybb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 90ec8ade9ba [opt](function) speed up md5 with AVX2 batch path (#63484)
90ec8ade9ba is described below

commit 90ec8ade9bafa2d0b9e7e18832cec22c84e03039
Author: zclllyybb <[email protected]>
AuthorDate: Fri May 22 10:40:39 2026 +0800

    [opt](function) speed up md5 with AVX2 batch path (#63484)
    
    Root cause: md5/md5sum evaluated every row through Md5Digest and
    OpenSSL, which leaves the vectorized string function path dominated by
    per-row scalar digest setup and hex materialization.
    
    Fix: add an AVX2 multi-buffer MD5 helper with scalar fallback, expose a
    batch hex API, and route single-argument md5/md5sum over
    ColumnString/ColumnVarbinary through the batch path while keeping
    multi-argument md5sum and sm3 on the existing digest implementation.
    
    test with sql:
    ```sql
    SET parallel_pipeline_task_num=1;
    SET enable_query_cache=false;
    SELECT SUM(ASCII(SUBSTRING(MD5(CAST(number AS STRING)), 1, 1)))
    FROM numbers("number" = "50000000");
    ```
    
    result:
    
    | version | times | avg | median |
    |---|---:|---:|---:|
    | upstream/master baseline | 8.59, 10.21, 9.52, 9.93, 8.85s | 9.42s |
    9.52s |
    | after AVX2 batch | 2.83, 2.84, 2.82, 2.79, 2.82s | 2.82s | 2.82s |
---
 be/src/exprs/function/function_string_digest.cpp |  47 +++
 be/src/util/md5.cpp                              | 380 ++++++++++++++++++++++-
 be/src/util/md5.h                                |   5 +
 be/test/exprs/function/function_string_test.cpp  | 104 +++++++
 be/test/util/md5_test.cpp                        | 148 +++++++++
 5 files changed, 677 insertions(+), 7 deletions(-)

diff --git a/be/src/exprs/function/function_string_digest.cpp 
b/be/src/exprs/function/function_string_digest.cpp
index e81300479b3..af4820ba770 100644
--- a/be/src/exprs/function/function_string_digest.cpp
+++ b/be/src/exprs/function/function_string_digest.cpp
@@ -16,7 +16,10 @@
 // under the License.
 
 #include <cstddef>
+#include <cstring>
 #include <string_view>
+#include <type_traits>
+#include <vector>
 
 #include "common/status.h"
 #include "core/assert_cast.h"
@@ -98,6 +101,14 @@ private:
                         const std::vector<ColumnPtr>& argument_columns,
                         const std::vector<uint8_t>& is_const, 
ColumnString::Chars& res_data,
                         ColumnString::Offsets& res_offset) const {
+        if constexpr (std::is_same_v<Impl, MD5Sum>) {
+            if (argument_columns.size() == 1) {
+                const auto* col = assert_cast<const 
ColumnType*>(argument_columns[0].get());
+                vector_execute_single_md5(col, input_rows_count, is_const[0], 
res_data, res_offset);
+                return;
+            }
+        }
+
         using ObjectData = typename Impl::ObjectData;
         for (size_t i = 0; i < input_rows_count; ++i) {
             ObjectData digest;
@@ -114,6 +125,42 @@ private:
                                         i, res_data, res_offset);
         }
     }
+
+    template <typename ColumnType>
+    void vector_execute_single_md5(const ColumnType* col, size_t 
input_rows_count, bool is_const,
+                                   ColumnString::Chars& res_data,
+                                   ColumnString::Offsets& res_offset) const {
+        ColumnString::check_chars_length(input_rows_count * MD5_HEX_LENGTH, 
input_rows_count);
+        res_data.resize(input_rows_count * MD5_HEX_LENGTH);
+        for (size_t i = 0; i < input_rows_count; ++i) {
+            res_offset[i] = (i + 1) * MD5_HEX_LENGTH;
+        }
+        if (input_rows_count == 0) {
+            return;
+        }
+
+        if (is_const) {
+            StringRef data_ref = col->get_data_at(0);
+            const unsigned char* input = reinterpret_cast<const unsigned 
char*>(data_ref.data);
+            size_t length = data_ref.size;
+            char digest[MD5_HEX_LENGTH];
+            md5_hex_batch(&input, &length, digest, 1);
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                std::memcpy(res_data.data() + i * MD5_HEX_LENGTH, digest, 
MD5_HEX_LENGTH);
+            }
+            return;
+        }
+
+        std::vector<const unsigned char*> inputs(input_rows_count);
+        std::vector<size_t> lengths(input_rows_count);
+        for (size_t i = 0; i < input_rows_count; ++i) {
+            StringRef data_ref = col->get_data_at(i);
+            inputs[i] = reinterpret_cast<const unsigned char*>(data_ref.data);
+            lengths[i] = data_ref.size;
+        }
+        md5_hex_batch(inputs.data(), lengths.data(), 
reinterpret_cast<char*>(res_data.data()),
+                      input_rows_count);
+    }
 };
 
 class FunctionStringDigestSHA1 : public IFunction {
diff --git a/be/src/util/md5.cpp b/be/src/util/md5.cpp
index bd3dda59dc8..b54e4eed8de 100644
--- a/be/src/util/md5.cpp
+++ b/be/src/util/md5.cpp
@@ -17,8 +17,359 @@
 
 #include "util/md5.h"
 
+#include <algorithm>
+#include <cstring>
+#include <vector>
+
+#ifdef __AVX2__
+#include <immintrin.h>
+#endif
+
+#include "exec/common/endian.h"
+
 namespace doris {
 
+namespace {
+
+constexpr uint32_t MD5_A0 = 0x67452301;
+constexpr uint32_t MD5_B0 = 0xefcdab89;
+constexpr uint32_t MD5_C0 = 0x98badcfe;
+constexpr uint32_t MD5_D0 = 0x10325476;
+constexpr unsigned char MD5_DUMMY_INPUT = 0;
+
+void md5_to_hex(const unsigned char* digest, char* out) {
+    static constexpr char DIGITS[] = "0123456789abcdef";
+    for (int i = 0; i < MD5_DIGEST_LENGTH; ++i) {
+        *out++ = DIGITS[digest[i] >> 4];
+        *out++ = DIGITS[digest[i] & 0x0F];
+    }
+}
+
+size_t md5_num_blocks(size_t len) {
+    return (len + 9 + 63) / 64;
+}
+
+size_t md5_pad_final_blocks(const unsigned char* data, size_t len, unsigned 
char* out) {
+    size_t full_blocks = len / 64;
+    size_t tail = len % 64;
+    size_t num_blocks = md5_num_blocks(len);
+    size_t final_count = num_blocks - full_blocks;
+
+    std::memset(out, 0, final_count * 64);
+    std::memcpy(out, data + full_blocks * 64, tail);
+    out[tail] = 0x80;
+    LittleEndian::Store64(out + final_count * 64 - 8, 
static_cast<uint64_t>(len) * 8);
+
+    return final_count;
+}
+
+#ifdef __AVX2__
+
+struct AVX2MD5Ops {
+    using Vec = __m256i;
+    static constexpr size_t LANES = 8;
+
+    static Vec add(Vec a, Vec b) { return _mm256_add_epi32(a, b); }
+
+    static Vec set1(uint32_t v) { return 
_mm256_set1_epi32(static_cast<int>(v)); }
+
+    static Vec loadu(const void* p) {
+        return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(p));
+    }
+
+    static void storeu(void* p, Vec v) { 
_mm256_storeu_si256(reinterpret_cast<__m256i*>(p), v); }
+
+    template <int N>
+    static Vec rotl(Vec x) {
+        return _mm256_or_si256(_mm256_slli_epi32(x, N), _mm256_srli_epi32(x, 
32 - N));
+    }
+
+    static Vec F(Vec b, Vec c, Vec d) {
+        return _mm256_xor_si256(d, _mm256_and_si256(b, _mm256_xor_si256(c, 
d)));
+    }
+
+    static Vec G(Vec b, Vec c, Vec d) {
+        return _mm256_xor_si256(c, _mm256_and_si256(d, _mm256_xor_si256(b, 
c)));
+    }
+
+    static Vec H(Vec b, Vec c, Vec d) { return _mm256_xor_si256(b, 
_mm256_xor_si256(c, d)); }
+
+    static Vec I(Vec b, Vec c, Vec d) {
+        return _mm256_xor_si256(c, _mm256_or_si256(b, _mm256_xor_si256(d, 
_mm256_set1_epi32(-1))));
+    }
+
+    static void gather_all_message_words(const unsigned char* const 
block_ptrs[], Vec msg[16]) {
+        for (int half = 0; half < 2; ++half) {
+            size_t off = half * 32;
+            Vec r0 = loadu(block_ptrs[0] + off);
+            Vec r1 = loadu(block_ptrs[1] + off);
+            Vec r2 = loadu(block_ptrs[2] + off);
+            Vec r3 = loadu(block_ptrs[3] + off);
+            Vec r4 = loadu(block_ptrs[4] + off);
+            Vec r5 = loadu(block_ptrs[5] + off);
+            Vec r6 = loadu(block_ptrs[6] + off);
+            Vec r7 = loadu(block_ptrs[7] + off);
+
+            Vec t0 = _mm256_unpacklo_epi32(r0, r1);
+            Vec t1 = _mm256_unpackhi_epi32(r0, r1);
+            Vec t2 = _mm256_unpacklo_epi32(r2, r3);
+            Vec t3 = _mm256_unpackhi_epi32(r2, r3);
+            Vec t4 = _mm256_unpacklo_epi32(r4, r5);
+            Vec t5 = _mm256_unpackhi_epi32(r4, r5);
+            Vec t6 = _mm256_unpacklo_epi32(r6, r7);
+            Vec t7 = _mm256_unpackhi_epi32(r6, r7);
+
+            Vec u0 = _mm256_unpacklo_epi64(t0, t2);
+            Vec u1 = _mm256_unpackhi_epi64(t0, t2);
+            Vec u2 = _mm256_unpacklo_epi64(t1, t3);
+            Vec u3 = _mm256_unpackhi_epi64(t1, t3);
+            Vec u4 = _mm256_unpacklo_epi64(t4, t6);
+            Vec u5 = _mm256_unpackhi_epi64(t4, t6);
+            Vec u6 = _mm256_unpacklo_epi64(t5, t7);
+            Vec u7 = _mm256_unpackhi_epi64(t5, t7);
+
+            size_t base = half * 8;
+            msg[base + 0] = _mm256_permute2x128_si256(u0, u4, 0x20);
+            msg[base + 4] = _mm256_permute2x128_si256(u0, u4, 0x31);
+            msg[base + 1] = _mm256_permute2x128_si256(u1, u5, 0x20);
+            msg[base + 5] = _mm256_permute2x128_si256(u1, u5, 0x31);
+            msg[base + 2] = _mm256_permute2x128_si256(u2, u6, 0x20);
+            msg[base + 6] = _mm256_permute2x128_si256(u2, u6, 0x31);
+            msg[base + 3] = _mm256_permute2x128_si256(u3, u7, 0x20);
+            msg[base + 7] = _mm256_permute2x128_si256(u3, u7, 0x31);
+        }
+    }
+};
+
+#define MD5_STEP_X2(func, w1, x1, y1, z1, w2, x2, y2, z2, g, s, ti) \
+    {                                                               \
+        Vec t1 = Ops::func(x1, y1, z1);                             \
+        Vec t2 = Ops::func(x2, y2, z2);                             \
+        t1 = Ops::add(t1, w1);                                      \
+        t2 = Ops::add(t2, w2);                                      \
+        Vec k = Ops::set1(ti);                                      \
+        t1 = Ops::add(t1, k);                                       \
+        t2 = Ops::add(t2, k);                                       \
+        t1 = Ops::add(t1, msg1[g]);                                 \
+        t2 = Ops::add(t2, msg2[g]);                                 \
+        (w1) = Ops::add(x1, Ops::template rotl<s>(t1));             \
+        (w2) = Ops::add(x2, Ops::template rotl<s>(t2));             \
+    }
+
+template <typename Ops>
+struct MD5X2State {
+    typename Ops::Vec a1, b1, c1, d1, a2, b2, c2, d2;
+};
+
+template <typename Ops>
+MD5X2State<Ops> md5_multi_buffer_block_x2(typename Ops::Vec a1, typename 
Ops::Vec b1,
+                                          typename Ops::Vec c1, typename 
Ops::Vec d1,
+                                          typename Ops::Vec a2, typename 
Ops::Vec b2,
+                                          typename Ops::Vec c2, typename 
Ops::Vec d2,
+                                          const typename Ops::Vec msg1[16],
+                                          const typename Ops::Vec msg2[16]) {
+    using Vec = typename Ops::Vec;
+    Vec aa1 = a1;
+    Vec bb1 = b1;
+    Vec cc1 = c1;
+    Vec dd1 = d1;
+    Vec aa2 = a2;
+    Vec bb2 = b2;
+    Vec cc2 = c2;
+    Vec dd2 = d2;
+
+    MD5_STEP_X2(F, a1, b1, c1, d1, a2, b2, c2, d2, 0, 7, 0xd76aa478)
+    MD5_STEP_X2(F, d1, a1, b1, c1, d2, a2, b2, c2, 1, 12, 0xe8c7b756)
+    MD5_STEP_X2(F, c1, d1, a1, b1, c2, d2, a2, b2, 2, 17, 0x242070db)
+    MD5_STEP_X2(F, b1, c1, d1, a1, b2, c2, d2, a2, 3, 22, 0xc1bdceee)
+    MD5_STEP_X2(F, a1, b1, c1, d1, a2, b2, c2, d2, 4, 7, 0xf57c0faf)
+    MD5_STEP_X2(F, d1, a1, b1, c1, d2, a2, b2, c2, 5, 12, 0x4787c62a)
+    MD5_STEP_X2(F, c1, d1, a1, b1, c2, d2, a2, b2, 6, 17, 0xa8304613)
+    MD5_STEP_X2(F, b1, c1, d1, a1, b2, c2, d2, a2, 7, 22, 0xfd469501)
+    MD5_STEP_X2(F, a1, b1, c1, d1, a2, b2, c2, d2, 8, 7, 0x698098d8)
+    MD5_STEP_X2(F, d1, a1, b1, c1, d2, a2, b2, c2, 9, 12, 0x8b44f7af)
+    MD5_STEP_X2(F, c1, d1, a1, b1, c2, d2, a2, b2, 10, 17, 0xffff5bb1)
+    MD5_STEP_X2(F, b1, c1, d1, a1, b2, c2, d2, a2, 11, 22, 0x895cd7be)
+    MD5_STEP_X2(F, a1, b1, c1, d1, a2, b2, c2, d2, 12, 7, 0x6b901122)
+    MD5_STEP_X2(F, d1, a1, b1, c1, d2, a2, b2, c2, 13, 12, 0xfd987193)
+    MD5_STEP_X2(F, c1, d1, a1, b1, c2, d2, a2, b2, 14, 17, 0xa679438e)
+    MD5_STEP_X2(F, b1, c1, d1, a1, b2, c2, d2, a2, 15, 22, 0x49b40821)
+
+    MD5_STEP_X2(G, a1, b1, c1, d1, a2, b2, c2, d2, 1, 5, 0xf61e2562)
+    MD5_STEP_X2(G, d1, a1, b1, c1, d2, a2, b2, c2, 6, 9, 0xc040b340)
+    MD5_STEP_X2(G, c1, d1, a1, b1, c2, d2, a2, b2, 11, 14, 0x265e5a51)
+    MD5_STEP_X2(G, b1, c1, d1, a1, b2, c2, d2, a2, 0, 20, 0xe9b6c7aa)
+    MD5_STEP_X2(G, a1, b1, c1, d1, a2, b2, c2, d2, 5, 5, 0xd62f105d)
+    MD5_STEP_X2(G, d1, a1, b1, c1, d2, a2, b2, c2, 10, 9, 0x02441453)
+    MD5_STEP_X2(G, c1, d1, a1, b1, c2, d2, a2, b2, 15, 14, 0xd8a1e681)
+    MD5_STEP_X2(G, b1, c1, d1, a1, b2, c2, d2, a2, 4, 20, 0xe7d3fbc8)
+    MD5_STEP_X2(G, a1, b1, c1, d1, a2, b2, c2, d2, 9, 5, 0x21e1cde6)
+    MD5_STEP_X2(G, d1, a1, b1, c1, d2, a2, b2, c2, 14, 9, 0xc33707d6)
+    MD5_STEP_X2(G, c1, d1, a1, b1, c2, d2, a2, b2, 3, 14, 0xf4d50d87)
+    MD5_STEP_X2(G, b1, c1, d1, a1, b2, c2, d2, a2, 8, 20, 0x455a14ed)
+    MD5_STEP_X2(G, a1, b1, c1, d1, a2, b2, c2, d2, 13, 5, 0xa9e3e905)
+    MD5_STEP_X2(G, d1, a1, b1, c1, d2, a2, b2, c2, 2, 9, 0xfcefa3f8)
+    MD5_STEP_X2(G, c1, d1, a1, b1, c2, d2, a2, b2, 7, 14, 0x676f02d9)
+    MD5_STEP_X2(G, b1, c1, d1, a1, b2, c2, d2, a2, 12, 20, 0x8d2a4c8a)
+
+    MD5_STEP_X2(H, a1, b1, c1, d1, a2, b2, c2, d2, 5, 4, 0xfffa3942)
+    MD5_STEP_X2(H, d1, a1, b1, c1, d2, a2, b2, c2, 8, 11, 0x8771f681)
+    MD5_STEP_X2(H, c1, d1, a1, b1, c2, d2, a2, b2, 11, 16, 0x6d9d6122)
+    MD5_STEP_X2(H, b1, c1, d1, a1, b2, c2, d2, a2, 14, 23, 0xfde5380c)
+    MD5_STEP_X2(H, a1, b1, c1, d1, a2, b2, c2, d2, 1, 4, 0xa4beea44)
+    MD5_STEP_X2(H, d1, a1, b1, c1, d2, a2, b2, c2, 4, 11, 0x4bdecfa9)
+    MD5_STEP_X2(H, c1, d1, a1, b1, c2, d2, a2, b2, 7, 16, 0xf6bb4b60)
+    MD5_STEP_X2(H, b1, c1, d1, a1, b2, c2, d2, a2, 10, 23, 0xbebfbc70)
+    MD5_STEP_X2(H, a1, b1, c1, d1, a2, b2, c2, d2, 13, 4, 0x289b7ec6)
+    MD5_STEP_X2(H, d1, a1, b1, c1, d2, a2, b2, c2, 0, 11, 0xeaa127fa)
+    MD5_STEP_X2(H, c1, d1, a1, b1, c2, d2, a2, b2, 3, 16, 0xd4ef3085)
+    MD5_STEP_X2(H, b1, c1, d1, a1, b2, c2, d2, a2, 6, 23, 0x04881d05)
+    MD5_STEP_X2(H, a1, b1, c1, d1, a2, b2, c2, d2, 9, 4, 0xd9d4d039)
+    MD5_STEP_X2(H, d1, a1, b1, c1, d2, a2, b2, c2, 12, 11, 0xe6db99e5)
+    MD5_STEP_X2(H, c1, d1, a1, b1, c2, d2, a2, b2, 15, 16, 0x1fa27cf8)
+    MD5_STEP_X2(H, b1, c1, d1, a1, b2, c2, d2, a2, 2, 23, 0xc4ac5665)
+
+    MD5_STEP_X2(I, a1, b1, c1, d1, a2, b2, c2, d2, 0, 6, 0xf4292244)
+    MD5_STEP_X2(I, d1, a1, b1, c1, d2, a2, b2, c2, 7, 10, 0x432aff97)
+    MD5_STEP_X2(I, c1, d1, a1, b1, c2, d2, a2, b2, 14, 15, 0xab9423a7)
+    MD5_STEP_X2(I, b1, c1, d1, a1, b2, c2, d2, a2, 5, 21, 0xfc93a039)
+    MD5_STEP_X2(I, a1, b1, c1, d1, a2, b2, c2, d2, 12, 6, 0x655b59c3)
+    MD5_STEP_X2(I, d1, a1, b1, c1, d2, a2, b2, c2, 3, 10, 0x8f0ccc92)
+    MD5_STEP_X2(I, c1, d1, a1, b1, c2, d2, a2, b2, 10, 15, 0xffeff47d)
+    MD5_STEP_X2(I, b1, c1, d1, a1, b2, c2, d2, a2, 1, 21, 0x85845dd1)
+    MD5_STEP_X2(I, a1, b1, c1, d1, a2, b2, c2, d2, 8, 6, 0x6fa87e4f)
+    MD5_STEP_X2(I, d1, a1, b1, c1, d2, a2, b2, c2, 15, 10, 0xfe2ce6e0)
+    MD5_STEP_X2(I, c1, d1, a1, b1, c2, d2, a2, b2, 6, 15, 0xa3014314)
+    MD5_STEP_X2(I, b1, c1, d1, a1, b2, c2, d2, a2, 13, 21, 0x4e0811a1)
+    MD5_STEP_X2(I, a1, b1, c1, d1, a2, b2, c2, d2, 4, 6, 0xf7537e82)
+    MD5_STEP_X2(I, d1, a1, b1, c1, d2, a2, b2, c2, 11, 10, 0xbd3af235)
+    MD5_STEP_X2(I, c1, d1, a1, b1, c2, d2, a2, b2, 2, 15, 0x2ad7d2bb)
+    MD5_STEP_X2(I, b1, c1, d1, a1, b2, c2, d2, a2, 9, 21, 0xeb86d391)
+
+    return {Ops::add(a1, aa1), Ops::add(b1, bb1), Ops::add(c1, cc1), 
Ops::add(d1, dd1),
+            Ops::add(a2, aa2), Ops::add(b2, bb2), Ops::add(c2, cc2), 
Ops::add(d2, dd2)};
+}
+
+#undef MD5_STEP_X2
+
+template <typename Ops>
+uint32_t extract_lane(typename Ops::Vec v, size_t lane) {
+    alignas(32) uint32_t values[Ops::LANES];
+    Ops::storeu(values, v);
+    return values[lane];
+}
+
+template <typename Ops>
+void md5_multi_buffer_compute(const unsigned char* const inputs[], const 
size_t lengths[],
+                              unsigned char* outputs, size_t count) {
+    constexpr size_t N = Ops::LANES;
+    using Vec = typename Ops::Vec;
+    size_t count1 = std::min(count, N);
+    size_t count2 = count > N ? count - N : 0;
+
+    size_t num_blocks[2 * N];
+    size_t max_blocks = 0;
+    for (size_t i = 0; i < count; ++i) {
+        num_blocks[i] = md5_num_blocks(lengths[i]);
+        max_blocks = std::max(max_blocks, num_blocks[i]);
+    }
+    for (size_t i = count; i < 2 * N; ++i) {
+        num_blocks[i] = 1;
+    }
+
+    alignas(32) unsigned char final_buf[2 * N][128];
+    size_t final_block_start[2 * N];
+    size_t final_block_count[2 * N];
+    for (size_t i = 0; i < count; ++i) {
+        final_block_start[i] = lengths[i] / 64;
+        final_block_count[i] = md5_pad_final_blocks(inputs[i], lengths[i], 
final_buf[i]);
+    }
+    for (size_t i = count; i < 2 * N; ++i) {
+        final_block_start[i] = 0;
+        final_block_count[i] = md5_pad_final_blocks(&MD5_DUMMY_INPUT, 0, 
final_buf[i]);
+    }
+
+    Vec a1 = Ops::set1(MD5_A0);
+    Vec b1 = Ops::set1(MD5_B0);
+    Vec c1 = Ops::set1(MD5_C0);
+    Vec d1 = Ops::set1(MD5_D0);
+    Vec a2 = Ops::set1(MD5_A0);
+    Vec b2 = Ops::set1(MD5_B0);
+    Vec c2 = Ops::set1(MD5_C0);
+    Vec d2 = Ops::set1(MD5_D0);
+
+    for (size_t block = 0; block < max_blocks; ++block) {
+        const unsigned char* block_ptrs[2 * N];
+        for (size_t i = 0; i < 2 * N; ++i) {
+            if (block < final_block_start[i]) {
+                block_ptrs[i] = inputs[i] + block * 64;
+            } else {
+                size_t final_index = block - final_block_start[i];
+                block_ptrs[i] = final_index < final_block_count[i] ? 
final_buf[i] + final_index * 64
+                                                                   : 
final_buf[i];
+            }
+        }
+
+        Vec msg1[16];
+        Vec msg2[16];
+        Ops::gather_all_message_words(block_ptrs, msg1);
+        Ops::gather_all_message_words(block_ptrs + N, msg2);
+
+        auto st = md5_multi_buffer_block_x2<Ops>(a1, b1, c1, d1, a2, b2, c2, 
d2, msg1, msg2);
+        a1 = st.a1;
+        b1 = st.b1;
+        c1 = st.c1;
+        d1 = st.d1;
+        a2 = st.a2;
+        b2 = st.b2;
+        c2 = st.c2;
+        d2 = st.d2;
+
+        for (size_t lane = 0; lane < count1; ++lane) {
+            if (block + 1 == num_blocks[lane]) {
+                unsigned char* out = outputs + lane * MD5_DIGEST_LENGTH;
+                LittleEndian::Store32(out, extract_lane<Ops>(a1, lane));
+                LittleEndian::Store32(out + 4, extract_lane<Ops>(b1, lane));
+                LittleEndian::Store32(out + 8, extract_lane<Ops>(c1, lane));
+                LittleEndian::Store32(out + 12, extract_lane<Ops>(d1, lane));
+            }
+        }
+        for (size_t lane = 0; lane < count2; ++lane) {
+            if (block + 1 == num_blocks[N + lane]) {
+                unsigned char* out = outputs + (N + lane) * MD5_DIGEST_LENGTH;
+                LittleEndian::Store32(out, extract_lane<Ops>(a2, lane));
+                LittleEndian::Store32(out + 4, extract_lane<Ops>(b2, lane));
+                LittleEndian::Store32(out + 8, extract_lane<Ops>(c2, lane));
+                LittleEndian::Store32(out + 12, extract_lane<Ops>(d2, lane));
+            }
+        }
+    }
+}
+
+void md5_binary_batch_avx2(const unsigned char* const inputs[], const size_t 
lengths[],
+                           unsigned char* outputs, size_t count) {
+    constexpr size_t BATCH = 2 * AVX2MD5Ops::LANES;
+    for (size_t base = 0; base < count; base += BATCH) {
+        size_t batch = std::min(BATCH, count - base);
+        const unsigned char* batch_inputs[BATCH];
+        size_t batch_lengths[BATCH];
+        for (size_t i = 0; i < batch; ++i) {
+            batch_inputs[i] = lengths[base + i] == 0 ? &MD5_DUMMY_INPUT : 
inputs[base + i];
+            batch_lengths[i] = lengths[base + i];
+        }
+        for (size_t i = batch; i < BATCH; ++i) {
+            batch_inputs[i] = &MD5_DUMMY_INPUT;
+            batch_lengths[i] = 0;
+        }
+        md5_multi_buffer_compute<AVX2MD5Ops>(batch_inputs, batch_lengths,
+                                             outputs + base * 
MD5_DIGEST_LENGTH, batch);
+    }
+}
+
+#endif
+
+} // namespace
+
 Md5Digest::Md5Digest() {
     MD5_Init(&_md5_ctx);
 }
@@ -31,15 +382,30 @@ void Md5Digest::digest() {
     unsigned char buf[MD5_DIGEST_LENGTH];
     MD5_Final(buf, &_md5_ctx);
 
-    char hex_buf[2 * MD5_DIGEST_LENGTH];
+    char hex_buf[MD5_HEX_LENGTH];
+    md5_to_hex(buf, hex_buf);
+    _hex.assign(hex_buf, MD5_HEX_LENGTH);
+}
 
-    static char dig_vec_lower[] = "0123456789abcdef";
-    char* to = hex_buf;
-    for (int i = 0; i < MD5_DIGEST_LENGTH; ++i) {
-        *to++ = dig_vec_lower[buf[i] >> 4];
-        *to++ = dig_vec_lower[buf[i] & 0x0F];
+void md5_hex_batch(const unsigned char* const inputs[], const size_t 
lengths[], char* outputs,
+                   size_t count) {
+    if (count == 0) {
+        return;
+    }
+
+#ifdef __AVX2__
+    std::vector<unsigned char> digests(count * MD5_DIGEST_LENGTH);
+    md5_binary_batch_avx2(inputs, lengths, digests.data(), count);
+    for (size_t i = 0; i < count; ++i) {
+        md5_to_hex(digests.data() + i * MD5_DIGEST_LENGTH, outputs + i * 
MD5_HEX_LENGTH);
+    }
+#else
+    for (size_t i = 0; i < count; ++i) {
+        unsigned char digest[MD5_DIGEST_LENGTH];
+        MD5(lengths[i] == 0 ? &MD5_DUMMY_INPUT : inputs[i], lengths[i], 
digest);
+        md5_to_hex(digest, outputs + i * MD5_HEX_LENGTH);
     }
-    _hex.assign(hex_buf, 2 * MD5_DIGEST_LENGTH);
+#endif
 }
 
 } // namespace doris
diff --git a/be/src/util/md5.h b/be/src/util/md5.h
index b6dca97c2b9..ac82a68dceb 100644
--- a/be/src/util/md5.h
+++ b/be/src/util/md5.h
@@ -24,6 +24,8 @@
 
 namespace doris {
 
+static constexpr size_t MD5_HEX_LENGTH = 2 * MD5_DIGEST_LENGTH;
+
 class Md5Digest {
 public:
     Md5Digest();
@@ -38,4 +40,7 @@ private:
     std::string _hex;
 };
 
+void md5_hex_batch(const unsigned char* const inputs[], const size_t 
lengths[], char* outputs,
+                   size_t count);
+
 } // namespace doris
diff --git a/be/test/exprs/function/function_string_test.cpp 
b/be/test/exprs/function/function_string_test.cpp
index 37231fd354c..53df245904c 100644
--- a/be/test/exprs/function/function_string_test.cpp
+++ b/be/test/exprs/function/function_string_test.cpp
@@ -27,10 +27,56 @@
 #include "core/types.h"
 #include "exprs/function/function_test_util.h"
 #include "util/encryption_util.h"
+#include "util/md5.h"
 
 namespace doris {
 using namespace ut_type;
 
+namespace {
+
+std::string md5_hex_for_test(const std::string& input) {
+    Md5Digest digest;
+    digest.update(input.data(), input.size());
+    digest.digest();
+    return digest.hex();
+}
+
+std::string make_md5_test_input(size_t length, size_t seed) {
+    std::string input(length, '\0');
+    for (size_t i = 0; i < length; ++i) {
+        input[i] = static_cast<char>('!' + ((i * 17 + seed * 31) % 94));
+    }
+    return input;
+}
+
+std::string make_md5_binary_input(size_t length, size_t seed) {
+    std::string input(length, '\0');
+    for (size_t i = 0; i < length; ++i) {
+        input[i] = static_cast<char>((i * 37 + seed * 13) & 0xff);
+    }
+    return input;
+}
+
+DataSet make_md5_string_dataset(const std::vector<std::string>& inputs) {
+    DataSet data_set;
+    data_set.reserve(inputs.size());
+    for (const auto& input : inputs) {
+        data_set.push_back({{input}, {md5_hex_for_test(input)}});
+    }
+    return data_set;
+}
+
+DataSet make_md5_varbinary_dataset(const std::vector<std::string>& inputs) {
+    DataSet data_set;
+    data_set.reserve(inputs.size());
+    for (const auto& input : inputs) {
+        data_set.push_back({{VARBINARY(input)}, {md5_hex_for_test(input)}});
+    }
+    return data_set;
+}
+
+} // namespace
+
 TEST(function_string_test, function_string_substr_test) {
     std::string func_name = "substr";
 
@@ -1970,6 +2016,22 @@ TEST(function_string_test, function_md5sum_test) {
         check_function_all_arg_comb<DataTypeString, true>(func_name, 
input_types, data_set);
     }
 
+    {
+        InputTypeSet input_types = {PrimitiveType::TYPE_VARCHAR};
+        std::vector<std::string> inputs;
+        const std::vector<size_t> lengths = {0,  1,   15,  16,  31,  32,  55,  
56,  57,  63,  64,
+                                             65, 119, 120, 121, 127, 128, 129, 
255, 256, 1024};
+        inputs.reserve(lengths.size());
+        for (size_t i = 0; i < lengths.size(); ++i) {
+            inputs.push_back(make_md5_test_input(lengths[i], i));
+        }
+
+        DataSet data_set = make_md5_string_dataset(inputs);
+        check_function_all_arg_comb<DataTypeString, true>(func_name, 
input_types, data_set);
+        check_function_all_arg_comb<DataTypeString, true>(std::string("md5"), 
input_types,
+                                                          data_set);
+    }
+
     {
         InputTypeSet input_types = {PrimitiveType::TYPE_VARCHAR, 
PrimitiveType::TYPE_VARCHAR};
         DataSet data_set = {{{std::string("asd"), std::string("你好")},
@@ -1988,6 +2050,10 @@ TEST(function_string_test, function_md5sum_test) {
                                     PrimitiveType::TYPE_VARCHAR};
         DataSet data_set = {{{std::string("a"), std::string("sd"), 
std::string("你好")},
                              
{std::string("a38c15675555017e6b8ea042f2eb24f5")}},
+                            {{std::string("a"), std::string(""), 
std::string("b")},
+                             
{std::string("187ef4436122d1cc2f40dc2b92f0eba0")}},
+                            {{std::string(""), std::string("abc"), 
std::string("")},
+                             
{std::string("900150983cd24fb0d6963f7d28e17f72")}},
                             {{std::string(""), std::string(""), 
std::string("")},
                              
{std::string("d41d8cd98f00b204e9800998ecf8427e")}},
                             {{std::string("HEL"), std::string("LO,!"), 
std::string("^%")},
@@ -2012,6 +2078,28 @@ TEST(function_string_test, function_md5sum_test) {
         check_function_all_arg_comb<DataTypeString, true>(func_name, 
input_types, data_set);
     }
 
+    {
+        InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY};
+        std::string all_bytes;
+        all_bytes.reserve(256);
+        for (int i = 0; i < 256; ++i) {
+            all_bytes.push_back(static_cast<char>(i));
+        }
+
+        std::vector<std::string> inputs = {
+                std::string(1, '\0'),          std::string("a\0b", 3),
+                std::string("\0\0\0\0", 4),    all_bytes,
+                all_bytes + all_bytes,         make_md5_binary_input(55, 1),
+                make_md5_binary_input(56, 2),  make_md5_binary_input(57, 3),
+                make_md5_binary_input(64, 4),  make_md5_binary_input(65, 5),
+                make_md5_binary_input(128, 6), make_md5_binary_input(129, 7)};
+
+        DataSet data_set = make_md5_varbinary_dataset(inputs);
+        check_function_all_arg_comb<DataTypeString, true>(func_name, 
input_types, data_set);
+        check_function_all_arg_comb<DataTypeString, true>(std::string("md5"), 
input_types,
+                                                          data_set);
+    }
+
     {
         InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, 
PrimitiveType::TYPE_VARBINARY};
         DataSet data_set = {{{VARBINARY("asd"), VARBINARY("你好")},
@@ -2025,11 +2113,27 @@ TEST(function_string_test, function_md5sum_test) {
         check_function_all_arg_comb<DataTypeString, true>(func_name, 
input_types, data_set);
     }
 
+    {
+        InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, 
PrimitiveType::TYPE_VARBINARY};
+        std::string left = std::string("a\0b", 3);
+        std::string right = std::string("\0c", 2);
+        std::string empty;
+        std::string bytes = make_md5_binary_input(64, 8);
+        DataSet data_set = {{{VARBINARY(left), VARBINARY(right)}, 
{md5_hex_for_test(left + right)}},
+                            {{VARBINARY(empty), VARBINARY(bytes)}, 
{md5_hex_for_test(bytes)}}};
+
+        check_function_all_arg_comb<DataTypeString, true>(func_name, 
input_types, data_set);
+    }
+
     {
         InputTypeSet input_types = {PrimitiveType::TYPE_VARBINARY, 
PrimitiveType::TYPE_VARBINARY,
                                     PrimitiveType::TYPE_VARBINARY};
         DataSet data_set = {{{VARBINARY("a"), VARBINARY("sd"), 
VARBINARY("你好")},
                              
{std::string("a38c15675555017e6b8ea042f2eb24f5")}},
+                            {{VARBINARY("a"), VARBINARY(""), VARBINARY("b")},
+                             
{std::string("187ef4436122d1cc2f40dc2b92f0eba0")}},
+                            {{VARBINARY(""), VARBINARY("abc"), VARBINARY("")},
+                             
{std::string("900150983cd24fb0d6963f7d28e17f72")}},
                             {{VARBINARY(""), VARBINARY(""), VARBINARY("")},
                              
{std::string("d41d8cd98f00b204e9800998ecf8427e")}},
                             {{VARBINARY("HEL"), VARBINARY("LO,!"), 
VARBINARY("^%")},
diff --git a/be/test/util/md5_test.cpp b/be/test/util/md5_test.cpp
index f73db6efbf1..15591626c1c 100644
--- a/be/test/util/md5_test.cpp
+++ b/be/test/util/md5_test.cpp
@@ -20,9 +20,76 @@
 #include <gtest/gtest-message.h>
 #include <gtest/gtest-test-part.h>
 
+#include <algorithm>
+#include <array>
+#include <cstdint>
+#include <string>
+#include <utility>
+#include <vector>
+
 #include "gtest/gtest_pred_impl.h"
 
 namespace doris {
+namespace {
+
+std::string scalar_md5_hex(const std::string& input) {
+    Md5Digest digest;
+    digest.update(input.data(), input.size());
+    digest.digest();
+    return digest.hex();
+}
+
+std::string make_patterned_input(size_t length, size_t seed) {
+    std::string input(length, '\0');
+    for (size_t i = 0; i < length; ++i) {
+        input[i] = static_cast<char>('!' + ((i * 17 + seed * 31) % 94));
+    }
+    return input;
+}
+
+std::string make_binary_input(size_t length, size_t seed) {
+    std::string input(length, '\0');
+    for (size_t i = 0; i < length; ++i) {
+        input[i] = static_cast<char>((i * 37 + seed * 13) & 0xff);
+    }
+    return input;
+}
+
+bool is_lower_hex_digest(const std::string& value) {
+    return value.size() == MD5_HEX_LENGTH && std::all_of(value.begin(), 
value.end(), [](char c) {
+               return ('0' <= c && c <= '9') || ('a' <= c && c <= 'f');
+           });
+}
+
+void expect_batch_matches_scalar(const std::vector<std::string>& inputs) {
+    if (inputs.empty()) {
+        std::array<char, MD5_HEX_LENGTH> sentinel;
+        sentinel.fill('x');
+        md5_hex_batch(nullptr, nullptr, sentinel.data(), 0);
+        EXPECT_TRUE(std::all_of(sentinel.begin(), sentinel.end(), [](char c) { 
return c == 'x'; }));
+        return;
+    }
+
+    std::vector<const unsigned char*> data(inputs.size());
+    std::vector<size_t> lengths(inputs.size());
+    for (size_t i = 0; i < inputs.size(); ++i) {
+        data[i] = reinterpret_cast<const unsigned char*>(inputs[i].data());
+        lengths[i] = inputs[i].size();
+    }
+
+    std::vector<char> output(inputs.size() * MD5_HEX_LENGTH);
+    md5_hex_batch(data.data(), lengths.data(), output.data(), inputs.size());
+
+    for (size_t i = 0; i < inputs.size(); ++i) {
+        const std::string actual(output.data() + i * MD5_HEX_LENGTH, 
MD5_HEX_LENGTH);
+        EXPECT_TRUE(is_lower_hex_digest(actual))
+                << "input index " << i << ", length " << inputs[i].size();
+        EXPECT_EQ(scalar_md5_hex(inputs[i]), actual)
+                << "input index " << i << ", length " << inputs[i].size();
+    }
+}
+
+} // namespace
 
 class Md5Test : public testing::Test {
 public:
@@ -43,4 +110,85 @@ TEST_F(Md5Test, normal) {
     EXPECT_STREQ("7ac66c0f148de9519b8bd264312c4d64", digest.hex().c_str());
 }
 
+TEST_F(Md5Test, batch_known_vectors) {
+    const std::vector<std::pair<std::string, std::string>> cases = {
+            {"", "d41d8cd98f00b204e9800998ecf8427e"},
+            {"a", "0cc175b9c0f1b6a831c399e269772661"},
+            {"abc", "900150983cd24fb0d6963f7d28e17f72"},
+            {"message digest", "f96b697d7cb7938d525a2f31aaf161d0"},
+            {"abcdefghijklmnopqrstuvwxyz", "c3fcd3d76192e4007dfb496cca67e13b"},
+            {"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789",
+             "d174ab98d277d9f5a5611c2c9f419d9f"},
+            
{"1234567890123456789012345678901234567890123456789012345678901234567890"
+             "1234567890",
+             "57edf4a22be3c955ac49da2e2107b67a"}};
+
+    std::vector<std::string> inputs;
+    std::vector<const unsigned char*> data(cases.size());
+    std::vector<size_t> lengths(cases.size());
+    inputs.reserve(cases.size());
+    for (const auto& entry : cases) {
+        inputs.push_back(entry.first);
+    }
+    for (size_t i = 0; i < inputs.size(); ++i) {
+        data[i] = reinterpret_cast<const unsigned char*>(inputs[i].data());
+        lengths[i] = inputs[i].size();
+    }
+
+    std::vector<char> output(inputs.size() * MD5_HEX_LENGTH);
+    md5_hex_batch(data.data(), lengths.data(), output.data(), inputs.size());
+
+    for (size_t i = 0; i < cases.size(); ++i) {
+        const std::string actual(output.data() + i * MD5_HEX_LENGTH, 
MD5_HEX_LENGTH);
+        EXPECT_TRUE(is_lower_hex_digest(actual)) << "input index " << i;
+        EXPECT_EQ(cases[i].second, actual) << "input index " << i;
+    }
+}
+
+TEST_F(Md5Test, batch_block_padding_boundaries) {
+    const std::vector<size_t> lengths = {0,  1,   15,  16,  31,  32,  55,  56, 
 57,  63,  64,
+                                         65, 119, 120, 121, 127, 128, 129, 
255, 256, 4096};
+    std::vector<std::string> inputs;
+    inputs.reserve(lengths.size());
+    for (size_t i = 0; i < lengths.size(); ++i) {
+        inputs.push_back(make_patterned_input(lengths[i], i));
+    }
+
+    expect_batch_matches_scalar(inputs);
+}
+
+TEST_F(Md5Test, batch_size_boundaries) {
+    const std::vector<size_t> counts = {0, 1, 7, 8, 9, 15, 16, 17, 31, 32, 33, 
47, 48, 49};
+    const std::vector<size_t> boundary_lengths = {0, 1, 55, 56, 57, 63, 64, 
65, 119, 120, 121};
+
+    for (size_t count : counts) {
+        SCOPED_TRACE("count=" + std::to_string(count));
+        std::vector<std::string> inputs;
+        inputs.reserve(count);
+        for (size_t i = 0; i < count; ++i) {
+            const size_t length = (i % 3 == 0) ? boundary_lengths[(i / 3) % 
boundary_lengths.size()]
+                                               : ((i * 37 + count * 11) % 190);
+            inputs.push_back(make_patterned_input(length, i + count));
+        }
+        expect_batch_matches_scalar(inputs);
+    }
+}
+
+TEST_F(Md5Test, batch_binary_payloads) {
+    std::string all_bytes;
+    all_bytes.reserve(256);
+    for (int i = 0; i < 256; ++i) {
+        all_bytes.push_back(static_cast<char>(i));
+    }
+
+    std::vector<std::string> inputs = {std::string(1, '\0'),       
std::string("a\0b", 3),
+                                       std::string("\0\0\0\0", 4), all_bytes,
+                                       all_bytes + all_bytes,      
make_binary_input(55, 1),
+                                       make_binary_input(56, 2),   
make_binary_input(57, 3),
+                                       make_binary_input(64, 4),   
make_binary_input(65, 5),
+                                       make_binary_input(128, 6),  
make_binary_input(129, 7)};
+
+    expect_batch_matches_scalar(inputs);
+}
+
 } // namespace doris


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to