michalursa commented on a change in pull request #10290:
URL: https://github.com/apache/arrow/pull/10290#discussion_r688211337



##########
File path: cpp/src/arrow/compute/exec/key_compare.cc
##########
@@ -17,250 +17,398 @@
 
 #include "arrow/compute/exec/key_compare.h"
 
+#include <immintrin.h>
+#include <memory.h>
+
 #include <algorithm>
 #include <cstdint>
 
 #include "arrow/compute/exec/util.h"
+#include "arrow/util/bit_util.h"
 #include "arrow/util/ubsan.h"
 
 namespace arrow {
 namespace compute {
 
-void KeyCompare::CompareRows(uint32_t num_rows_to_compare,
-                             const uint16_t* sel_left_maybe_null,
-                             const uint32_t* left_to_right_map,
-                             KeyEncoder::KeyEncoderContext* ctx, uint32_t* 
out_num_rows,
-                             uint16_t* out_sel_left_maybe_same,
-                             const KeyEncoder::KeyRowArray& rows_left,
-                             const KeyEncoder::KeyRowArray& rows_right) {
-  ARROW_DCHECK(rows_left.metadata().is_compatible(rows_right.metadata()));
-
-  if (num_rows_to_compare == 0) {
-    *out_num_rows = 0;
+template <bool use_selection>
+void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t 
num_rows_to_compare,
+                                       const uint16_t* sel_left_maybe_null,
+                                       const uint32_t* left_to_right_map,
+                                       KeyEncoder::KeyEncoderContext* ctx,
+                                       const KeyEncoder::KeyColumnArray& col,
+                                       const KeyEncoder::KeyRowArray& rows,
+                                       uint8_t* match_bytevector) {
+  if (!rows.has_any_nulls(ctx) && !col.data(0)) {
     return;
   }
-
-  // Allocate temporary byte and bit vectors
-  auto bytevector_holder =
-      util::TempVectorHolder<uint8_t>(ctx->stack, num_rows_to_compare);
-  auto bitvector_holder =
-      util::TempVectorHolder<uint8_t>(ctx->stack, num_rows_to_compare);
-
-  uint8_t* match_bytevector = bytevector_holder.mutable_data();
-  uint8_t* match_bitvector = bitvector_holder.mutable_data();
-
-  // All comparison functions called here will update match byte vector
-  // (AND it with comparison result) instead of overwriting it.
-  memset(match_bytevector, 0xff, num_rows_to_compare);
-
-  if (rows_left.metadata().is_fixed_length) {
-    CompareFixedLength(num_rows_to_compare, sel_left_maybe_null, 
left_to_right_map,
-                       match_bytevector, ctx, 
rows_left.metadata().fixed_length,
-                       rows_left.data(1), rows_right.data(1));
-  } else {
-    CompareVaryingLength(num_rows_to_compare, sel_left_maybe_null, 
left_to_right_map,
-                         match_bytevector, ctx, rows_left.data(2), 
rows_right.data(2),
-                         rows_left.offsets(), rows_right.offsets());
+  uint32_t num_processed = 0;
+#if defined(ARROW_HAVE_AVX2)
+  if (ctx->has_avx2()) {
+    num_processed = NullUpdateColumnToRow_avx2(use_selection, id_col, 
num_rows_to_compare,
+                                               sel_left_maybe_null, 
left_to_right_map,
+                                               ctx, col, rows, 
match_bytevector);
   }
+#endif
 
-  // CompareFixedLength can be used to compare nulls as well
-  bool nulls_present = rows_left.has_any_nulls(ctx) || 
rows_right.has_any_nulls(ctx);
-  if (nulls_present) {
-    CompareFixedLength(num_rows_to_compare, sel_left_maybe_null, 
left_to_right_map,
-                       match_bytevector, ctx,
-                       rows_left.metadata().null_masks_bytes_per_row,
-                       rows_left.null_masks(), rows_right.null_masks());
+  if (!col.data(0)) {
+    // Remove rows from the result for which the column value is a null
+    const uint8_t* null_masks = rows.null_masks();
+    uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
+    for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) {
+      uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+      uint32_t irow_right = left_to_right_map[irow_left];
+      int64_t bitid = irow_right * null_mask_num_bytes * 8 + id_col;
+      match_bytevector[i] &= (BitUtil::GetBit(null_masks, bitid) ? 0 : 0xff);
+    }
+  } else if (!rows.has_any_nulls(ctx)) {
+    // Remove rows from the result for which the column value on left side is 
null
+    const uint8_t* non_nulls = col.data(0);
+    ARROW_DCHECK(non_nulls);
+    for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) {
+      uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+      match_bytevector[i] &=
+          BitUtil::GetBit(non_nulls, irow_left + col.bit_offset(0)) ? 0xff : 0;
+    }
+  } else {
+    const uint8_t* null_masks = rows.null_masks();
+    uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
+    const uint8_t* non_nulls = col.data(0);
+    ARROW_DCHECK(non_nulls);
+    for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) {
+      uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+      uint32_t irow_right = left_to_right_map[irow_left];
+      int64_t bitid_right = irow_right * null_mask_num_bytes * 8 + id_col;
+      int right_null = BitUtil::GetBit(null_masks, bitid_right) ? 0xff : 0;
+      int left_null =
+          BitUtil::GetBit(non_nulls, irow_left + col.bit_offset(0)) ? 0 : 0xff;
+      match_bytevector[i] |= left_null & right_null;
+      match_bytevector[i] &= ~(left_null ^ right_null);
+    }
   }
+}
 
-  util::BitUtil::bytes_to_bits(ctx->hardware_flags, num_rows_to_compare, 
match_bytevector,
-                               match_bitvector);
-  if (sel_left_maybe_null) {
-    int out_num_rows_int;
-    util::BitUtil::bits_filter_indexes(0, ctx->hardware_flags, 
num_rows_to_compare,
-                                       match_bitvector, sel_left_maybe_null,
-                                       &out_num_rows_int, 
out_sel_left_maybe_same);
-    *out_num_rows = out_num_rows_int;
+template <bool use_selection, class COMPARE_FN>
+void KeyCompare::CompareBinaryColumnToRowHelper(
+    uint32_t offset_within_row, uint32_t first_row_to_compare,
+    uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null,
+    const uint32_t* left_to_right_map, KeyEncoder::KeyEncoderContext* ctx,
+    const KeyEncoder::KeyColumnArray& col, const KeyEncoder::KeyRowArray& rows,
+    uint8_t* match_bytevector, COMPARE_FN compare_fn) {
+  bool is_fixed_length = rows.metadata().is_fixed_length;
+  if (is_fixed_length) {
+    uint32_t fixed_length = rows.metadata().fixed_length;
+    const uint8_t* rows_left = col.data(1);
+    const uint8_t* rows_right = rows.data(1);
+    for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) {
+      uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+      uint32_t irow_right = left_to_right_map[irow_left];
+      uint32_t offset_right = irow_right * fixed_length + offset_within_row;
+      match_bytevector[i] = compare_fn(rows_left, rows_right, irow_left, 
offset_right);
+    }
   } else {
-    int out_num_rows_int;
-    util::BitUtil::bits_to_indexes(0, ctx->hardware_flags, num_rows_to_compare,
-                                   match_bitvector, &out_num_rows_int,
-                                   out_sel_left_maybe_same);
-    *out_num_rows = out_num_rows_int;
+    const uint8_t* rows_left = col.data(1);
+    const uint32_t* offsets_right = rows.offsets();
+    const uint8_t* rows_right = rows.data(2);
+    for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) {
+      uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
+      uint32_t irow_right = left_to_right_map[irow_left];
+      uint32_t offset_right = offsets_right[irow_right] + offset_within_row;
+      match_bytevector[i] = compare_fn(rows_left, rows_right, irow_left, 
offset_right);
+    }
   }
 }
 
-void KeyCompare::CompareFixedLength(uint32_t num_rows_to_compare,
-                                    const uint16_t* sel_left_maybe_null,
-                                    const uint32_t* left_to_right_map,
-                                    uint8_t* match_bytevector,
-                                    KeyEncoder::KeyEncoderContext* ctx,
-                                    uint32_t fixed_length, const uint8_t* 
rows_left,
-                                    const uint8_t* rows_right) {
-  bool use_selection = (sel_left_maybe_null != nullptr);
-
-  uint32_t num_rows_already_processed = 0;
-
+template <bool use_selection>
+void KeyCompare::CompareBinaryColumnToRow(
+    uint32_t offset_within_row, uint32_t num_rows_to_compare,
+    const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map,
+    KeyEncoder::KeyEncoderContext* ctx, const KeyEncoder::KeyColumnArray& col,
+    const KeyEncoder::KeyRowArray& rows, uint8_t* match_bytevector) {
+  uint32_t num_processed = 0;
 #if defined(ARROW_HAVE_AVX2)
-  if (ctx->has_avx2() && !use_selection) {
-    // Choose between up-to-8B length, up-to-16B length and any size versions
-    if (fixed_length <= 8) {
-      num_rows_already_processed = CompareFixedLength_UpTo8B_avx2(
-          num_rows_to_compare, left_to_right_map, match_bytevector, 
fixed_length,
-          rows_left, rows_right);
-    } else if (fixed_length <= 16) {
-      num_rows_already_processed = CompareFixedLength_UpTo16B_avx2(
-          num_rows_to_compare, left_to_right_map, match_bytevector, 
fixed_length,
-          rows_left, rows_right);
-    } else {
-      num_rows_already_processed =
-          CompareFixedLength_avx2(num_rows_to_compare, left_to_right_map,
-                                  match_bytevector, fixed_length, rows_left, 
rows_right);
-    }
+  if (ctx->has_avx2()) {
+    num_processed = CompareBinaryColumnToRow_avx2(
+        use_selection, offset_within_row, num_rows_to_compare, 
sel_left_maybe_null,
+        left_to_right_map, ctx, col, rows, match_bytevector);
   }
 #endif
 
-  typedef void (*CompareFixedLengthImp_t)(uint32_t, uint32_t, const uint16_t*,
-                                          const uint32_t*, uint8_t*, uint32_t,
-                                          const uint8_t*, const uint8_t*);
-  static const CompareFixedLengthImp_t CompareFixedLengthImp_fn[] = {
-      CompareFixedLengthImp<false, 1>, CompareFixedLengthImp<false, 2>,
-      CompareFixedLengthImp<false, 0>, CompareFixedLengthImp<true, 1>,
-      CompareFixedLengthImp<true, 2>,  CompareFixedLengthImp<true, 0>};
-  int dispatch_const = (use_selection ? 3 : 0) +
-                       ((fixed_length <= 8) ? 0 : ((fixed_length <= 16) ? 1 : 
2));
-  CompareFixedLengthImp_fn[dispatch_const](
-      num_rows_already_processed, num_rows_to_compare, sel_left_maybe_null,
-      left_to_right_map, match_bytevector, fixed_length, rows_left, 
rows_right);
-}
+  uint32_t col_width = col.metadata().fixed_length;
+  if (col_width == 0) {
+    int bit_offset = col.bit_offset(1);
+    CompareBinaryColumnToRowHelper<use_selection>(
+        offset_within_row, num_processed, num_rows_to_compare, 
sel_left_maybe_null,
+        left_to_right_map, ctx, col, rows, match_bytevector,
+        [bit_offset](const uint8_t* left_base, const uint8_t* right_base,
+                     uint32_t irow_left, uint32_t offset_right) {
+          uint8_t left = BitUtil::GetBit(left_base, irow_left + bit_offset) ? 
0xff : 0x00;
+          uint8_t right = right_base[offset_right];
+          return left == right ? 0xff : 0;
+        });
+  } else if (col_width == 1) {
+    CompareBinaryColumnToRowHelper<use_selection>(
+        offset_within_row, num_processed, num_rows_to_compare, 
sel_left_maybe_null,
+        left_to_right_map, ctx, col, rows, match_bytevector,
+        [](const uint8_t* left_base, const uint8_t* right_base, uint32_t 
irow_left,
+           uint32_t offset_right) {
+          uint8_t left = left_base[irow_left];
+          uint8_t right = right_base[offset_right];
+          return left == right ? 0xff : 0;
+        });
+  } else if (col_width == 2) {
+    CompareBinaryColumnToRowHelper<use_selection>(
+        offset_within_row, num_processed, num_rows_to_compare, 
sel_left_maybe_null,
+        left_to_right_map, ctx, col, rows, match_bytevector,
+        [](const uint8_t* left_base, const uint8_t* right_base, uint32_t 
irow_left,
+           uint32_t offset_right) {
+          uint16_t left = reinterpret_cast<const 
uint16_t*>(left_base)[irow_left];

Review comment:
       I added alignment checks when guaranteed and SafeLoads when not.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to