lidavidm commented on a change in pull request #10557: URL: https://github.com/apache/arrow/pull/10557#discussion_r664796056
########## File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc ########## @@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_fun } } -} // namespace +// Helper to copy or broadcast fixed-width values between buffers. +template <typename Type, typename Enable = void> +struct CopyFixedWidth {}; +template <> +struct CopyFixedWidth<BooleanType> { + static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const bool value = UnboxScalar<BooleanType>::Unbox(scalar); + BitUtil::SetBitsTo(out_values, offset, length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length, + out_values, offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_number<Type>> { + using CType = typename TypeTraits<Type>::CType; + static void CopyScalar(const Scalar& values, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType value = UnboxScalar<Type>::Unbox(values); + std::fill(out_values + offset, out_values + offset + length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast<CType*>(raw_out_values); + const CType* in_values = array.GetValues<CType>(1); + std::copy(in_values + offset, in_values + offset + length, out_values + offset); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> { + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(values); + // Scalar may have null value buffer + if (!scalar.value) return; + DCHECK_EQ(scalar.value->size(), width); + for (int i = 0; i < length; i++) { + std::memcpy(next, scalar.value->data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +template <typename Type> +struct CopyFixedWidth<Type, enable_if_decimal<Type>> { + using ScalarType = typename TypeTraits<Type>::ScalarType; + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast<const ScalarType&>(values); + const auto value = scalar.value.ToBytes(); + for (int i = 0; i < length; i++) { + std::memcpy(next, value.data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +// Copy fixed-width values from a scalar/array datum into an output values buffer +template <typename Type> +void CopyValues(const Datum& values, uint8_t* out_valid, uint8_t* out_values, + const int64_t offset, const int64_t length) { + using Copier = CopyFixedWidth<Type>; + if (values.is_scalar()) { + const auto& scalar = *values.scalar(); + if (out_valid) { + BitUtil::SetBitsTo(out_valid, offset, length, scalar.is_valid); + } + Copier::CopyScalar(scalar, out_values, offset, length); + } else { + const ArrayData& array = *values.array(); + if (out_valid) { + if (array.MayHaveNulls()) { + arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + offset, + length, out_valid, offset); + } else { + BitUtil::SetBitsTo(out_valid, offset, length, true); + } + } + Copier::CopyArray(array, out_values, offset, length); + } +} + +struct CaseWhenFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + std::vector<ValueDescr> value_types; + for (size_t i = 0; i < values->size() - 1; i += 2) { + ValueDescr* cond = &(*values)[i]; + if (cond->type->id() == Type::NA) { + cond->type = boolean(); + } + if (cond->type->id() != Type::BOOL) { + return Status::Invalid("Condition arguments must be boolean, but argument ", i, + " was ", cond->type->ToString()); + } + value_types.push_back((*values)[i + 1]); + } + if (values->size() % 2 != 0) { + // Have an ELSE clause + value_types.push_back(values->back()); + } + EnsureDictionaryDecoded(&value_types); + if (auto type = CommonNumeric(value_types)) { + ReplaceTypes(type, &value_types); + } + + const DataType& common_values_type = *value_types.front().type; + auto next_type = value_types.cbegin(); + for (size_t i = 0; i < values->size(); i += 2) { + if (!common_values_type.Equals(next_type->type)) { Review comment: A quick check with choose shows that it's only about 2/3 as fast: ``` CaseWhenBench32/1048576/0 48183534 ns 48183323 ns 13 bytes_per_second=334.659M/s CaseWhenBench64/1048576/0 48610718 ns 48610432 ns 13 bytes_per_second=660.866M/s CaseWhenBench32/1048576/99 48226819 ns 48226677 ns 15 bytes_per_second=334.327M/s CaseWhenBench64/1048576/99 48891957 ns 48892172 ns 13 bytes_per_second=656.996M/s ``` though note we have to call both first_true_in, then fill_null, to get the 'else' behavior which doesn't help. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org