Copilot commented on code in PR #52967:
URL: https://github.com/apache/doris/pull/52967#discussion_r2209518641
##########
be/test/util/frame_of_reference_coding_test.cpp:
##########
@@ -300,4 +356,27 @@ TEST_F(TestForCoding, accuracy_test) {
}
}
+TEST_F(TestForCoding, accracy2_test) {
Review Comment:
Typo in test name: `accracy2_test` should be `accuracy2_test` for
consistency.
```suggestion
TEST_F(TestForCoding, accuracy2_test) {
```
##########
be/benchmark/benchmark_bit_unpack.hpp:
##########
@@ -0,0 +1,111 @@
+// Licensed to the Apache Software Foundation (ASF) under one
Review Comment:
This header lacks an include guard or `#pragma once`. Consider adding one to
prevent multiple inclusion.
##########
be/src/util/frame_of_reference_coding.cpp:
##########
@@ -421,28 +422,133 @@ bool ForDecoder<T>::init() {
}
// todo(kks): improve this method by SIMD instructions
+
+template <typename T>
+template <typename U>
+void ForDecoder<T>::bit_unpack_optimize(const uint8_t* input, uint8_t in_num,
int bit_width,
+ T* output) {
+ static_assert(std::is_same<U, int64_t>::value || std::is_same<U,
__int128_t>::value,
+ "bit_unpack_optimize only supports U = int64_t or
__int128_t");
+ constexpr int u_size = sizeof(U); // Size of U
+ constexpr int u_size_shift = (u_size == 8) ? 3 : 4; // log2(u_size)
+ int valid_bit = 0; // How many valid bits
+ int need_bit = 0; // still need
+ size_t input_size = (in_num * bit_width + 7) >> 3; // input's size
+ int full_batch_size = (input_size >> u_size_shift)
+ << u_size_shift; // Adjust input_size to a
multiple of u_size
+ int tail_count = input_size & (u_size - 1); // The remainder of input_size
modulo u_size.
+ // The number of bits in input to adjust to multiples of 8 and thus more
+ int more_bit = (input_size << 3) - (in_num * bit_width);
+
+ // to ensure that only bit_width bits are valid
+ T output_mask;
+ if (bit_width >= static_cast<int>(sizeof(T) * 8)) {
+ output_mask = static_cast<T>(~T(0));
+ } else {
+ output_mask = static_cast<T>((static_cast<T>(1) << bit_width) - 1);
+ }
+
+ U s = 0; // Temporary buffer for bitstream: aggregates input bytes into a
large integer for unpacking
+
+ for (int i = 0; i < full_batch_size; i += u_size) {
+ s = 0;
+
+ if constexpr (u_size == 8) {
+ s = to_endian<std::endian::big>(*((int64_t*)(input + i)));
+ } else if constexpr (u_size == 16) {
+ s = to_endian<std::endian::big>(*((__int128_t*)(input + i)));
+ }
+
+ // Determine what the valid bits are based on u_size
+ valid_bit = u_size << 3;
+
+ // If input_size is exactly a multiple of 8, then need to remove the
last more_bit in the last loop.
+ if (tail_count == 0 && i == full_batch_size - u_size) {
+ valid_bit -= more_bit;
+ s >>= more_bit;
+ }
+
+ if (need_bit) {
+ // The last time we take away the high bit_width - need_bit,
+ // we need to make up the rest of the need_bit from the width.
+ // Use valid_bit - need_bit to compute high need_bit bits of s
+ // perform an AND operation to ensure that only need_bit bits are
valid
+ *output |= ((s >> (valid_bit - need_bit)) & ((static_cast<U>(1) <<
need_bit) - 1));
+ output++;
+ valid_bit -= need_bit;
+ }
+
+ int num = valid_bit / bit_width; // How many outputs can
be processed at a time
+ int remainder = valid_bit - num * bit_width; // How many bits are left
to store
+
+ // Starting with the highest valid bit, take out bit_width bits in
sequence
+ // perform an AND operation with output_mask to ensure that only
bit_width bits are valid
+ // (num-j-1) * bit_width used to calculate how many bits need to be
removed at the end
+ // But since there are still remainder bits that can't be processed,
need to add the remainder
+ for (int j = 0; j < num; j++) {
+ *output =
+ static_cast<T>((s >> (((num - j - 1) * bit_width) +
remainder)) & output_mask);
+ output++;
+ }
+
+ if (remainder) {
+ // Process the last remaining remainder bit.
+ // y = (s & ((static_cast<U>(1) << remainder) - 1)) extract the
last remainder bits.
+ // ouput = y << (bit_width - reaminder) Use the high bit_width -
remainder bit
Review Comment:
Fix typos in comment: change `ouput` to `output` and `reaminder` to
`remainder`.
```suggestion
// output = y << (bit_width - remainder) Use the high bit_width
- remainder bit
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]