rluvaton commented on code in PR #8653:
URL: https://github.com/apache/arrow-rs/pull/8653#discussion_r2463752849
##########
arrow-select/src/zip.rs:
##########
@@ -166,9 +196,465 @@ pub fn zip(
Ok(make_array(data))
}
+/// Zipper for 2 scalars
+///
+/// Useful for using in `IF <expr> THEN <scalar> ELSE <scalar> END` expressions
+///
+/// # Example
+/// ```
+/// # use std::sync::Arc;
+/// # use arrow_array::{ArrayRef, BooleanArray, Int32Array, Scalar,
cast::AsArray, types::Int32Type};
+///
+/// # use arrow_select::zip::ScalarZipper;
+/// let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
+/// let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1));
+/// let zipper = ScalarZipper::try_new(&scalar_truthy, &scalar_falsy).unwrap();
+///
+/// // Later when we have a boolean mask
+/// let mask = BooleanArray::from(vec![true, false, true, false, true]);
+/// let result = zipper.zip(&mask).unwrap();
+/// let actual = result.as_primitive::<Int32Type>();
+/// let expected = Int32Array::from(vec![Some(42), Some(123), Some(42),
Some(123), Some(42)]);
+/// ```
+///
+#[derive(Debug, Clone)]
+pub struct ScalarZipper {
+ zip_impl: Arc<dyn ZipImpl>,
+}
+
+impl ScalarZipper {
+ /// Try to create a new ScalarZipper from two scalar Datum
+ ///
+ /// # Errors
+ /// returns error if:
+ /// - the two Datum have different data types
+ /// - either Datum is not a scalar (or has more than 1 element)
+ ///
+ pub fn try_new(truthy: &dyn Datum, falsy: &dyn Datum) -> Result<Self,
ArrowError> {
+ let (truthy, truthy_is_scalar) = truthy.get();
+ let (falsy, falsy_is_scalar) = falsy.get();
+
+ if truthy.data_type() != falsy.data_type() {
+ return Err(ArrowError::InvalidArgumentError(
+ "arguments need to have the same data type".into(),
+ ));
+ }
+
+ if !truthy_is_scalar {
+ return Err(ArrowError::InvalidArgumentError(
+ "only scalar arrays are supported".into(),
+ ));
+ }
+
+ if !falsy_is_scalar {
+ return Err(ArrowError::InvalidArgumentError(
+ "only scalar arrays are supported".into(),
+ ));
+ }
+
+ if truthy.len() != 1 {
+ return Err(ArrowError::InvalidArgumentError(
+ "scalar arrays must have 1 element".into(),
+ ));
+ }
+ if falsy.len() != 1 {
+ return Err(ArrowError::InvalidArgumentError(
+ "scalar arrays must have 1 element".into(),
+ ));
+ }
+
+ macro_rules! primitive_size_helper {
+ ($t:ty) => {
+ Arc::new(PrimitiveScalarImpl::<$t>::new(truthy, falsy)) as
Arc<dyn ZipImpl>
+ };
+ }
+
+ let zip_impl = downcast_primitive! {
+ truthy.data_type() => (primitive_size_helper),
+ DataType::Utf8 => {
+ Arc::new(BytesScalarImpl::<Utf8Type>::new(truthy, falsy)) as
Arc<dyn ZipImpl>
+ },
+ DataType::LargeUtf8 => {
+ Arc::new(BytesScalarImpl::<LargeUtf8Type>::new(truthy, falsy))
as Arc<dyn ZipImpl>
+ },
+ DataType::Binary => {
+ Arc::new(BytesScalarImpl::<BinaryType>::new(truthy, falsy)) as
Arc<dyn ZipImpl>
+ },
+ DataType::LargeBinary => {
+ Arc::new(BytesScalarImpl::<LargeBinaryType>::new(truthy,
falsy)) as Arc<dyn ZipImpl>
+ },
+ _ => {
+ Arc::new(FallbackImpl::new(truthy, falsy)) as Arc<dyn ZipImpl>
+ },
+ };
+
+ Ok(Self { zip_impl })
+ }
+
+ /// Creating output array based on input boolean array and the two scalar
values the zipper was created with
+ /// See struct level documentation for examples.
+ pub fn zip(&self, mask: &BooleanArray) -> Result<ArrayRef, ArrowError> {
+ self.zip_impl.create_output(mask)
+ }
+}
+
+/// Impl for creating output array based on a mask
+trait ZipImpl: Debug + Send + Sync {
+ /// Creating output array based on input boolean array
+ fn create_output(&self, input: &BooleanArray) -> Result<ArrayRef,
ArrowError>;
+}
+
+#[derive(Debug, PartialEq)]
+struct FallbackImpl {
+ truthy: ArrayData,
+ falsy: ArrayData,
+}
+
+impl FallbackImpl {
+ fn new(left: &dyn Array, right: &dyn Array) -> Self {
+ Self {
+ truthy: left.to_data(),
+ falsy: right.to_data(),
+ }
+ }
+}
+
+impl ZipImpl for FallbackImpl {
+ fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef,
ArrowError> {
+ zip_impl(predicate, &self.truthy, false, &self.falsy, false)
+ }
+}
+
+struct PrimitiveScalarImpl<T: ArrowPrimitiveType> {
+ data_type: DataType,
+ truthy: Option<T::Native>,
+ falsy: Option<T::Native>,
+}
+
+impl<T: ArrowPrimitiveType> Debug for PrimitiveScalarImpl<T> {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("PrimitiveScalarImpl")
+ .field("data_type", &self.data_type)
+ .field("truthy", &self.truthy)
+ .field("falsy", &self.falsy)
+ .finish()
+ }
+}
+
+impl<T: ArrowPrimitiveType> PrimitiveScalarImpl<T> {
+ fn new(truthy: &dyn Array, falsy: &dyn Array) -> Self {
+ Self {
+ data_type: truthy.data_type().clone(),
+ truthy: Self::get_value_from_scalar(truthy),
+ falsy: Self::get_value_from_scalar(falsy),
+ }
+ }
+
+ fn get_value_from_scalar(scalar: &dyn Array) -> Option<T::Native> {
+ if scalar.is_null(0) {
+ None
+ } else {
+ let value = scalar.as_primitive::<T>().value(0);
+
+ Some(value)
+ }
+ }
+}
+
+impl<T: ArrowPrimitiveType> PrimitiveScalarImpl<T> {
+ fn get_scalar_and_null_buffer_for_single_non_nullable(
+ predicate: BooleanBuffer,
+ value: T::Native,
+ ) -> (Vec<T::Native>, Option<NullBuffer>) {
+ let result_len = predicate.len();
+ let nulls = NullBuffer::new(predicate);
+ let scalars = vec![value; result_len];
+
+ (scalars, Some(nulls))
+ }
+}
+
+impl<T: ArrowPrimitiveType> ZipImpl for PrimitiveScalarImpl<T> {
+ fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef,
ArrowError> {
+ let result_len = predicate.len();
+ // Nulls are treated as false
+ let predicate = combine_nulls_and_false(predicate);
+
+ let (scalars, nulls): (Vec<T::Native>, Option<NullBuffer>) = match
(self.truthy, self.falsy)
+ {
+ (Some(truthy_val), Some(falsy_val)) => {
+ let scalars: Vec<T::Native> = predicate
+ .iter()
+ .map(|b| if b { truthy_val } else { falsy_val })
+ .collect();
+
+ (scalars, None)
+ }
+ (Some(truthy_val), None) => {
+ // If a value is true we need the TRUTHY and the null buffer
will have 1 (meaning not null)
+ // If a value is false we need the FALSY and the null buffer
will have 0 (meaning null)
+
+
Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, truthy_val)
+ }
+ (None, Some(falsy_val)) => {
+ // Flipping the boolean buffer as we want the opposite of the
TRUE case
+ //
+ // if the condition is true we want null so we need to NOT the
value so we get 0 (meaning null)
+ // if the condition is false we want the FALSY value so we
need to NOT the value so we get 1 (meaning not null)
+ let predicate = predicate.not();
+
+
Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, falsy_val)
+ }
+ (None, None) => {
+ // All values are null
+ let nulls = NullBuffer::new_null(result_len);
+ let scalars = vec![T::default_value(); result_len];
+
+ (scalars, Some(nulls))
+ }
+ };
+
+ let scalars = ScalarBuffer::<T::Native>::from(scalars);
+ let output = PrimitiveArray::<T>::try_new(scalars, nulls)?;
+
+ // Keep decimal precisions, scales or timestamps timezones
+ let output = output.with_data_type(self.data_type.clone());
+
+ Ok(Arc::new(output))
+ }
+}
+
+#[derive(PartialEq, Hash)]
+struct BytesScalarImpl<T: ByteArrayType> {
+ truthy: Option<Vec<u8>>,
+ falsy: Option<Vec<u8>>,
+ phantom: PhantomData<T>,
+}
+
+impl<T: ByteArrayType> Debug for BytesScalarImpl<T> {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("BytesScalarImpl")
+ .field("truthy", &self.truthy)
+ .field("falsy", &self.falsy)
+ .finish()
+ }
+}
+
+impl<T: ByteArrayType> BytesScalarImpl<T> {
+ fn new(truthy_value: &dyn Array, falsy_value: &dyn Array) -> Self {
+ Self {
+ truthy: Self::get_value_from_scalar(truthy_value),
+ falsy: Self::get_value_from_scalar(falsy_value),
+ phantom: PhantomData,
+ }
+ }
+
+ fn get_value_from_scalar(scalar: &dyn Array) -> Option<Vec<u8>> {
+ if scalar.is_null(0) {
+ None
+ } else {
+ let bytes: &[u8] = scalar.as_bytes::<T>().value(0).as_ref();
+
+ Some(bytes.to_vec())
+ }
+ }
+
+ fn get_scalar_and_null_buffer_for_single_non_nullable(
+ predicate: BooleanBuffer,
+ value: &[u8],
+ ) -> (Buffer, OffsetBuffer<T::Offset>, Option<NullBuffer>) {
+ let value_length = value.len();
+
+ let number_of_true = predicate.count_set_bits();
+
+ // Fast path for all nulls
+ if number_of_true == 0 {
+ // All values are null
+ let nulls = NullBuffer::new_null(predicate.len());
+
+ return (
+ // Empty bytes
+ Buffer::from(&[]),
+ // All nulls so all lengths are 0
+ OffsetBuffer::<T::Offset>::new_zeroed(predicate.len()),
+ Some(nulls),
+ );
+ }
+
+ let offsets = OffsetBuffer::<T::Offset>::from_lengths(
+ predicate.iter().map(|b| if b { value_length } else { 0 }),
+ );
+
+ let mut bytes = MutableBuffer::with_capacity(0);
+ bytes.repeat_slice_n_times(value, number_of_true);
+
+ let bytes = Buffer::from(bytes);
+
+ // If a value is true we need the TRUTHY and the null buffer will have
1 (meaning not null)
+ // If a value is false we need the FALSY and the null buffer will have
0 (meaning null)
+ let nulls = NullBuffer::new(predicate);
+
+ (bytes, offsets, Some(nulls))
+ }
+
+ fn get_bytes_and_offset_for_all_same_value(
+ predicate: &BooleanBuffer,
+ value: &[u8],
+ ) -> (Buffer, OffsetBuffer<T::Offset>) {
+ let value_length = value.len();
+
+ let offsets =
+ OffsetBuffer::<T::Offset>::from_repeated_length(value_length,
predicate.len());
+
+ let mut bytes = MutableBuffer::with_capacity(0);
+ bytes.repeat_slice_n_times(value, predicate.len());
+ let bytes = Buffer::from(bytes);
+
+ (bytes, offsets)
+ }
+}
+
+impl<T: ByteArrayType> ZipImpl for BytesScalarImpl<T> {
+ fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef,
ArrowError> {
+ let result_len = predicate.len();
+ // Nulls are treated as false
+ let predicate = combine_nulls_and_false(predicate);
+
+ let (bytes, offsets, nulls): (Buffer, OffsetBuffer<T::Offset>,
Option<NullBuffer>) =
+ match (self.truthy.as_deref(), self.falsy.as_deref()) {
+ (Some(truthy_val), Some(falsy_val)) => {
+ let (bytes, offsets) =
+ Self::create_output_on_non_nulls(&predicate,
truthy_val, falsy_val);
+
+ (bytes, offsets, None)
+ }
+ (Some(truthy_val), None) => {
+
Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, truthy_val)
+ }
+ (None, Some(falsy_val)) => {
+ // Flipping the boolean buffer as we want the opposite of
the TRUE case
+ //
+ // if the condition is true we want null so we need to NOT
the value so we get 0 (meaning null)
+ // if the condition is false we want the FALSE value so we
need to NOT the value so we get 1 (meaning not null)
+ let predicate = predicate.not();
+
Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, falsy_val)
+ }
+ (None, None) => {
+ // All values are null
+ let nulls = NullBuffer::new_null(result_len);
+
+ (
+ // Empty bytes
+ Buffer::from(&[]),
+ // All nulls so all lengths are 0
+ OffsetBuffer::<T::Offset>::new_zeroed(predicate.len()),
+ Some(nulls),
+ )
+ }
+ };
+
+ let output = unsafe {
+ // Safety: the values are based on valid inputs
+ // and `try_new` is expensive for strings as it validate that the
input is valid utf8
+ GenericByteArray::<T>::new_unchecked(offsets, bytes, nulls)
+ };
+
+ Ok(Arc::new(output))
+ }
+}
+
+impl<T: ByteArrayType> BytesScalarImpl<T> {
+ fn create_output_on_non_nulls(
+ predicate: &BooleanBuffer,
+ truthy_val: &[u8],
+ falsy_val: &[u8],
+ ) -> (Buffer, OffsetBuffer<<T as ByteArrayType>::Offset>) {
+ let true_count = predicate.count_set_bits();
+
+ match true_count {
+ 0 => {
+ // All values are falsy
+
+ let (bytes, offsets) =
+ Self::get_bytes_and_offset_for_all_same_value(predicate,
falsy_val);
+
+ return (bytes, offsets);
+ }
+ n if n == predicate.len() => {
+ // All values are truthy
+ let (bytes, offsets) =
+ Self::get_bytes_and_offset_for_all_same_value(predicate,
truthy_val);
+
+ return (bytes, offsets);
+ }
+
+ _ => {
+ // Fallback
+ }
+ }
+
+ let total_number_of_bytes =
+ true_count * truthy_val.len() + (predicate.len() - true_count) *
falsy_val.len();
+ let mut mutable = MutableBuffer::with_capacity(total_number_of_bytes);
+ let mut offset_buffer_builder =
OffsetBufferBuilder::<T::Offset>::new(predicate.len());
+
+ // keep track of how much is filled
+ let mut filled = 0;
+
+ let truthy_len = truthy_val.len();
+ let falsy_len = falsy_val.len();
+
+ SlicesIterator::from(predicate).for_each(|(start, end)| {
+ // the gap needs to be filled with falsy values
+ if start > filled {
+ let false_repeat_count = start - filled;
+ // Push false value `repeat_count` times
+ mutable.repeat_slice_n_times(falsy_val, false_repeat_count);
+
+ for _ in 0..false_repeat_count {
+ offset_buffer_builder.push_length(falsy_len)
+ }
+ }
+
+ let true_repeat_count = end - start;
+ // fill with truthy values
+ mutable.repeat_slice_n_times(truthy_val, true_repeat_count);
Review Comment:
yeah I agree, but if scalars are tend to be short (less than 12 bytes) than
it won't be faster and possibly even slower due to the indirections as the
bytes are inlined anyway, no?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]