wesm commented on code in PR #13654:
URL: https://github.com/apache/arrow/pull/13654#discussion_r925607156


##########
cpp/src/arrow/compute/kernels/scalar_compare.cc:
##########
@@ -158,11 +158,183 @@ struct Maximum {
 
 // Implement Less, LessEqual by flipping arguments to Greater, GreaterEqual
 
-template <typename OutType, typename ArgType, typename Op>
-struct CompareTimestamps
-    : public applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op> {
-  using Base = applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op>;
+template <int batch_size>
+void PackBits(const int* values, uint8_t* out) {
+  for (int i = 0; i < batch_size / 8; ++i) {
+    *out++ = (values[0] | values[1] << 1 | values[2] << 2 | values[3] << 3 |
+              values[4] << 4 | values[5] << 5 | values[6] << 6 | values[7] << 
7);
+    values += 8;
+  }
+}
+
+template <typename T, typename Op>
+struct ComparePrimitive {
+  static void Exec(const void* left_values_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] = Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, 
nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(out_bitmap, bit_index++,
+                         Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, 
nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveAS {
+  static void Exec(const void* left_values_void, const void* right_value_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T right_value = *reinterpret_cast<const T*>(right_value_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] =
+            Op::template Call<bool, T, T>(nullptr, *left_values++, 
right_value, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(
+          out_bitmap, bit_index++,
+          Op::template Call<bool, T, T>(nullptr, *left_values++, right_value, 
nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveSA {
+  static void Exec(const void* left_value_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T left_value = *reinterpret_cast<const T*>(left_value_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] =
+            Op::template Call<bool, T, T>(nullptr, left_value, 
*right_values++, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(
+          out_bitmap, bit_index++,
+          Op::template Call<bool, T, T>(nullptr, left_value, *right_values++, 
nullptr));
+    }
+  }
+};
+
+using BinaryKernel = void (*)(const void*, const void*, int64_t, void*);
+
+struct CompareData : public KernelState {
+  BinaryKernel func_aa;
+  BinaryKernel func_sa;
+  BinaryKernel func_as;
+  CompareData(BinaryKernel func_aa, BinaryKernel func_sa, BinaryKernel func_as)
+      : func_aa(func_aa), func_sa(func_sa), func_as(func_as) {}
+};
+
+template <template <typename...> class Generator, typename Op>
+BinaryKernel GetBinaryKernel(Type::type type) {
+  switch (type) {
+    case Type::INT8:
+      return Generator<int8_t, Op>::Exec;
+    case Type::INT16:
+      return Generator<int16_t, Op>::Exec;
+    case Type::INT32:
+    case Type::DATE32:
+      return Generator<int32_t, Op>::Exec;
+    case Type::INT64:
+    case Type::DURATION:
+    case Type::TIMESTAMP:
+    case Type::DATE64:
+      return Generator<int64_t, Op>::Exec;
+    case Type::UINT8:
+      return Generator<uint8_t, Op>::Exec;
+    case Type::UINT16:
+      return Generator<uint16_t, Op>::Exec;
+    case Type::UINT32:
+      return Generator<uint32_t, Op>::Exec;
+    case Type::UINT64:
+      return Generator<uint64_t, Op>::Exec;
+    case Type::FLOAT:
+      return Generator<float, Op>::Exec;
+    case Type::DOUBLE:
+      return Generator<double, Op>::Exec;
+    default:
+      return nullptr;
+  }
+}
+
+template <typename Type>
+struct CompareKernel {
+  using T = typename Type::c_type;
+
+  static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* 
out) {
+    const auto kernel = static_cast<const ScalarKernel*>(ctx->kernel());
+    DCHECK(kernel);
+    const auto kernel_data = static_cast<const 
CompareData*>(kernel->data.get());
+
+    ArraySpan* out_arr = out->array_span();
+
+    // TODO: implement path for offset not multiple of 8
+    const bool out_is_byte_aligned = out_arr->offset % 8 == 0;

Review Comment:
   These kernels can write into sliced outputs (chunksize not a multiple of 8) 
-- whether they _should_ be allowed to do this is a separate questions...



##########
cpp/src/arrow/compute/kernels/scalar_compare.cc:
##########
@@ -158,11 +158,183 @@ struct Maximum {
 
 // Implement Less, LessEqual by flipping arguments to Greater, GreaterEqual
 
-template <typename OutType, typename ArgType, typename Op>
-struct CompareTimestamps
-    : public applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op> {
-  using Base = applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op>;
+template <int batch_size>
+void PackBits(const int* values, uint8_t* out) {
+  for (int i = 0; i < batch_size / 8; ++i) {
+    *out++ = (values[0] | values[1] << 1 | values[2] << 2 | values[3] << 3 |
+              values[4] << 4 | values[5] << 5 | values[6] << 6 | values[7] << 
7);
+    values += 8;
+  }
+}
+
+template <typename T, typename Op>
+struct ComparePrimitive {
+  static void Exec(const void* left_values_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] = Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, 
nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(out_bitmap, bit_index++,
+                         Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, 
nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveAS {
+  static void Exec(const void* left_values_void, const void* right_value_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T right_value = *reinterpret_cast<const T*>(right_value_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] =
+            Op::template Call<bool, T, T>(nullptr, *left_values++, 
right_value, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(
+          out_bitmap, bit_index++,
+          Op::template Call<bool, T, T>(nullptr, *left_values++, right_value, 
nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveSA {
+  static void Exec(const void* left_value_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T left_value = *reinterpret_cast<const T*>(left_value_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] =
+            Op::template Call<bool, T, T>(nullptr, left_value, 
*right_values++, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(
+          out_bitmap, bit_index++,
+          Op::template Call<bool, T, T>(nullptr, left_value, *right_values++, 
nullptr));
+    }
+  }
+};
+
+using BinaryKernel = void (*)(const void*, const void*, int64_t, void*);
+
+struct CompareData : public KernelState {
+  BinaryKernel func_aa;
+  BinaryKernel func_sa;
+  BinaryKernel func_as;
+  CompareData(BinaryKernel func_aa, BinaryKernel func_sa, BinaryKernel func_as)
+      : func_aa(func_aa), func_sa(func_sa), func_as(func_as) {}
+};
+
+template <template <typename...> class Generator, typename Op>
+BinaryKernel GetBinaryKernel(Type::type type) {
+  switch (type) {
+    case Type::INT8:
+      return Generator<int8_t, Op>::Exec;
+    case Type::INT16:
+      return Generator<int16_t, Op>::Exec;
+    case Type::INT32:
+    case Type::DATE32:
+      return Generator<int32_t, Op>::Exec;
+    case Type::INT64:
+    case Type::DURATION:
+    case Type::TIMESTAMP:
+    case Type::DATE64:
+      return Generator<int64_t, Op>::Exec;
+    case Type::UINT8:
+      return Generator<uint8_t, Op>::Exec;
+    case Type::UINT16:
+      return Generator<uint16_t, Op>::Exec;
+    case Type::UINT32:
+      return Generator<uint32_t, Op>::Exec;
+    case Type::UINT64:
+      return Generator<uint64_t, Op>::Exec;
+    case Type::FLOAT:
+      return Generator<float, Op>::Exec;
+    case Type::DOUBLE:
+      return Generator<double, Op>::Exec;
+    default:
+      return nullptr;
+  }
+}
+
+template <typename Type>
+struct CompareKernel {
+  using T = typename Type::c_type;
+
+  static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* 
out) {
+    const auto kernel = static_cast<const ScalarKernel*>(ctx->kernel());
+    DCHECK(kernel);
+    const auto kernel_data = static_cast<const 
CompareData*>(kernel->data.get());
+
+    ArraySpan* out_arr = out->array_span();
+
+    // TODO: implement path for offset not multiple of 8
+    const bool out_is_byte_aligned = out_arr->offset % 8 == 0;

Review Comment:
   These kernels can write into sliced outputs (chunksize not a multiple of 8) 
-- whether they _should_ be allowed to do this is a separate question...



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to