lidavidm commented on a change in pull request #11218: URL: https://github.com/apache/arrow/pull/11218#discussion_r719607062
########## File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc ########## @@ -900,44 +896,182 @@ struct IfElseFunctor<Type, enable_if_fixed_size_binary<Type>> { auto* out_values = out->buffers[1]->mutable_data() + out->offset * byte_width; // copy right data to out_buff - const util::string_view& right_data = - internal::UnboxScalar<FixedSizeBinaryType>::Unbox(right); - if (right_data.data()) { + const uint8_t* right_data = UnboxBinaryScalar(right); + if (right_data) { for (int64_t i = 0; i < cond.length; i++) { - std::memcpy(out_values + i * byte_width, right_data.data(), right_data.size()); + std::memcpy(out_values + i * byte_width, right_data, byte_width); } } // selectively copy values from left data - const util::string_view& left_data = - internal::UnboxScalar<FixedSizeBinaryType>::Unbox(left); - + const uint8_t* left_data = UnboxBinaryScalar(left); RunIfElseLoop(cond, [&](int64_t data_offset, int64_t num_elems) { - if (left_data.data()) { + if (left_data) { for (int64_t i = 0; i < num_elems; i++) { - std::memcpy(out_values + (data_offset + i) * byte_width, left_data.data(), - left_data.size()); + std::memcpy(out_values + (data_offset + i) * byte_width, left_data, byte_width); } } }); return Status::OK(); } - static Result<int32_t> GetByteWidth(const DataType& left_type, - const DataType& right_type) { - int width = checked_cast<const FixedSizeBinaryType&>(left_type).byte_width(); - if (width == checked_cast<const FixedSizeBinaryType&>(right_type).byte_width()) { - return width; + template <typename T = Type> + static enable_if_t<!is_decimal_type<T>::value, const uint8_t*> UnboxBinaryScalar( + const Scalar& scalar) { + return reinterpret_cast<const uint8_t*>( + internal::UnboxScalar<FixedSizeBinaryType>::Unbox(scalar).data()); + } + + template <typename T = Type> + static enable_if_decimal<T, const uint8_t*> UnboxBinaryScalar(const Scalar& scalar) { + return internal::UnboxScalar<T>::Unbox(scalar).native_endian_bytes(); + } + + template <typename T = Type> + static enable_if_t<!is_decimal_type<T>::value, Result<int32_t>> GetByteWidth( + const DataType& left_type, const DataType& right_type) { + const int32_t width = + checked_cast<const FixedSizeBinaryType&>(left_type).byte_width(); + DCHECK_EQ(width, checked_cast<const FixedSizeBinaryType&>(right_type).byte_width()); + return width; + } + + template <typename T = Type> + static enable_if_decimal<T, Result<int32_t>> GetByteWidth(const DataType& left_type, + const DataType& right_type) { + const auto& left = checked_cast<const T&>(left_type); + const auto& right = checked_cast<const T&>(right_type); + DCHECK_EQ(left.precision(), right.precision()); + DCHECK_EQ(left.scale(), right.scale()); + return left.byte_width(); + } +}; + +// Use builders for dictionaries - slower, but allows us to unify dictionaries +template <typename Type> +struct IfElseFunctor< + Type, enable_if_t<is_nested_type<Type>::value || is_dictionary_type<Type>::value>> { + // A - Array, S - Scalar, X = Array/Scalar + + // SXX + static Status Call(KernelContext* ctx, const BooleanScalar& cond, const Datum& left, + const Datum& right, Datum* out) { + if (left.is_scalar() && right.is_scalar()) { + if (cond.is_valid) { + *out = cond.value ? left.scalar() : right.scalar(); + } else { + *out = MakeNullScalar(left.type()); + } + return Status::OK(); + } + // either left or right is an array. Output is always an array + int64_t out_arr_len = std::max(left.length(), right.length()); + if (!cond.is_valid) { + // cond is null; just create a null array + ARROW_ASSIGN_OR_RAISE(*out, + MakeArrayOfNull(left.type(), out_arr_len, ctx->memory_pool())) + return Status::OK(); + } + + const auto& valid_data = cond.value ? left : right; + if (valid_data.is_array()) { + *out = valid_data; } else { - return Status::Invalid("FixedSizeBinaryType byte_widths should be equal"); + // valid data is a scalar that needs to be broadcasted + ARROW_ASSIGN_OR_RAISE(*out, MakeArrayFromScalar(*valid_data.scalar(), out_arr_len, + ctx->memory_pool())); } + return Status::OK(); + } + + // AAA + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + return RunLoop( Review comment: I updated this to use BitRunReader. I also made this kernel not templated, since it didn't depend on the type in the first place, and to skip PromoteNullsVisitor, since the builder computes the output bitmap (and since unions and dictionaries with nulls mean PromoteNullsVisitor calculates an invalid bitmap anyways). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org