Copilot commented on code in PR #52967:
URL: https://github.com/apache/doris/pull/52967#discussion_r2209518641


##########
be/test/util/frame_of_reference_coding_test.cpp:
##########
@@ -300,4 +356,27 @@ TEST_F(TestForCoding, accuracy_test) {
     }
 }
 
+TEST_F(TestForCoding, accracy2_test) {

Review Comment:
   Typo in test name: `accracy2_test` should be `accuracy2_test` for 
consistency.
   ```suggestion
   TEST_F(TestForCoding, accuracy2_test) {
   ```



##########
be/benchmark/benchmark_bit_unpack.hpp:
##########
@@ -0,0 +1,111 @@
+// Licensed to the Apache Software Foundation (ASF) under one

Review Comment:
   This header lacks an include guard or `#pragma once`. Consider adding one to 
prevent multiple inclusion.



##########
be/src/util/frame_of_reference_coding.cpp:
##########
@@ -421,28 +422,133 @@ bool ForDecoder<T>::init() {
 }
 
 // todo(kks): improve this method by SIMD instructions
+
+template <typename T>
+template <typename U>
+void ForDecoder<T>::bit_unpack_optimize(const uint8_t* input, uint8_t in_num, 
int bit_width,
+                                        T* output) {
+    static_assert(std::is_same<U, int64_t>::value || std::is_same<U, 
__int128_t>::value,
+                  "bit_unpack_optimize only supports U = int64_t or 
__int128_t");
+    constexpr int u_size = sizeof(U);                   // Size of U
+    constexpr int u_size_shift = (u_size == 8) ? 3 : 4; // log2(u_size)
+    int valid_bit = 0;                                  // How many valid bits
+    int need_bit = 0;                                   // still need
+    size_t input_size = (in_num * bit_width + 7) >> 3;  // input's size
+    int full_batch_size = (input_size >> u_size_shift)
+                          << u_size_shift;      // Adjust input_size to a 
multiple of u_size
+    int tail_count = input_size & (u_size - 1); // The remainder of input_size 
modulo u_size.
+    // The number of bits in input to adjust to multiples of 8 and thus more
+    int more_bit = (input_size << 3) - (in_num * bit_width);
+
+    // to ensure that only bit_width bits are valid
+    T output_mask;
+    if (bit_width >= static_cast<int>(sizeof(T) * 8)) {
+        output_mask = static_cast<T>(~T(0));
+    } else {
+        output_mask = static_cast<T>((static_cast<T>(1) << bit_width) - 1);
+    }
+
+    U s = 0; // Temporary buffer for bitstream: aggregates input bytes into a 
large integer for unpacking
+
+    for (int i = 0; i < full_batch_size; i += u_size) {
+        s = 0;
+
+        if constexpr (u_size == 8) {
+            s = to_endian<std::endian::big>(*((int64_t*)(input + i)));
+        } else if constexpr (u_size == 16) {
+            s = to_endian<std::endian::big>(*((__int128_t*)(input + i)));
+        }
+
+        // Determine what the valid bits are based on u_size
+        valid_bit = u_size << 3;
+
+        // If input_size is exactly a multiple of 8, then need to remove the 
last more_bit in the last loop.
+        if (tail_count == 0 && i == full_batch_size - u_size) {
+            valid_bit -= more_bit;
+            s >>= more_bit;
+        }
+
+        if (need_bit) {
+            // The last time we take away the high bit_width - need_bit,
+            // we need to make up the rest of the need_bit from the width.
+            // Use valid_bit - need_bit to compute high need_bit bits of s
+            // perform an AND operation to ensure that only need_bit bits are 
valid
+            *output |= ((s >> (valid_bit - need_bit)) & ((static_cast<U>(1) << 
need_bit) - 1));
+            output++;
+            valid_bit -= need_bit;
+        }
+
+        int num = valid_bit / bit_width;             // How many outputs can 
be processed at a time
+        int remainder = valid_bit - num * bit_width; // How many bits are left 
to store
+
+        // Starting with the highest valid bit, take out bit_width bits in 
sequence
+        // perform an AND operation with output_mask to ensure that only 
bit_width bits are valid
+        // (num-j-1) * bit_width used to calculate how many bits need to be 
removed at the end
+        // But since there are still remainder bits that can't be processed, 
need to add the remainder
+        for (int j = 0; j < num; j++) {
+            *output =
+                    static_cast<T>((s >> (((num - j - 1) * bit_width) + 
remainder)) & output_mask);
+            output++;
+        }
+
+        if (remainder) {
+            // Process the last remaining remainder bit.
+            // y = (s & ((static_cast<U>(1) << remainder) - 1)) extract the 
last remainder bits.
+            // ouput = y << (bit_width - reaminder) Use the high bit_width - 
remainder bit

Review Comment:
   Fix typos in comment: change `ouput` to `output` and `reaminder` to 
`remainder`.
   ```suggestion
               // output = y << (bit_width - remainder) Use the high bit_width 
- remainder bit
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to