lidavidm commented on a change in pull request #10557:
URL: https://github.com/apache/arrow/pull/10557#discussion_r661730002
##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc
##########
@@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const
std::shared_ptr<ScalarFunction>& scalar_fun
}
}
-} // namespace
+// Helper to copy or broadcast fixed-width values between buffers.
+template <typename Type, typename Enable = void>
+struct CopyFixedWidth {};
+template <>
+struct CopyFixedWidth<BooleanType> {
+ static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const
int64_t offset,
+ const int64_t length) {
+ const bool value = UnboxScalar<BooleanType>::Unbox(scalar);
+ BitUtil::SetBitsTo(out_values, offset, length, value);
+ }
+ static void CopyArray(const ArrayData& array, uint8_t* out_values, const
int64_t offset,
+ const int64_t length) {
+ arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset +
offset, length,
+ out_values, offset);
+ }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_number<Type>> {
+ using CType = typename TypeTraits<Type>::CType;
+ static void CopyScalar(const Scalar& values, uint8_t* raw_out_values,
+ const int64_t offset, const int64_t length) {
+ CType* out_values = reinterpret_cast<CType*>(raw_out_values);
+ const CType value = UnboxScalar<Type>::Unbox(values);
+ std::fill(out_values + offset, out_values + offset + length, value);
+ }
+ static void CopyArray(const ArrayData& array, uint8_t* raw_out_values,
+ const int64_t offset, const int64_t length) {
+ CType* out_values = reinterpret_cast<CType*>(raw_out_values);
+ const CType* in_values = array.GetValues<CType>(1);
+ std::copy(in_values + offset, in_values + offset + length, out_values +
offset);
+ }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> {
+ static void CopyScalar(const Scalar& values, uint8_t* out_values, const
int64_t offset,
+ const int64_t length) {
+ const int32_t width =
+ checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width();
+ uint8_t* next = out_values + (width * offset);
+ const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(values);
+ // Scalar may have null value buffer
+ if (!scalar.value) return;
+ DCHECK_EQ(scalar.value->size(), width);
+ for (int i = 0; i < length; i++) {
+ std::memcpy(next, scalar.value->data(), width);
+ next += width;
+ }
+ }
+ static void CopyArray(const ArrayData& array, uint8_t* out_values, const
int64_t offset,
+ const int64_t length) {
+ const int32_t width =
+ checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width();
+ uint8_t* next = out_values + (width * offset);
+ const auto* in_values = array.GetValues<uint8_t>(1, (offset +
array.offset) * width);
+ std::memcpy(next, in_values, length * width);
+ }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_decimal<Type>> {
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+ static void CopyScalar(const Scalar& values, uint8_t* out_values, const
int64_t offset,
+ const int64_t length) {
+ const int32_t width =
+ checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width();
+ uint8_t* next = out_values + (width * offset);
+ const auto& scalar = checked_cast<const ScalarType&>(values);
+ const auto value = scalar.value.ToBytes();
+ for (int i = 0; i < length; i++) {
+ std::memcpy(next, value.data(), width);
+ next += width;
+ }
+ }
+ static void CopyArray(const ArrayData& array, uint8_t* out_values, const
int64_t offset,
+ const int64_t length) {
+ const int32_t width =
+ checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width();
+ uint8_t* next = out_values + (width * offset);
+ const auto* in_values = array.GetValues<uint8_t>(1, (offset +
array.offset) * width);
+ std::memcpy(next, in_values, length * width);
+ }
+};
+// Copy fixed-width values from a scalar/array datum into an output values
buffer
+template <typename Type>
+void CopyValues(const Datum& values, uint8_t* out_valid, uint8_t* out_values,
+ const int64_t offset, const int64_t length) {
+ using Copier = CopyFixedWidth<Type>;
+ if (values.is_scalar()) {
+ const auto& scalar = *values.scalar();
+ if (out_valid) {
+ BitUtil::SetBitsTo(out_valid, offset, length, scalar.is_valid);
+ }
+ Copier::CopyScalar(scalar, out_values, offset, length);
+ } else {
+ const ArrayData& array = *values.array();
+ if (out_valid) {
+ if (array.MayHaveNulls()) {
+ arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset +
offset,
+ length, out_valid, offset);
+ } else {
+ BitUtil::SetBitsTo(out_valid, offset, length, true);
+ }
+ }
+ Copier::CopyArray(array, out_values, offset, length);
+ }
+}
+
+struct CaseWhenFunction : ScalarFunction {
+ using ScalarFunction::ScalarFunction;
+
+ Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const
override {
+ RETURN_NOT_OK(CheckArity(*values));
+ std::vector<ValueDescr> value_types;
+ for (size_t i = 0; i < values->size() - 1; i += 2) {
+ ValueDescr* cond = &(*values)[i];
+ if (cond->type->id() == Type::NA) {
+ cond->type = boolean();
+ }
+ if (cond->type->id() != Type::BOOL) {
+ return Status::Invalid("Condition arguments must be boolean, but
argument ", i,
+ " was ", cond->type->ToString());
+ }
+ value_types.push_back((*values)[i + 1]);
+ }
+ if (values->size() % 2 != 0) {
+ // Have an ELSE clause
+ value_types.push_back(values->back());
+ }
+ EnsureDictionaryDecoded(&value_types);
+ if (auto type = CommonNumeric(value_types)) {
+ ReplaceTypes(type, &value_types);
+ }
+
+ const DataType& common_values_type = *value_types.front().type;
+ auto next_type = value_types.cbegin();
+ for (size_t i = 0; i < values->size(); i += 2) {
+ if (!common_values_type.Equals(next_type->type)) {
+ return Status::Invalid("Value arguments must be of same type, but
argument ", i,
+ " was ", next_type->type->ToString(), "
(expected ",
+ common_values_type.ToString(), ")");
+ }
+ if (i == values->size() - 1) {
+ // ELSE
+ (*values)[i] = *next_type++;
+ } else {
+ (*values)[i + 1] = *next_type++;
+ }
+ }
+
+ // We register a unary kernel for each value type and dispatch to it after
validation.
+ if (auto kernel = DispatchExactImpl(this, {values->back()})) return kernel;
+ return arrow::compute::detail::NoMatchingKernel(this, *values);
+ }
+};
+
+// Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar
arguments
+Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum*
out) {
+ for (size_t i = 0; i < batch.values.size() - 1; i += 2) {
+ const Scalar& cond = *batch[i].scalar();
+ if (cond.is_valid && internal::UnboxScalar<BooleanType>::Unbox(cond)) {
+ *out = batch[i + 1];
+ return Status::OK();
+ }
+ }
+ if (batch.values.size() % 2 == 0) {
+ // No ELSE
+ *out = MakeNullScalar(batch[1].type());
+ } else {
+ *out = batch.values.back();
+ }
+ return Status::OK();
+}
+
+// Implement 'case when' for any mix of scalar/array arguments for any
fixed-width type,
+// given helper functions to copy data from a source array to a target array
and to
+// allocate a values buffer
+template <typename Type>
+Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum*
out) {
+ ArrayData* output = out->mutable_array();
+ const bool have_else_arg = batch.values.size() % 2 != 0;
+ // Check if we may need a validity bitmap
+ uint8_t* out_valid = nullptr;
+
+ bool need_valid_bitmap = false;
+ if (!have_else_arg) {
+ // If we don't have an else arg -> need a bitmap since we may emit nulls
+ need_valid_bitmap = true;
+ } else if (batch.values.back().null_count() > 0) {
+ // If the 'else' array has a null count we need a validity bitmap
+ need_valid_bitmap = true;
+ } else {
+ // Otherwise if any value array has a null count we need a validity bitmap
+ for (size_t i = 1; i < batch.values.size(); i += 2) {
+ if (batch[i].null_count() > 0) {
+ need_valid_bitmap = true;
+ break;
+ }
+ }
+ }
+ if (need_valid_bitmap) {
+ ARROW_ASSIGN_OR_RAISE(output->buffers[0],
ctx->AllocateBitmap(batch.length));
+ out_valid = output->buffers[0]->mutable_data();
+ }
+
+ // Initialize values buffer
+ uint8_t* out_values = output->buffers[1]->mutable_data();
+ if (have_else_arg) {
+ // Copy 'else' value into output
Review comment:
This seems to be faster as written, oddly.
Before:
```
-----------------------------------------------------------------------------------------------
Benchmark Time CPU
Iterations UserCounters...
-----------------------------------------------------------------------------------------------
CaseWhenBench32/1048576/0 31933112 ns 31932368 ns
22 bytes_per_second=504.974M/s
CaseWhenBench64/1048576/0 33170481 ns 33168735 ns
21 bytes_per_second=968.533M/s
CaseWhenBench32/1048576/99 32487300 ns 32487411 ns
21 bytes_per_second=496.299M/s
CaseWhenBench64/1048576/99 33682029 ns 33680901 ns
21 bytes_per_second=953.715M/s
CaseWhenBench32Contiguous/1048576/0 7255445 ns 7255387 ns
96 bytes_per_second=1.632G/s
CaseWhenBench64Contiguous/1048576/0 7932437 ns 7932171 ns
88 bytes_per_second=2.97013G/s
CaseWhenBench32Contiguous/1048576/99 7526742 ns 7526709 ns
92 bytes_per_second=1.57303G/s
CaseWhenBench64Contiguous/1048576/99 8172498 ns 8172239 ns
83 bytes_per_second=2.88261G/s
```
After:
```
-----------------------------------------------------------------------------------------------
Benchmark Time CPU
Iterations UserCounters...
-----------------------------------------------------------------------------------------------
CaseWhenBench32/1048576/0 44166172 ns 44165634 ns
16 bytes_per_second=365.103M/s
CaseWhenBench64/1048576/0 44605356 ns 44603995 ns
16 bytes_per_second=720.227M/s
CaseWhenBench32/1048576/99 44867670 ns 44867051 ns
16 bytes_per_second=359.361M/s
CaseWhenBench64/1048576/99 45077818 ns 45076721 ns
15 bytes_per_second=712.607M/s
CaseWhenBench32Contiguous/1048576/0 17757494 ns 17757271 ns
39 bytes_per_second=682.819M/s
CaseWhenBench64Contiguous/1048576/0 18236327 ns 18235892 ns
38 bytes_per_second=1.29193G/s
CaseWhenBench32Contiguous/1048576/99 15051281 ns 15051008 ns
39 bytes_per_second=805.518M/s
CaseWhenBench64Contiguous/1048576/99 15504081 ns 15503998 ns
45 bytes_per_second=1.51944G/s
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]