michalursa commented on a change in pull request #10290:
URL: https://github.com/apache/arrow/pull/10290#discussion_r688212687



##########
File path: cpp/src/arrow/compute/exec/key_compare.cc
##########
@@ -17,250 +17,398 @@
 
 #include "arrow/compute/exec/key_compare.h"
 
+#include <immintrin.h>
+#include <memory.h>
+
 #include <algorithm>
 #include <cstdint>
 
 #include "arrow/compute/exec/util.h"
+#include "arrow/util/bit_util.h"
 #include "arrow/util/ubsan.h"
 
 namespace arrow {
 namespace compute {
 
-void KeyCompare::CompareRows(uint32_t num_rows_to_compare,
-                             const uint16_t* sel_left_maybe_null,
-                             const uint32_t* left_to_right_map,
-                             KeyEncoder::KeyEncoderContext* ctx, uint32_t* 
out_num_rows,
-                             uint16_t* out_sel_left_maybe_same,
-                             const KeyEncoder::KeyRowArray& rows_left,
-                             const KeyEncoder::KeyRowArray& rows_right) {
-  ARROW_DCHECK(rows_left.metadata().is_compatible(rows_right.metadata()));
-
-  if (num_rows_to_compare == 0) {
-    *out_num_rows = 0;
+template <bool use_selection>
+void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t 
num_rows_to_compare,
+                                       const uint16_t* sel_left_maybe_null,
+                                       const uint32_t* left_to_right_map,
+                                       KeyEncoder::KeyEncoderContext* ctx,
+                                       const KeyEncoder::KeyColumnArray& col,
+                                       const KeyEncoder::KeyRowArray& rows,
+                                       uint8_t* match_bytevector) {
+  if (!rows.has_any_nulls(ctx) && !col.data(0)) {
     return;
   }
-
-  // Allocate temporary byte and bit vectors
-  auto bytevector_holder =
-      util::TempVectorHolder<uint8_t>(ctx->stack, num_rows_to_compare);
-  auto bitvector_holder =
-      util::TempVectorHolder<uint8_t>(ctx->stack, num_rows_to_compare);
-
-  uint8_t* match_bytevector = bytevector_holder.mutable_data();
-  uint8_t* match_bitvector = bitvector_holder.mutable_data();
-
-  // All comparison functions called here will update match byte vector
-  // (AND it with comparison result) instead of overwriting it.
-  memset(match_bytevector, 0xff, num_rows_to_compare);
-
-  if (rows_left.metadata().is_fixed_length) {
-    CompareFixedLength(num_rows_to_compare, sel_left_maybe_null, 
left_to_right_map,
-                       match_bytevector, ctx, 
rows_left.metadata().fixed_length,
-                       rows_left.data(1), rows_right.data(1));
-  } else {
-    CompareVaryingLength(num_rows_to_compare, sel_left_maybe_null, 
left_to_right_map,
-                         match_bytevector, ctx, rows_left.data(2), 
rows_right.data(2),
-                         rows_left.offsets(), rows_right.offsets());
+  uint32_t num_processed = 0;
+#if defined(ARROW_HAVE_AVX2)
+  if (ctx->has_avx2()) {
+    num_processed = NullUpdateColumnToRow_avx2(use_selection, id_col, 
num_rows_to_compare,
+                                               sel_left_maybe_null, 
left_to_right_map,
+                                               ctx, col, rows, 
match_bytevector);
   }
+#endif
 
-  // CompareFixedLength can be used to compare nulls as well
-  bool nulls_present = rows_left.has_any_nulls(ctx) || 
rows_right.has_any_nulls(ctx);
-  if (nulls_present) {
-    CompareFixedLength(num_rows_to_compare, sel_left_maybe_null, 
left_to_right_map,
-                       match_bytevector, ctx,
-                       rows_left.metadata().null_masks_bytes_per_row,
-                       rows_left.null_masks(), rows_right.null_masks());
+  if (!col.data(0)) {
+    // Remove rows from the result for which the column value is a null
+    const uint8_t* null_masks = rows.null_masks();
+    uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
+    for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) {
+      uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+      uint32_t irow_right = left_to_right_map[irow_left];
+      int64_t bitid = irow_right * null_mask_num_bytes * 8 + id_col;
+      match_bytevector[i] &= (BitUtil::GetBit(null_masks, bitid) ? 0 : 0xff);
+    }
+  } else if (!rows.has_any_nulls(ctx)) {
+    // Remove rows from the result for which the column value on left side is 
null
+    const uint8_t* non_nulls = col.data(0);
+    ARROW_DCHECK(non_nulls);
+    for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) {
+      uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+      match_bytevector[i] &=
+          BitUtil::GetBit(non_nulls, irow_left + col.bit_offset(0)) ? 0xff : 0;
+    }
+  } else {
+    const uint8_t* null_masks = rows.null_masks();
+    uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
+    const uint8_t* non_nulls = col.data(0);
+    ARROW_DCHECK(non_nulls);
+    for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) {
+      uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+      uint32_t irow_right = left_to_right_map[irow_left];
+      int64_t bitid_right = irow_right * null_mask_num_bytes * 8 + id_col;
+      int right_null = BitUtil::GetBit(null_masks, bitid_right) ? 0xff : 0;
+      int left_null =
+          BitUtil::GetBit(non_nulls, irow_left + col.bit_offset(0)) ? 0 : 0xff;
+      match_bytevector[i] |= left_null & right_null;
+      match_bytevector[i] &= ~(left_null ^ right_null);
+    }
   }
+}
 
-  util::BitUtil::bytes_to_bits(ctx->hardware_flags, num_rows_to_compare, 
match_bytevector,
-                               match_bitvector);
-  if (sel_left_maybe_null) {
-    int out_num_rows_int;
-    util::BitUtil::bits_filter_indexes(0, ctx->hardware_flags, 
num_rows_to_compare,
-                                       match_bitvector, sel_left_maybe_null,
-                                       &out_num_rows_int, 
out_sel_left_maybe_same);
-    *out_num_rows = out_num_rows_int;
+template <bool use_selection, class COMPARE_FN>
+void KeyCompare::CompareBinaryColumnToRowHelper(
+    uint32_t offset_within_row, uint32_t first_row_to_compare,
+    uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null,
+    const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx,
+    const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows,
+    uint8_t* match_bytevector, COMPARE_FN compare_fn) {
+  bool is_fixed_length = rows.metadata().is_fixed_length;
+  if (is_fixed_length) {
+    uint32_t fixed_length = rows.metadata().fixed_length;
+    const uint8_t* rows_left = col.data(1);
+    const uint8_t* rows_right = rows.data(1);
+    for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) {
+      uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+      uint32_t irow_right = left_to_right_map[irow_left];
+      uint32_t offset_right = irow_right * fixed_length + offset_within_row;
+      match_bytevector[i] = compare_fn(rows_left, rows_right, irow_left, 
offset_right);
+    }
   } else {
-    int out_num_rows_int;
-    util::BitUtil::bits_to_indexes(0, ctx->hardware_flags, num_rows_to_compare,
-                                   match_bitvector, &out_num_rows_int,
-                                   out_sel_left_maybe_same);
-    *out_num_rows = out_num_rows_int;
+    const uint8_t* rows_left = col.data(1);
+    const uint32_t* offsets_right = rows.offsets();
+    const uint8_t* rows_right = rows.data(2);
+    for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) {
+      uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+      uint32_t irow_right = left_to_right_map[irow_left];
+      uint32_t offset_right = offsets_right[irow_right] + offset_within_row;
+      match_bytevector[i] = compare_fn(rows_left, rows_right, irow_left, 
offset_right);
+    }
   }
 }
 
-void KeyCompare::CompareFixedLength(uint32_t num_rows_to_compare,
-                                    const uint16_t* sel_left_maybe_null,
-                                    const uint32_t* left_to_right_map,
-                                    uint8_t* match_bytevector,
-                                    KeyEncoder::KeyEncoderContext* ctx,
-                                    uint32_t fixed_length, const uint8_t* 
rows_left,
-                                    const uint8_t* rows_right) {
-  bool use_selection = (sel_left_maybe_null != nullptr);
-
-  uint32_t num_rows_already_processed = 0;
-
+template <bool use_selection>
+void KeyCompare::CompareBinaryColumnToRow(
+    uint32_t offset_within_row, uint32_t num_rows_to_compare,
+    const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+    KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+    const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) {
+  uint32_t num_processed = 0;
 #if defined(ARROW_HAVE_AVX2)
-  if (ctx->has_avx2() && !use_selection) {
-    // Choose between up-to-8B length, up-to-16B length and any size versions
-    if (fixed_length <= 8) {
-      num_rows_already_processed = CompareFixedLength_UpTo8B_avx2(
-          num_rows_to_compare, left_to_right_map, match_bytevector, 
fixed_length,
-          rows_left, rows_right);
-    } else if (fixed_length <= 16) {
-      num_rows_already_processed = CompareFixedLength_UpTo16B_avx2(
-          num_rows_to_compare, left_to_right_map, match_bytevector, 
fixed_length,
-          rows_left, rows_right);
-    } else {
-      num_rows_already_processed =
-          CompareFixedLength_avx2(num_rows_to_compare, left_to_right_map,
-                                  match_bytevector, fixed_length, rows_left, 
rows_right);
-    }
+  if (ctx->has_avx2()) {
+    num_processed = CompareBinaryColumnToRow_avx2(
+        use_selection, offset_within_row, num_rows_to_compare, 
sel_left_maybe_null,
+        left_to_right_map, ctx, col, rows, match_bytevector);
   }
 #endif
 
-  typedef void (*CompareFixedLengthImp_t)(uint32_t, uint32_t, const uint16_t*,
-                                          const uint32_t*, uint8_t*, uint32_t,
-                                          const uint8_t*, const uint8_t*);
-  static const CompareFixedLengthImp_t CompareFixedLengthImp_fn[] = {
-      CompareFixedLengthImp<false, 1>, CompareFixedLengthImp<false, 2>,
-      CompareFixedLengthImp<false, 0>, CompareFixedLengthImp<true, 1>,
-      CompareFixedLengthImp<true, 2>,  CompareFixedLengthImp<true, 0>};
-  int dispatch_const = (use_selection ? 3 : 0) +
-                       ((fixed_length <= 8) ? 0 : ((fixed_length <= 16) ? 1 : 
2));
-  CompareFixedLengthImp_fn[dispatch_const](
-      num_rows_already_processed, num_rows_to_compare, sel_left_maybe_null,
-      left_to_right_map, match_bytevector, fixed_length, rows_left, 
rows_right);
-}
+  uint32_t col_width = col.metadata().fixed_length;
+  if (col_width == 0) {
+    int bit_offset = col.bit_offset(1);
+    CompareBinaryColumnToRowHelper<use_selection>(
+        offset_within_row, num_processed, num_rows_to_compare, 
sel_left_maybe_null,
+        left_to_right_map, ctx, col, rows, match_bytevector,
+        [bit_offset](const uint8_t* left_base, const uint8_t* right_base,
+                     uint32_t irow_left, uint32_t offset_right) {
+          uint8_t left = BitUtil::GetBit(left_base, irow_left + bit_offset) ? 
0xff : 0x00;
+          uint8_t right = right_base[offset_right];
+          return left == right ? 0xff : 0;
+        });
+  } else if (col_width == 1) {
+    CompareBinaryColumnToRowHelper<use_selection>(
+        offset_within_row, num_processed, num_rows_to_compare, 
sel_left_maybe_null,
+        left_to_right_map, ctx, col, rows, match_bytevector,
+        [](const uint8_t* left_base, const uint8_t* right_base, uint32_t 
irow_left,
+           uint32_t offset_right) {
+          uint8_t left = left_base[irow_left];
+          uint8_t right = right_base[offset_right];
+          return left == right ? 0xff : 0;
+        });
+  } else if (col_width == 2) {
+    CompareBinaryColumnToRowHelper<use_selection>(
+        offset_within_row, num_processed, num_rows_to_compare, 
sel_left_maybe_null,
+        left_to_right_map, ctx, col, rows, match_bytevector,
+        [](const uint8_t* left_base, const uint8_t* right_base, uint32_t 
irow_left,
+           uint32_t offset_right) {
+          uint16_t left = reinterpret_cast<const 
uint16_t*>(left_base)[irow_left];
+          uint16_t right = *reinterpret_cast<const uint16_t*>(right_base + 
offset_right);
+          return left == right ? 0xff : 0;
+        });
+  } else if (col_width == 4) {
+    CompareBinaryColumnToRowHelper<use_selection>(
+        offset_within_row, num_processed, num_rows_to_compare, 
sel_left_maybe_null,
+        left_to_right_map, ctx, col, rows, match_bytevector,
+        [](const uint8_t* left_base, const uint8_t* right_base, uint32_t 
irow_left,
+           uint32_t offset_right) {
+          uint32_t left = reinterpret_cast<const 
uint32_t*>(left_base)[irow_left];
+          uint32_t right = *reinterpret_cast<const uint32_t*>(right_base + 
offset_right);
+          return left == right ? 0xff : 0;
+        });
+  } else if (col_width == 8) {
+    CompareBinaryColumnToRowHelper<use_selection>(
+        offset_within_row, num_processed, num_rows_to_compare, 
sel_left_maybe_null,
+        left_to_right_map, ctx, col, rows, match_bytevector,
+        [](const uint8_t* left_base, const uint8_t* right_base, uint32_t 
irow_left,
+           uint32_t offset_right) {
+          uint64_t left = reinterpret_cast<const 
uint64_t*>(left_base)[irow_left];
+          uint64_t right = *reinterpret_cast<const uint64_t*>(right_base + 
offset_right);
+          return left == right ? 0xff : 0;
+        });
+  } else {
+    CompareBinaryColumnToRowHelper<use_selection>(
+        offset_within_row, num_processed, num_rows_to_compare, 
sel_left_maybe_null,
+        left_to_right_map, ctx, col, rows, match_bytevector,
+        [&col](const uint8_t* left_base, const uint8_t* right_base, uint32_t 
irow_left,
+               uint32_t offset_right) {
+          uint32_t length = col.metadata().fixed_length;
 
-template <bool use_selection, int num_64bit_words>
-void KeyCompare::CompareFixedLengthImp(uint32_t num_rows_already_processed,
-                                       uint32_t num_rows,
-                                       const uint16_t* sel_left_maybe_null,
-                                       const uint32_t* left_to_right_map,
-                                       uint8_t* match_bytevector, uint32_t 
length,
-                                       const uint8_t* rows_left,
-                                       const uint8_t* rows_right) {
-  // Key length (for encoded key) has to be non-zero
-  ARROW_DCHECK(length > 0);
+          // Non-zero length guarantees no underflow
+          int32_t num_loops_less_one = (static_cast<int32_t>(length) + 7) / 8 
- 1;
+
+          uint64_t tail_mask = ~0ULL >> (64 - 8 * (length - num_loops_less_one 
* 8));
 
-  // Non-zero length guarantees no underflow
-  int32_t num_loops_less_one = (static_cast<int32_t>(length) + 7) / 8 - 1;
+          const uint64_t* key_left_ptr =
+              reinterpret_cast<const uint64_t*>(left_base + irow_left * 
length);
+          const uint64_t* key_right_ptr =
+              reinterpret_cast<const uint64_t*>(right_base + offset_right);
+          uint64_t result_or = 0;
+          int32_t i;
+          // length cannot be zero
+          for (i = 0; i < num_loops_less_one; ++i) {
+            uint64_t key_left = key_left_ptr[i];
+            uint64_t key_right = key_right_ptr[i];
+            result_or |= key_left ^ key_right;
+          }
+          uint64_t key_left = key_left_ptr[i];
+          uint64_t key_right = key_right_ptr[i];
+          result_or |= tail_mask & (key_left ^ key_right);
+          return result_or == 0 ? 0xff : 0;
+        });
+  }
+}
 
-  // Length remaining in last loop can only be zero for input length equal to 
zero
-  uint32_t length_remaining_last_loop = length - num_loops_less_one * 8;
-  uint64_t tail_mask = (~0ULL) >> (8 * (8 - length_remaining_last_loop));
+// Overwrites the match_bytevector instead of updating it
+template <bool use_selection, bool is_first_varbinary_col>
+void KeyCompare::CompareVarBinaryColumnToRow(
+    uint32_t id_varbinary_col, uint32_t num_rows_to_compare,
+    const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+    KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+    const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) {
+#if defined(ARROW_HAVE_AVX2)
+  if (ctx->has_avx2()) {
+    CompareVarBinaryColumnToRow_avx2(
+        use_selection, is_first_varbinary_col, id_varbinary_col, 
num_rows_to_compare,
+        sel_left_maybe_null, left_to_right_map, ctx, col, rows, 
match_bytevector);
+    return;
+  }
+#endif
 
-  for (uint32_t id_input = num_rows_already_processed; id_input < num_rows; 
++id_input) {
-    uint32_t irow_left = use_selection ? sel_left_maybe_null[id_input] : 
id_input;
+  const uint32_t* offsets_left = col.offsets();
+  const uint32_t* offsets_right = rows.offsets();
+  const uint8_t* rows_left = col.data(2);
+  const uint8_t* rows_right = rows.data(2);
+  for (uint32_t i = 0; i < num_rows_to_compare; ++i) {
+    uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
     uint32_t irow_right = left_to_right_map[irow_left];
-    uint32_t begin_left = length * irow_left;
-    uint32_t begin_right = length * irow_right;
+    uint32_t begin_left = offsets_left[irow_left];
+    uint32_t length_left = offsets_left[irow_left + 1] - begin_left;
+    uint32_t begin_right = offsets_right[irow_right];
+    uint32_t length_right;
+    uint32_t offset_within_row;
+    if (!is_first_varbinary_col) {
+      rows.metadata().nth_varbinary_offset_and_length(
+          rows_right + begin_right, id_varbinary_col, &offset_within_row, 
&length_right);
+    } else {
+      rows.metadata().first_varbinary_offset_and_length(
+          rows_right + begin_right, &offset_within_row, &length_right);
+    }
+    begin_right += offset_within_row;
+    uint32_t length = std::min(length_left, length_right);
     const uint64_t* key_left_ptr =
         reinterpret_cast<const uint64_t*>(rows_left + begin_left);
     const uint64_t* key_right_ptr =
         reinterpret_cast<const uint64_t*>(rows_right + begin_right);
-    uint64_t result_or = 0ULL;
-    int32_t istripe = 0;
-
-    // Specializations for keys up to 8 bytes and between 9 and 16 bytes to
-    // avoid internal loop over words in the value for short ones.
-    //
-    // Template argument 0 means arbitrarily many 64-bit words,
-    // 1 means up to 1 and 2 means up to 2.
-    //
-    if (num_64bit_words == 0) {
-      for (; istripe < num_loops_less_one; ++istripe) {
-        uint64_t key_left = util::SafeLoad(&key_left_ptr[istripe]);
-        uint64_t key_right = util::SafeLoad(&key_right_ptr[istripe]);
-        result_or |= (key_left ^ key_right);
+    uint64_t result_or = 0;
+    if (length > 0) {
+      int32_t j;
+      // length can be zero
+      for (j = 0; j < (static_cast<int32_t>(length) + 7) / 8 - 1; ++j) {

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to