This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 57447434d1 perf: add optimized zip implementation for scalars (#8653)
57447434d1 is described below
commit 57447434d1921701456543f9dfb92741e5d86734
Author: Raz Luvaton <[email protected]>
AuthorDate: Tue Oct 28 13:39:08 2025 +0200
perf: add optimized zip implementation for scalars (#8653)
Waiting for the PRs below to be merged first:
- [x] https://github.com/apache/arrow-rs/pull/8654 - zip benchmarks
**This PR include the following other PRs (unless merged)** to make the
review easier, so please make sure to review them first
- [x] https://github.com/apache/arrow-rs/pull/8658 - extracted from this
- [x] https://github.com/apache/arrow-rs/pull/8656 - extracted from this
# Which issue does this PR close?
N/A
# Rationale for this change
Making zip really fast for scalars
This is useful for `IF <expr> THEN <literal> ELSE <literal> END`
# What changes are included in this PR?
Created couple of implementation for zipping scalar, for primitive,
bytes and fallback
# Are these changes tested?
existing tests
# Are there any user-facing changes?
new struct `ScalarZipper`
TODO:
- [x] Need to add comments if missing
- [x] Add tests for decimal and timestamp to make sure the type is kept
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
arrow-select/src/zip.rs | 846 +++++++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 834 insertions(+), 12 deletions(-)
diff --git a/arrow-select/src/zip.rs b/arrow-select/src/zip.rs
index c202be6b62..e45b817dc6 100644
--- a/arrow-select/src/zip.rs
+++ b/arrow-select/src/zip.rs
@@ -18,10 +18,21 @@
//! [`zip`]: Combine values from two arrays based on boolean mask
use crate::filter::{SlicesIterator, prep_null_mask_filter};
+use arrow_array::cast::AsArray;
+use arrow_array::types::{BinaryType, ByteArrayType, LargeBinaryType,
LargeUtf8Type, Utf8Type};
use arrow_array::*;
-use arrow_buffer::BooleanBuffer;
+use arrow_buffer::{
+ BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer,
OffsetBufferBuilder,
+ ScalarBuffer,
+};
+use arrow_data::ArrayData;
use arrow_data::transform::MutableArrayData;
-use arrow_schema::ArrowError;
+use arrow_schema::{ArrowError, DataType};
+use std::fmt::{Debug, Formatter};
+use std::hash::Hash;
+use std::marker::PhantomData;
+use std::ops::Not;
+use std::sync::Arc;
/// Zip two arrays by some boolean mask.
///
@@ -87,8 +98,16 @@ pub fn zip(
truthy: &dyn Datum,
falsy: &dyn Datum,
) -> Result<ArrayRef, ArrowError> {
- let (truthy, truthy_is_scalar) = truthy.get();
- let (falsy, falsy_is_scalar) = falsy.get();
+ let (truthy_array, truthy_is_scalar) = truthy.get();
+ let (falsy_array, falsy_is_scalar) = falsy.get();
+
+ if falsy_is_scalar && truthy_is_scalar {
+ let zipper = ScalarZipper::try_new(truthy, falsy)?;
+ return zipper.zip_impl.create_output(mask);
+ }
+
+ let truthy = truthy_array;
+ let falsy = falsy_array;
if truthy.data_type() != falsy.data_type() {
return Err(ArrowError::InvalidArgumentError(
@@ -120,7 +139,17 @@ pub fn zip(
let falsy = falsy.to_data();
let truthy = truthy.to_data();
- let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false,
truthy.len());
+ zip_impl(mask, &truthy, truthy_is_scalar, &falsy, falsy_is_scalar)
+}
+
+fn zip_impl(
+ mask: &BooleanArray,
+ truthy: &ArrayData,
+ truthy_is_scalar: bool,
+ falsy: &ArrayData,
+ falsy_is_scalar: bool,
+) -> Result<ArrayRef, ArrowError> {
+ let mut mutable = MutableArrayData::new(vec![truthy, falsy], false,
truthy.len());
// the SlicesIterator slices only the true values. So the gaps left by
this iterator we need to
// fill with falsy values
@@ -128,8 +157,8 @@ pub fn zip(
// keep track of how much is filled
let mut filled = 0;
- let mask = maybe_prep_null_mask_filter(mask);
- SlicesIterator::from(&mask).for_each(|(start, end)| {
+ let mask_buffer = maybe_prep_null_mask_filter(mask);
+ SlicesIterator::from(&mask_buffer).for_each(|(start, end)| {
// the gap needs to be filled with falsy values
if start > filled {
if falsy_is_scalar {
@@ -168,6 +197,455 @@ pub fn zip(
Ok(make_array(data))
}
+/// Zipper for 2 scalars
+///
+/// Useful for using in `IF <expr> THEN <scalar> ELSE <scalar> END` expressions
+///
+/// # Example
+/// ```
+/// # use std::sync::Arc;
+/// # use arrow_array::{ArrayRef, BooleanArray, Int32Array, Scalar,
cast::AsArray, types::Int32Type};
+///
+/// # use arrow_select::zip::ScalarZipper;
+/// let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
+/// let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1));
+/// let zipper = ScalarZipper::try_new(&scalar_truthy, &scalar_falsy).unwrap();
+///
+/// // Later when we have a boolean mask
+/// let mask = BooleanArray::from(vec![true, false, true, false, true]);
+/// let result = zipper.zip(&mask).unwrap();
+/// let actual = result.as_primitive::<Int32Type>();
+/// let expected = Int32Array::from(vec![Some(42), Some(123), Some(42),
Some(123), Some(42)]);
+/// ```
+///
+#[derive(Debug, Clone)]
+pub struct ScalarZipper {
+ zip_impl: Arc<dyn ZipImpl>,
+}
+
+impl ScalarZipper {
+ /// Try to create a new ScalarZipper from two scalar Datum
+ ///
+ /// # Errors
+ /// returns error if:
+ /// - the two Datum have different data types
+ /// - either Datum is not a scalar (or has more than 1 element)
+ ///
+ pub fn try_new(truthy: &dyn Datum, falsy: &dyn Datum) -> Result<Self,
ArrowError> {
+ let (truthy, truthy_is_scalar) = truthy.get();
+ let (falsy, falsy_is_scalar) = falsy.get();
+
+ if truthy.data_type() != falsy.data_type() {
+ return Err(ArrowError::InvalidArgumentError(
+ "arguments need to have the same data type".into(),
+ ));
+ }
+
+ if !truthy_is_scalar {
+ return Err(ArrowError::InvalidArgumentError(
+ "only scalar arrays are supported".into(),
+ ));
+ }
+
+ if !falsy_is_scalar {
+ return Err(ArrowError::InvalidArgumentError(
+ "only scalar arrays are supported".into(),
+ ));
+ }
+
+ if truthy.len() != 1 {
+ return Err(ArrowError::InvalidArgumentError(
+ "scalar arrays must have 1 element".into(),
+ ));
+ }
+ if falsy.len() != 1 {
+ return Err(ArrowError::InvalidArgumentError(
+ "scalar arrays must have 1 element".into(),
+ ));
+ }
+
+ macro_rules! primitive_size_helper {
+ ($t:ty) => {
+ Arc::new(PrimitiveScalarImpl::<$t>::new(truthy, falsy)) as
Arc<dyn ZipImpl>
+ };
+ }
+
+ let zip_impl = downcast_primitive! {
+ truthy.data_type() => (primitive_size_helper),
+ DataType::Utf8 => {
+ Arc::new(BytesScalarImpl::<Utf8Type>::new(truthy, falsy)) as
Arc<dyn ZipImpl>
+ },
+ DataType::LargeUtf8 => {
+ Arc::new(BytesScalarImpl::<LargeUtf8Type>::new(truthy, falsy))
as Arc<dyn ZipImpl>
+ },
+ DataType::Binary => {
+ Arc::new(BytesScalarImpl::<BinaryType>::new(truthy, falsy)) as
Arc<dyn ZipImpl>
+ },
+ DataType::LargeBinary => {
+ Arc::new(BytesScalarImpl::<LargeBinaryType>::new(truthy,
falsy)) as Arc<dyn ZipImpl>
+ },
+ // TODO: Handle Utf8View
https://github.com/apache/arrow-rs/issues/8724
+ _ => {
+ Arc::new(FallbackImpl::new(truthy, falsy)) as Arc<dyn ZipImpl>
+ },
+ };
+
+ Ok(Self { zip_impl })
+ }
+
+ /// Creating output array based on input boolean array and the two scalar
values the zipper was created with
+ /// See struct level documentation for examples.
+ pub fn zip(&self, mask: &BooleanArray) -> Result<ArrayRef, ArrowError> {
+ self.zip_impl.create_output(mask)
+ }
+}
+
+/// Impl for creating output array based on a mask
+trait ZipImpl: Debug + Send + Sync {
+ /// Creating output array based on input boolean array
+ fn create_output(&self, input: &BooleanArray) -> Result<ArrayRef,
ArrowError>;
+}
+
+#[derive(Debug, PartialEq)]
+struct FallbackImpl {
+ truthy: ArrayData,
+ falsy: ArrayData,
+}
+
+impl FallbackImpl {
+ fn new(left: &dyn Array, right: &dyn Array) -> Self {
+ Self {
+ truthy: left.to_data(),
+ falsy: right.to_data(),
+ }
+ }
+}
+
+impl ZipImpl for FallbackImpl {
+ fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef,
ArrowError> {
+ zip_impl(predicate, &self.truthy, true, &self.falsy, true)
+ }
+}
+
+struct PrimitiveScalarImpl<T: ArrowPrimitiveType> {
+ data_type: DataType,
+ truthy: Option<T::Native>,
+ falsy: Option<T::Native>,
+}
+
+impl<T: ArrowPrimitiveType> Debug for PrimitiveScalarImpl<T> {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("PrimitiveScalarImpl")
+ .field("data_type", &self.data_type)
+ .field("truthy", &self.truthy)
+ .field("falsy", &self.falsy)
+ .finish()
+ }
+}
+
+impl<T: ArrowPrimitiveType> PrimitiveScalarImpl<T> {
+ fn new(truthy: &dyn Array, falsy: &dyn Array) -> Self {
+ Self {
+ data_type: truthy.data_type().clone(),
+ truthy: Self::get_value_from_scalar(truthy),
+ falsy: Self::get_value_from_scalar(falsy),
+ }
+ }
+
+ fn get_value_from_scalar(scalar: &dyn Array) -> Option<T::Native> {
+ if scalar.is_null(0) {
+ None
+ } else {
+ let value = scalar.as_primitive::<T>().value(0);
+
+ Some(value)
+ }
+ }
+
+ /// return an output array that has
+ /// `value` in all locations where predicate is true
+ /// `null` otherwise
+ fn get_scalar_and_null_buffer_for_single_non_nullable(
+ predicate: BooleanBuffer,
+ value: T::Native,
+ ) -> (Vec<T::Native>, Option<NullBuffer>) {
+ let result_len = predicate.len();
+ let nulls = NullBuffer::new(predicate);
+ let scalars = vec![value; result_len];
+
+ (scalars, Some(nulls))
+ }
+}
+
+impl<T: ArrowPrimitiveType> ZipImpl for PrimitiveScalarImpl<T> {
+ fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef,
ArrowError> {
+ let result_len = predicate.len();
+ // Nulls are treated as false
+ let predicate = maybe_prep_null_mask_filter(predicate);
+
+ let (scalars, nulls): (Vec<T::Native>, Option<NullBuffer>) = match
(self.truthy, self.falsy)
+ {
+ (Some(truthy_val), Some(falsy_val)) => {
+ let scalars: Vec<T::Native> = predicate
+ .iter()
+ .map(|b| if b { truthy_val } else { falsy_val })
+ .collect();
+
+ (scalars, None)
+ }
+ (Some(truthy_val), None) => {
+ // If a value is true we need the TRUTHY and the null buffer
will have 1 (meaning not null)
+ // If a value is false we need the FALSY and the null buffer
will have 0 (meaning null)
+
+
Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, truthy_val)
+ }
+ (None, Some(falsy_val)) => {
+ // Flipping the boolean buffer as we want the opposite of the
TRUE case
+ //
+ // if the condition is true we want null so we need to NOT the
value so we get 0 (meaning null)
+ // if the condition is false we want the FALSY value so we
need to NOT the value so we get 1 (meaning not null)
+ let predicate = predicate.not();
+
+
Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, falsy_val)
+ }
+ (None, None) => {
+ // All values are null
+ let nulls = NullBuffer::new_null(result_len);
+ let scalars = vec![T::default_value(); result_len];
+
+ (scalars, Some(nulls))
+ }
+ };
+
+ let scalars = ScalarBuffer::<T::Native>::from(scalars);
+ let output = PrimitiveArray::<T>::try_new(scalars, nulls)?;
+
+ // Keep decimal precisions, scales or timestamps timezones
+ let output = output.with_data_type(self.data_type.clone());
+
+ Ok(Arc::new(output))
+ }
+}
+
+#[derive(PartialEq, Hash)]
+struct BytesScalarImpl<T: ByteArrayType> {
+ truthy: Option<Vec<u8>>,
+ falsy: Option<Vec<u8>>,
+ phantom: PhantomData<T>,
+}
+
+impl<T: ByteArrayType> Debug for BytesScalarImpl<T> {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("BytesScalarImpl")
+ .field("truthy", &self.truthy)
+ .field("falsy", &self.falsy)
+ .finish()
+ }
+}
+
+impl<T: ByteArrayType> BytesScalarImpl<T> {
+ fn new(truthy_value: &dyn Array, falsy_value: &dyn Array) -> Self {
+ Self {
+ truthy: Self::get_value_from_scalar(truthy_value),
+ falsy: Self::get_value_from_scalar(falsy_value),
+ phantom: PhantomData,
+ }
+ }
+
+ fn get_value_from_scalar(scalar: &dyn Array) -> Option<Vec<u8>> {
+ if scalar.is_null(0) {
+ None
+ } else {
+ let bytes: &[u8] = scalar.as_bytes::<T>().value(0).as_ref();
+
+ Some(bytes.to_vec())
+ }
+ }
+
+ /// return an output array that has
+ /// `value` in all locations where predicate is true
+ /// `null` otherwise
+ fn get_scalar_and_null_buffer_for_single_non_nullable(
+ predicate: BooleanBuffer,
+ value: &[u8],
+ ) -> (Buffer, OffsetBuffer<T::Offset>, Option<NullBuffer>) {
+ let value_length = value.len();
+
+ let number_of_true = predicate.count_set_bits();
+
+ // Fast path for all nulls
+ if number_of_true == 0 {
+ // All values are null
+ let nulls = NullBuffer::new_null(predicate.len());
+
+ return (
+ // Empty bytes
+ Buffer::from(&[]),
+ // All nulls so all lengths are 0
+ OffsetBuffer::<T::Offset>::new_zeroed(predicate.len()),
+ Some(nulls),
+ );
+ }
+
+ let offsets = OffsetBuffer::<T::Offset>::from_lengths(
+ predicate.iter().map(|b| if b { value_length } else { 0 }),
+ );
+
+ let mut bytes = MutableBuffer::with_capacity(0);
+ bytes.repeat_slice_n_times(value, number_of_true);
+
+ let bytes = Buffer::from(bytes);
+
+ // If a value is true we need the TRUTHY and the null buffer will have
1 (meaning not null)
+ // If a value is false we need the FALSY and the null buffer will have
0 (meaning null)
+ let nulls = NullBuffer::new(predicate);
+
+ (bytes, offsets, Some(nulls))
+ }
+
+ /// Create a [`Buffer`] where `value` slice is repeated `number_of_values`
times
+ /// and [`OffsetBuffer`] where there are `number_of_values` lengths, and
all equals to `value` length
+ fn get_bytes_and_offset_for_all_same_value(
+ number_of_values: usize,
+ value: &[u8],
+ ) -> (Buffer, OffsetBuffer<T::Offset>) {
+ let value_length = value.len();
+
+ let offsets =
+ OffsetBuffer::<T::Offset>::from_repeated_length(value_length,
number_of_values);
+
+ let mut bytes = MutableBuffer::with_capacity(0);
+ bytes.repeat_slice_n_times(value, number_of_values);
+ let bytes = Buffer::from(bytes);
+
+ (bytes, offsets)
+ }
+
+ fn create_output_on_non_nulls(
+ predicate: &BooleanBuffer,
+ truthy_val: &[u8],
+ falsy_val: &[u8],
+ ) -> (Buffer, OffsetBuffer<<T as ByteArrayType>::Offset>) {
+ let true_count = predicate.count_set_bits();
+
+ match true_count {
+ 0 => {
+ // All values are falsy
+
+ let (bytes, offsets) =
+
Self::get_bytes_and_offset_for_all_same_value(predicate.len(), falsy_val);
+
+ return (bytes, offsets);
+ }
+ n if n == predicate.len() => {
+ // All values are truthy
+ let (bytes, offsets) =
+
Self::get_bytes_and_offset_for_all_same_value(predicate.len(), truthy_val);
+
+ return (bytes, offsets);
+ }
+
+ _ => {
+ // Fallback
+ }
+ }
+
+ let total_number_of_bytes =
+ true_count * truthy_val.len() + (predicate.len() - true_count) *
falsy_val.len();
+ let mut mutable = MutableBuffer::with_capacity(total_number_of_bytes);
+ let mut offset_buffer_builder =
OffsetBufferBuilder::<T::Offset>::new(predicate.len());
+
+ // keep track of how much is filled
+ let mut filled = 0;
+
+ let truthy_len = truthy_val.len();
+ let falsy_len = falsy_val.len();
+
+ SlicesIterator::from(predicate).for_each(|(start, end)| {
+ // the gap needs to be filled with falsy values
+ if start > filled {
+ let false_repeat_count = start - filled;
+ // Push false value `repeat_count` times
+ mutable.repeat_slice_n_times(falsy_val, false_repeat_count);
+
+ for _ in 0..false_repeat_count {
+ offset_buffer_builder.push_length(falsy_len)
+ }
+ }
+
+ let true_repeat_count = end - start;
+ // fill with truthy values
+ mutable.repeat_slice_n_times(truthy_val, true_repeat_count);
+
+ for _ in 0..true_repeat_count {
+ offset_buffer_builder.push_length(truthy_len)
+ }
+ filled = end;
+ });
+ // the remaining part is falsy
+ if filled < predicate.len() {
+ let false_repeat_count = predicate.len() - filled;
+ // Copy the first item from the 'falsy' array into the output
buffer.
+ mutable.repeat_slice_n_times(falsy_val, false_repeat_count);
+
+ for _ in 0..false_repeat_count {
+ offset_buffer_builder.push_length(falsy_len)
+ }
+ }
+
+ (mutable.into(), offset_buffer_builder.finish())
+ }
+}
+
+impl<T: ByteArrayType> ZipImpl for BytesScalarImpl<T> {
+ fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef,
ArrowError> {
+ let result_len = predicate.len();
+ // Nulls are treated as false
+ let predicate = maybe_prep_null_mask_filter(predicate);
+
+ let (bytes, offsets, nulls): (Buffer, OffsetBuffer<T::Offset>,
Option<NullBuffer>) =
+ match (self.truthy.as_deref(), self.falsy.as_deref()) {
+ (Some(truthy_val), Some(falsy_val)) => {
+ let (bytes, offsets) =
+ Self::create_output_on_non_nulls(&predicate,
truthy_val, falsy_val);
+
+ (bytes, offsets, None)
+ }
+ (Some(truthy_val), None) => {
+
Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, truthy_val)
+ }
+ (None, Some(falsy_val)) => {
+ // Flipping the boolean buffer as we want the opposite of
the TRUE case
+ //
+ // if the condition is true we want null so we need to NOT
the value so we get 0 (meaning null)
+ // if the condition is false we want the FALSE value so we
need to NOT the value so we get 1 (meaning not null)
+ let predicate = predicate.not();
+
Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, falsy_val)
+ }
+ (None, None) => {
+ // All values are null
+ let nulls = NullBuffer::new_null(result_len);
+
+ (
+ // Empty bytes
+ Buffer::from(&[]),
+ // All nulls so all lengths are 0
+ OffsetBuffer::<T::Offset>::new_zeroed(predicate.len()),
+ Some(nulls),
+ )
+ }
+ };
+
+ let output = unsafe {
+ // Safety: the values are based on valid inputs
+ // and `try_new` is expensive for strings as it validate that the
input is valid utf8
+ GenericByteArray::<T>::new_unchecked(offsets, bytes, nulls)
+ };
+
+ Ok(Arc::new(output))
+ }
+}
+
fn maybe_prep_null_mask_filter(predicate: &BooleanArray) -> BooleanBuffer {
// Nulls are treated as false
if predicate.null_count() == 0 {
@@ -182,8 +660,7 @@ fn maybe_prep_null_mask_filter(predicate: &BooleanArray) ->
BooleanBuffer {
#[cfg(test)]
mod test {
use super::*;
- use arrow_array::cast::AsArray;
- use arrow_buffer::{BooleanBuffer, NullBuffer};
+ use arrow_array::types::Int32Type;
#[test]
fn test_zip_kernel_one() {
@@ -260,7 +737,7 @@ mod test {
}
#[test]
- fn test_zip_kernel_scalar_both() {
+ fn test_zip_kernel_scalar_both_mask_ends_with_true() {
let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1));
@@ -272,7 +749,26 @@ mod test {
}
#[test]
- fn test_zip_kernel_scalar_none_1() {
+ fn test_zip_kernel_scalar_both_mask_ends_with_false() {
+ let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
+ let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1));
+
+ let mask = BooleanArray::from(vec![true, true, false, true, false,
false]);
+ let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
+ let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
+ let expected = Int32Array::from(vec![
+ Some(42),
+ Some(42),
+ Some(123),
+ Some(42),
+ Some(123),
+ Some(123),
+ ]);
+ assert_eq!(actual, &expected);
+ }
+
+ #[test]
+ fn test_zip_kernel_primitive_scalar_none_1() {
let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
let scalar_falsy = Scalar::new(Int32Array::new_null(1));
@@ -284,7 +780,7 @@ mod test {
}
#[test]
- fn test_zip_kernel_scalar_none_2() {
+ fn test_zip_kernel_primitive_scalar_none_2() {
let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
let scalar_falsy = Scalar::new(Int32Array::new_null(1));
@@ -295,6 +791,18 @@ mod test {
assert_eq!(actual, &expected);
}
+ #[test]
+ fn test_zip_kernel_primitive_scalar_both_null() {
+ let scalar_truthy = Scalar::new(Int32Array::new_null(1));
+ let scalar_falsy = Scalar::new(Int32Array::new_null(1));
+
+ let mask = BooleanArray::from(vec![false, false, true, true, false]);
+ let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
+ let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
+ let expected = Int32Array::from(vec![None, None, None, None, None]);
+ assert_eq!(actual, &expected);
+ }
+
#[test]
fn
test_zip_primitive_array_with_nulls_is_mask_should_be_treated_as_false() {
let truthy = Int32Array::from_iter_values(vec![1, 2, 3, 4, 5, 6]);
@@ -400,4 +908,318 @@ mod test {
]);
assert_eq!(actual, &expected);
}
+
+ #[test]
+ fn test_zip_kernel_bytes_scalar_none_1() {
+ let scalar_truthy =
Scalar::new(StringArray::from_iter_values(["hello"]));
+ let scalar_falsy = Scalar::new(StringArray::new_null(1));
+
+ let mask = BooleanArray::from(vec![true, true, false, false, true]);
+ let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
+ let actual = out.as_any().downcast_ref::<StringArray>().unwrap();
+ let expected = StringArray::from_iter(vec![
+ Some("hello"),
+ Some("hello"),
+ None,
+ None,
+ Some("hello"),
+ ]);
+ assert_eq!(actual, &expected);
+ }
+
+ #[test]
+ fn test_zip_kernel_bytes_scalar_none_2() {
+ let scalar_truthy = Scalar::new(StringArray::new_null(1));
+ let scalar_falsy =
Scalar::new(StringArray::from_iter_values(["hello"]));
+
+ let mask = BooleanArray::from(vec![true, true, false, false, true]);
+ let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
+ let actual = out.as_any().downcast_ref::<StringArray>().unwrap();
+ let expected = StringArray::from_iter(vec![None, None, Some("hello"),
Some("hello"), None]);
+ assert_eq!(actual, &expected);
+ }
+
+ #[test]
+ fn test_zip_kernel_bytes_scalar_both() {
+ let scalar_truthy =
Scalar::new(StringArray::from_iter_values(["test"]));
+ let scalar_falsy =
Scalar::new(StringArray::from_iter_values(["something else"]));
+
+ // mask ends with false
+ let mask = BooleanArray::from(vec![true, true, false, true, false,
false]);
+ let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
+ let actual = out.as_any().downcast_ref::<StringArray>().unwrap();
+ let expected = StringArray::from_iter(vec![
+ Some("test"),
+ Some("test"),
+ Some("something else"),
+ Some("test"),
+ Some("something else"),
+ Some("something else"),
+ ]);
+ assert_eq!(actual, &expected);
+ }
+
+ #[test]
+ fn test_zip_scalar_bytes_only_taking_one_side() {
+ let mask_len = 5;
+ let all_true_mask = BooleanArray::from(vec![true; mask_len]);
+ let all_false_mask = BooleanArray::from(vec![false; mask_len]);
+
+ let null_scalar = Scalar::new(StringArray::new_null(1));
+ let non_null_scalar_1 =
Scalar::new(StringArray::from_iter_values(["test"]));
+ let non_null_scalar_2 =
Scalar::new(StringArray::from_iter_values(["something else"]));
+
+ {
+ // 1. Test where left is null and right is non-null
+ // and mask is all true
+ let out = zip(&all_true_mask, &null_scalar,
&non_null_scalar_1).unwrap();
+ let actual = out.as_string::<i32>();
+ let expected =
StringArray::from_iter(std::iter::repeat_n(None::<&str>, mask_len));
+ assert_eq!(actual, &expected);
+ }
+
+ {
+ // 2. Test where left is null and right is non-null
+ // and mask is all false
+ let out = zip(&all_false_mask, &null_scalar,
&non_null_scalar_1).unwrap();
+ let actual = out.as_string::<i32>();
+ let expected =
StringArray::from_iter(std::iter::repeat_n(Some("test"), mask_len));
+ assert_eq!(actual, &expected);
+ }
+
+ {
+ // 3. Test where left is non-null and right is null
+ // and mask is all true
+ let out = zip(&all_true_mask, &non_null_scalar_1,
&null_scalar).unwrap();
+ let actual = out.as_string::<i32>();
+ let expected =
StringArray::from_iter(std::iter::repeat_n(Some("test"), mask_len));
+ assert_eq!(actual, &expected);
+ }
+
+ {
+ // 4. Test where left is non-null and right is null
+ // and mask is all false
+ let out = zip(&all_false_mask, &non_null_scalar_1,
&null_scalar).unwrap();
+ let actual = out.as_string::<i32>();
+ let expected =
StringArray::from_iter(std::iter::repeat_n(None::<&str>, mask_len));
+ assert_eq!(actual, &expected);
+ }
+
+ {
+ // 5. Test where both left and right are not null
+ // and mask is all true
+ let out = zip(&all_true_mask, &non_null_scalar_1,
&non_null_scalar_2).unwrap();
+ let actual = out.as_string::<i32>();
+ let expected =
StringArray::from_iter(std::iter::repeat_n(Some("test"), mask_len));
+ assert_eq!(actual, &expected);
+ }
+
+ {
+ // 6. Test where both left and right are not null
+ // and mask is all false
+ let out = zip(&all_false_mask, &non_null_scalar_1,
&non_null_scalar_2).unwrap();
+ let actual = out.as_string::<i32>();
+ let expected =
+ StringArray::from_iter(std::iter::repeat_n(Some("something
else"), mask_len));
+ assert_eq!(actual, &expected);
+ }
+
+ {
+ // 7. Test where both left and right are null
+ // and mask is random
+ let mask = BooleanArray::from(vec![true, false, true, false,
true]);
+ let out = zip(&mask, &null_scalar, &null_scalar).unwrap();
+ let actual = out.as_string::<i32>();
+ let expected =
StringArray::from_iter(std::iter::repeat_n(None::<&str>, mask_len));
+ assert_eq!(actual, &expected);
+ }
+ }
+
+ #[test]
+ fn test_scalar_zipper() {
+ let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
+ let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1));
+
+ let mask = BooleanArray::from(vec![false, false, true, true, false]);
+
+ let scalar_zipper = ScalarZipper::try_new(&scalar_truthy,
&scalar_falsy).unwrap();
+ let out = scalar_zipper.zip(&mask).unwrap();
+ let actual = out.as_primitive::<Int32Type>();
+ let expected = Int32Array::from(vec![Some(123), Some(123), Some(42),
Some(42), Some(123)]);
+ assert_eq!(actual, &expected);
+
+ // test with different mask length as well
+ let mask = BooleanArray::from(vec![true, false, true]);
+ let out = scalar_zipper.zip(&mask).unwrap();
+ let actual = out.as_primitive::<Int32Type>();
+ let expected = Int32Array::from(vec![Some(42), Some(123), Some(42)]);
+ assert_eq!(actual, &expected);
+ }
+
+ #[test]
+ fn test_zip_kernel_scalar_strings() {
+ let scalar_truthy = Scalar::new(StringArray::from(vec!["hello"]));
+ let scalar_falsy = Scalar::new(StringArray::from(vec!["world"]));
+
+ let mask = BooleanArray::from(vec![true, false, true, false, true]);
+ let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
+ let actual = out.as_string::<i32>();
+ let expected = StringArray::from(vec![
+ Some("hello"),
+ Some("world"),
+ Some("hello"),
+ Some("world"),
+ Some("hello"),
+ ]);
+ assert_eq!(actual, &expected);
+ }
+
+ #[test]
+ fn test_zip_kernel_scalar_binary() {
+ let truthy_bytes: &[u8] = b"\xFF\xFE\xFD";
+ let falsy_bytes: &[u8] = b"world";
+ let scalar_truthy = Scalar::new(BinaryArray::from_iter_values(
+ // Non valid UTF8 bytes
+ vec![truthy_bytes],
+ ));
+ let scalar_falsy =
Scalar::new(BinaryArray::from_iter_values(vec![falsy_bytes]));
+
+ let mask = BooleanArray::from(vec![true, false, true, false, true]);
+ let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
+ let actual = out.as_binary::<i32>();
+ let expected = BinaryArray::from(vec![
+ Some(truthy_bytes),
+ Some(falsy_bytes),
+ Some(truthy_bytes),
+ Some(falsy_bytes),
+ Some(truthy_bytes),
+ ]);
+ assert_eq!(actual, &expected);
+ }
+
+ #[test]
+ fn test_zip_kernel_scalar_large_binary() {
+ let truthy_bytes: &[u8] = b"hey";
+ let falsy_bytes: &[u8] = b"world";
+ let scalar_truthy =
Scalar::new(LargeBinaryArray::from_iter_values(vec![truthy_bytes]));
+ let scalar_falsy =
Scalar::new(LargeBinaryArray::from_iter_values(vec![falsy_bytes]));
+
+ let mask = BooleanArray::from(vec![true, false, true, false, true]);
+ let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
+ let actual = out.as_binary::<i64>();
+ let expected = LargeBinaryArray::from(vec![
+ Some(truthy_bytes),
+ Some(falsy_bytes),
+ Some(truthy_bytes),
+ Some(falsy_bytes),
+ Some(truthy_bytes),
+ ]);
+ assert_eq!(actual, &expected);
+ }
+
+ // Test to ensure that the precision and scale are kept when zipping
Decimal128 data
+ #[test]
+ fn test_zip_decimal_with_custom_precision_and_scale() {
+ let arr = Decimal128Array::from_iter_values([12345, 456, 7890,
-123223423432432])
+ .with_precision_and_scale(20, 2)
+ .unwrap();
+
+ let arr: ArrayRef = Arc::new(arr);
+
+ let scalar_1 = Scalar::new(arr.slice(0, 1));
+ let scalar_2 = Scalar::new(arr.slice(1, 1));
+ let null_scalar = Scalar::new(new_null_array(arr.data_type(), 1));
+ let array_1: ArrayRef = arr.slice(0, 2);
+ let array_2: ArrayRef = arr.slice(2, 2);
+
+ test_zip_output_data_types_for_input(scalar_1, scalar_2, null_scalar,
array_1, array_2);
+ }
+
+ // Test to ensure that the timezone is kept when zipping TimestampArray
data
+ #[test]
+ fn test_zip_timestamp_with_timezone() {
+ let arr = TimestampSecondArray::from(vec![0, 1000, 2000, 4000])
+ .with_timezone("+01:00".to_string());
+
+ let arr: ArrayRef = Arc::new(arr);
+
+ let scalar_1 = Scalar::new(arr.slice(0, 1));
+ let scalar_2 = Scalar::new(arr.slice(1, 1));
+ let null_scalar = Scalar::new(new_null_array(arr.data_type(), 1));
+ let array_1: ArrayRef = arr.slice(0, 2);
+ let array_2: ArrayRef = arr.slice(2, 2);
+
+ test_zip_output_data_types_for_input(scalar_1, scalar_2, null_scalar,
array_1, array_2);
+ }
+
+ fn test_zip_output_data_types_for_input(
+ scalar_1: Scalar<ArrayRef>,
+ scalar_2: Scalar<ArrayRef>,
+ null_scalar: Scalar<ArrayRef>,
+ array_1: ArrayRef,
+ array_2: ArrayRef,
+ ) {
+ // non null Scalar vs non null Scalar
+ test_zip_output_data_type(&scalar_1, &scalar_2, 10);
+
+ // null Scalar vs non-null Scalar (and vice versa)
+ test_zip_output_data_type(&null_scalar, &scalar_1, 10);
+ test_zip_output_data_type(&scalar_1, &null_scalar, 10);
+
+ // non-null Scalar and array (and vice versa)
+ test_zip_output_data_type(&array_1.as_ref(), &scalar_1, array_1.len());
+ test_zip_output_data_type(&scalar_1, &array_1.as_ref(), array_1.len());
+
+ // Array and null scalar (and vice versa)
+ test_zip_output_data_type(&array_1.as_ref(), &null_scalar,
array_1.len());
+
+ test_zip_output_data_type(&null_scalar, &array_1.as_ref(),
array_1.len());
+
+ // Both arrays
+ test_zip_output_data_type(&array_1.as_ref(), &array_2.as_ref(),
array_1.len());
+ }
+
+ fn test_zip_output_data_type(truthy: &dyn Datum, falsy: &dyn Datum,
mask_length: usize) {
+ let expected_data_type = truthy.get().0.data_type().clone();
+ assert_eq!(&expected_data_type, falsy.get().0.data_type());
+
+ // Try different masks to test different paths
+ let mask_all_true = BooleanArray::from(vec![true; mask_length]);
+ let mask_all_false = BooleanArray::from(vec![false; mask_length]);
+ let mask_some_true_and_false =
+ BooleanArray::from((0..mask_length).map(|i| i % 2 ==
0).collect::<Vec<bool>>());
+
+ for mask in [&mask_all_true, &mask_all_false,
&mask_some_true_and_false] {
+ let out = zip(mask, truthy, falsy).unwrap();
+ assert_eq!(out.data_type(), &expected_data_type);
+ }
+ }
+
+ #[test]
+ fn zip_scalar_fallback_impl() {
+ let truthy_list_item_scalar = Some(vec![Some(1), None, Some(3)]);
+ let truthy_list_array_scalar =
+ Scalar::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ truthy_list_item_scalar.clone(),
+ ]));
+ let falsy_list_item_scalar = Some(vec![None, Some(2), Some(4)]);
+ let falsy_list_array_scalar =
+ Scalar::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ falsy_list_item_scalar.clone(),
+ ]));
+ let mask = BooleanArray::from(vec![true, false, true, false, false,
true, false]);
+ let out = zip(&mask, &truthy_list_array_scalar,
&falsy_list_array_scalar).unwrap();
+ let actual = out.as_list::<i32>();
+
+ let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ truthy_list_item_scalar.clone(),
+ falsy_list_item_scalar.clone(),
+ truthy_list_item_scalar.clone(),
+ falsy_list_item_scalar.clone(),
+ falsy_list_item_scalar.clone(),
+ truthy_list_item_scalar.clone(),
+ falsy_list_item_scalar.clone(),
+ ]);
+ assert_eq!(actual, &expected);
+ }
}