lidavidm commented on a change in pull request #10557: URL: https://github.com/apache/arrow/pull/10557#discussion_r661730002
########## File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc ########## @@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_fun } } -} // namespace +// Helper to copy or broadcast fixed-width values between buffers. +template <typename Type, typename Enable = void> +struct CopyFixedWidth {}; +template <> +struct CopyFixedWidth<BooleanType> { + static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const bool value = UnboxScalar<BooleanType>::Unbox(scalar); + BitUtil::SetBitsTo(out_values, offset, length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length, + out_values, offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_number<Type>> { + using CType = typename TypeTraits<Type>::CType; + static void CopyScalar(const Scalar& values, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType value = UnboxScalar<Type>::Unbox(values); + std::fill(out_values + offset, out_values + offset + length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType* in_values = array.GetValues<CType>(1); + std::copy(in_values + offset, in_values + offset + length, out_values + offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> { + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(values); + // Scalar may have null value buffer + if (!scalar.value) return; + DCHECK_EQ(scalar.value->size(), width); + for (int i = 0; i < length; i++) { + std::memcpy(next, scalar.value->data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_decimal<Type>> { + using ScalarType = typename TypeTraits<Type>::ScalarType; + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const ScalarType&>(values); + const auto value = scalar.value.ToBytes(); + for (int i = 0; i < length; i++) { + std::memcpy(next, value.data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +// Copy fixed-width values from a scalar/array datum into an output values buffer +template <typename Type> +void CopyValues(const Datum& values, uint8_t* out_valid, uint8_t* out_values, + const int64_t offset, const int64_t length) { + using Copier = CopyFixedWidth<Type>; + if (values.is_scalar()) { + const auto& scalar = *values.scalar(); + if (out_valid) { + BitUtil::SetBitsTo(out_valid, offset, length, scalar.is_valid); + } + Copier::CopyScalar(scalar, out_values, offset, length); + } else { + const ArrayData& array = *values.array(); + if (out_valid) { + if (array.MayHaveNulls()) { + arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + offset, + length, out_valid, offset); + } else { + BitUtil::SetBitsTo(out_valid, offset, length, true); + } + } + Copier::CopyArray(array, out_values, offset, length); + } +} + +struct CaseWhenFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + std::vector<ValueDescr> value_types; + for (size_t i = 0; i < values->size() - 1; i += 2) { + ValueDescr* cond = &(*values)[i]; + if (cond->type->id() == Type::NA) { + cond->type = boolean(); + } + if (cond->type->id() != Type::BOOL) { + return Status::Invalid("Condition arguments must be boolean, but argument ", i, + " was ", cond->type->ToString()); + } + value_types.push_back((*values)[i + 1]); + } + if (values->size() % 2 != 0) { + // Have an ELSE clause + value_types.push_back(values->back()); + } + EnsureDictionaryDecoded(&value_types); + if (auto type = CommonNumeric(value_types)) { + ReplaceTypes(type, &value_types); + } + + const DataType& common_values_type = *value_types.front().type; + auto next_type = value_types.cbegin(); + for (size_t i = 0; i < values->size(); i += 2) { + if (!common_values_type.Equals(next_type->type)) { + return Status::Invalid("Value arguments must be of same type, but argument ", i, + " was ", next_type->type->ToString(), " (expected ", + common_values_type.ToString(), ")"); + } + if (i == values->size() - 1) { + // ELSE + (*values)[i] = *next_type++; + } else { + (*values)[i + 1] = *next_type++; + } + } + + // We register a unary kernel for each value type and dispatch to it after validation. + if (auto kernel = DispatchExactImpl(this, {values->back()})) return kernel; + return arrow::compute::detail::NoMatchingKernel(this, *values); + } +}; + +// Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar arguments +Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + for (size_t i = 0; i < batch.values.size() - 1; i += 2) { + const Scalar& cond = *batch[i].scalar(); + if (cond.is_valid && internal::UnboxScalar<BooleanType>::Unbox(cond)) { + *out = batch[i + 1]; + return Status::OK(); + } + } + if (batch.values.size() % 2 == 0) { + // No ELSE + *out = MakeNullScalar(batch[1].type()); + } else { + *out = batch.values.back(); + } + return Status::OK(); +} + +// Implement 'case when' for any mix of scalar/array arguments for any fixed-width type, +// given helper functions to copy data from a source array to a target array and to +// allocate a values buffer +template <typename Type> +Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + ArrayData* output = out->mutable_array(); + const bool have_else_arg = batch.values.size() % 2 != 0; + // Check if we may need a validity bitmap + uint8_t* out_valid = nullptr; + + bool need_valid_bitmap = false; + if (!have_else_arg) { + // If we don't have an else arg -> need a bitmap since we may emit nulls + need_valid_bitmap = true; + } else if (batch.values.back().null_count() > 0) { + // If the 'else' array has a null count we need a validity bitmap + need_valid_bitmap = true; + } else { + // Otherwise if any value array has a null count we need a validity bitmap + for (size_t i = 1; i < batch.values.size(); i += 2) { + if (batch[i].null_count() > 0) { + need_valid_bitmap = true; + break; + } + } + } + if (need_valid_bitmap) { + ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(batch.length)); + out_valid = output->buffers[0]->mutable_data(); + } + + // Initialize values buffer + uint8_t* out_values = output->buffers[1]->mutable_data(); + if (have_else_arg) { + // Copy 'else' value into output Review comment: This seems to be faster as written, oddly. Before: ``` ----------------------------------------------------------------------------------------------- Benchmark Time CPU Iterations UserCounters... ----------------------------------------------------------------------------------------------- CaseWhenBench32/1048576/0 31933112 ns 31932368 ns 22 bytes_per_second=504.974M/s CaseWhenBench64/1048576/0 33170481 ns 33168735 ns 21 bytes_per_second=968.533M/s CaseWhenBench32/1048576/99 32487300 ns 32487411 ns 21 bytes_per_second=496.299M/s CaseWhenBench64/1048576/99 33682029 ns 33680901 ns 21 bytes_per_second=953.715M/s CaseWhenBench32Contiguous/1048576/0 7255445 ns 7255387 ns 96 bytes_per_second=1.632G/s CaseWhenBench64Contiguous/1048576/0 7932437 ns 7932171 ns 88 bytes_per_second=2.97013G/s CaseWhenBench32Contiguous/1048576/99 7526742 ns 7526709 ns 92 bytes_per_second=1.57303G/s CaseWhenBench64Contiguous/1048576/99 8172498 ns 8172239 ns 83 bytes_per_second=2.88261G/s ``` After: ``` ----------------------------------------------------------------------------------------------- Benchmark Time CPU Iterations UserCounters... ----------------------------------------------------------------------------------------------- CaseWhenBench32/1048576/0 44166172 ns 44165634 ns 16 bytes_per_second=365.103M/s CaseWhenBench64/1048576/0 44605356 ns 44603995 ns 16 bytes_per_second=720.227M/s CaseWhenBench32/1048576/99 44867670 ns 44867051 ns 16 bytes_per_second=359.361M/s CaseWhenBench64/1048576/99 45077818 ns 45076721 ns 15 bytes_per_second=712.607M/s CaseWhenBench32Contiguous/1048576/0 17757494 ns 17757271 ns 39 bytes_per_second=682.819M/s CaseWhenBench64Contiguous/1048576/0 18236327 ns 18235892 ns 38 bytes_per_second=1.29193G/s CaseWhenBench32Contiguous/1048576/99 15051281 ns 15051008 ns 39 bytes_per_second=805.518M/s CaseWhenBench64Contiguous/1048576/99 15504081 ns 15503998 ns 45 bytes_per_second=1.51944G/s ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org