rluvaton commented on code in PR #8653:
URL: https://github.com/apache/arrow-rs/pull/8653#discussion_r2443324535


##########
arrow-select/src/zip.rs:
##########
@@ -166,6 +195,353 @@ pub fn zip(
     Ok(make_array(data))
 }
 
+/// Zipper for 2 scalars
+///
+/// Useful for using in `IF <expr> THEN <scalar> ELSE <scalar> END` expressions
+///
+#[derive(Debug, Clone)]
+pub struct ScalarZipper {
+    zip_impl: Arc<dyn ZipImpl>,
+}
+
+impl ScalarZipper {
+    pub fn try_new(truthy: &dyn Datum, falsy: &dyn Datum) -> Result<Self, 
ArrowError> {
+        let (truthy, truthy_is_scalar) = truthy.get();
+        let (falsy, falsy_is_scalar) = falsy.get();
+
+        if truthy.data_type() != falsy.data_type() {
+            return Err(ArrowError::InvalidArgumentError(
+                "arguments need to have the same data type".into(),
+            ));
+        }
+
+        if !truthy_is_scalar {
+            return Err(ArrowError::InvalidArgumentError(
+                "only scalar arrays are supported".into(),
+            ));
+        }
+
+        if !falsy_is_scalar {
+            return Err(ArrowError::InvalidArgumentError(
+                "only scalar arrays are supported".into(),
+            ));
+        }
+
+        if truthy.len() != 1 {
+            return Err(ArrowError::InvalidArgumentError(
+                "scalar arrays must have 1 element".into(),
+            ));
+        }
+        if falsy.len() != 1 {
+            return Err(ArrowError::InvalidArgumentError(
+                "scalar arrays must have 1 element".into(),
+            ));
+        }
+
+        macro_rules! primitive_size_helper {
+            ($t:ty) => {
+                Arc::new(PrimitiveScalarImpl::<$t>::new(truthy, falsy)) as 
Arc<dyn ZipImpl>
+            };
+        }
+
+        let zip_impl = downcast_primitive! {
+            truthy.data_type() => (primitive_size_helper),
+            DataType::Utf8 => {
+                Arc::new(BytesScalarImpl::<Utf8Type>::new(truthy, falsy)) as 
Arc<dyn ZipImpl>
+            },
+            DataType::LargeUtf8 => {
+                Arc::new(BytesScalarImpl::<LargeUtf8Type>::new(truthy, falsy)) 
as Arc<dyn ZipImpl>
+            },
+            DataType::Binary => {
+                Arc::new(BytesScalarImpl::<BinaryType>::new(truthy, falsy)) as 
Arc<dyn ZipImpl>
+            },
+            DataType::LargeBinary => {
+                Arc::new(BytesScalarImpl::<LargeBinaryType>::new(truthy, 
falsy)) as Arc<dyn ZipImpl>
+            },
+            _ => {
+                Arc::new(FallbackImpl::new(truthy, falsy)) as Arc<dyn ZipImpl>
+            },
+        };
+
+        Ok(Self { zip_impl })
+    }
+}
+
+/// Impl for creating output array based on input boolean array
+trait ZipImpl: Debug {
+    /// Creating output array based on input boolean array
+    fn create_output(&self, input: &BooleanArray) -> Result<ArrayRef, 
ArrowError>;
+}
+
+#[derive(Debug, PartialEq)]
+struct FallbackImpl {
+    truthy: ArrayData,
+    falsy: ArrayData,
+}
+
+impl FallbackImpl {
+    fn new(left: &dyn Array, right: &dyn Array) -> Self {
+        Self {
+            truthy: left.to_data(),
+            falsy: right.to_data(),
+        }
+    }
+}
+
+impl ZipImpl for FallbackImpl {
+    fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef, 
ArrowError> {
+        zip_impl(predicate, &self.truthy, false, &self.falsy, false)
+    }
+}
+
+struct PrimitiveScalarImpl<T: ArrowPrimitiveType> {
+    data_type: DataType,
+    truthy: Option<T::Native>,
+    falsy: Option<T::Native>,
+}
+
+impl<T: ArrowPrimitiveType> Debug for PrimitiveScalarImpl<T> {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("PrimitiveScalarImpl")
+            .field("data_type", &self.data_type)
+            .field("then_value", &self.truthy)
+            .field("else_value", &self.falsy)
+            .finish()
+    }
+}
+
+impl<T: ArrowPrimitiveType> PrimitiveScalarImpl<T> {
+    fn new(then_value: &dyn Array, else_value: &dyn Array) -> Self {
+        Self {
+            data_type: then_value.data_type().clone(),
+            truthy: Self::get_value_from_scalar(then_value),
+            falsy: Self::get_value_from_scalar(else_value),
+        }
+    }
+
+    fn get_value_from_scalar(scalar: &dyn Array) -> Option<T::Native> {
+        if scalar.is_null(0) {
+            None
+        } else {
+            let value = scalar.as_primitive::<T>().value(0);
+
+            Some(value)
+        }
+    }
+}
+
+impl<T: ArrowPrimitiveType> PrimitiveScalarImpl<T> {
+    fn get_scalar_and_null_buffer_for_single_non_nullable(
+        predicate: BooleanBuffer,
+        value: T::Native,
+    ) -> (Vec<T::Native>, Option<NullBuffer>) {
+        let result_len = predicate.len();
+        let nulls = NullBuffer::new(predicate);
+        let scalars = vec![value; result_len];
+
+        (scalars, Some(nulls))
+    }
+}
+
+impl<T: ArrowPrimitiveType> ZipImpl for PrimitiveScalarImpl<T> {
+    fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef, 
ArrowError> {
+        let result_len = predicate.len();
+        // Nulls are treated as false
+        let predicate = combine_nulls_and_false(predicate);
+
+        let (scalars, nulls): (Vec<T::Native>, Option<NullBuffer>) = match 
(self.truthy, self.falsy)
+        {
+            (Some(then_val), Some(else_val)) => {
+                let scalars: Vec<T::Native> = predicate
+                    .iter()
+                    .map(|b| if b { then_val } else { else_val })
+                    .collect();
+
+                (scalars, None)
+            }
+            (Some(then_val), None) => {
+                // If a value is true we need the TRUTHY and the null buffer 
will have 1 (meaning not null)
+                // If a value is false we need the FALSY and the null buffer 
will have 0 (meaning null)
+
+                
Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, then_val)
+            }
+            (None, Some(else_val)) => {
+                // Flipping the boolean buffer as we want the opposite of the 
THEN case
+                //
+                // if the condition is true we want null so we need to NOT the 
value so we get 0 (meaning null)
+                // if the condition is false we want the FALSY value so we 
need to NOT the value so we get 1 (meaning not null)
+                let predicate = predicate.not();
+
+                
Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, else_val)
+            }
+            (None, None) => {
+                // All values are null
+                let nulls = NullBuffer::new_null(result_len);
+                let scalars = vec![T::default_value(); result_len];
+
+                (scalars, Some(nulls))
+            }
+        };
+
+        let scalars = ScalarBuffer::<T::Native>::from(scalars);
+        let output = PrimitiveArray::<T>::try_new(scalars, nulls)?;
+
+        // Keep decimal precisions, scales or timestamps timezones
+        let output = output.with_data_type(self.data_type.clone());
+
+        Ok(Arc::new(output))
+    }
+}
+
+#[derive(PartialEq, Hash)]
+struct BytesScalarImpl<T: ByteArrayType> {
+    truthy: Option<Vec<u8>>,
+    falsy: Option<Vec<u8>>,
+    phantom: PhantomData<T>,
+}
+
+impl<T: ByteArrayType> Debug for BytesScalarImpl<T> {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("BytesScalarImpl")
+            .field("then_value", &self.truthy)
+            .field("else_value", &self.falsy)
+            .finish()
+    }
+}
+
+impl<T: ByteArrayType> BytesScalarImpl<T> {
+    fn new(then_value: &dyn Array, else_value: &dyn Array) -> Self {
+        Self {
+            truthy: Self::get_value_from_scalar(then_value),
+            falsy: Self::get_value_from_scalar(else_value),
+            phantom: PhantomData,
+        }
+    }
+
+    fn get_value_from_scalar(scalar: &dyn Array) -> Option<Vec<u8>> {
+        if scalar.is_null(0) {
+            None
+        } else {
+            let bytes: &[u8] = scalar.as_bytes::<T>().value(0).as_ref();
+
+            Some(bytes.to_vec())
+        }
+    }
+
+    fn get_scalar_and_null_buffer_for_single_non_nullable(
+        predicate: BooleanBuffer,
+        value: &[u8],
+    ) -> (Buffer, OffsetBuffer<T::Offset>, Option<NullBuffer>) {
+        let value_length = value.len();
+        let offsets = OffsetBuffer::<T::Offset>::from_lengths(
+            predicate.iter().map(|b| if b { value_length } else { 0 }),
+        );
+
+        let length = offsets.last().map(|o| o.as_usize()).unwrap_or(0);
+
+        let bytes_iter = predicate
+            .iter()
+            .flat_map(|b| if b { value } else { &[] })
+            .copied();
+
+        let bytes = unsafe {
+            // Safety: the iterator is trusted length as we limit it to the 
known length
+            MutableBuffer::from_trusted_len_iter(
+                bytes_iter
+                    // Limiting the bytes so the iterator will be trusted 
length
+                    .take(length),
+            )
+        };
+
+        // If a value is true we need the TRUTHY and the null buffer will have 
1 (meaning not null)
+        // If a value is false we need the FALSY and the null buffer will have 
0 (meaning null)
+        let nulls = NullBuffer::new(predicate);
+
+        (bytes.into(), offsets, Some(nulls))
+    }
+}
+
+impl<T: ByteArrayType> ZipImpl for BytesScalarImpl<T> {
+    fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef, 
ArrowError> {
+        let result_len = predicate.len();
+        // Nulls are treated as false
+        let predicate = combine_nulls_and_false(predicate);
+
+        let (bytes, offsets, nulls): (Buffer, OffsetBuffer<T::Offset>, 
Option<NullBuffer>) =
+            match (self.truthy.as_deref(), self.falsy.as_deref()) {
+                (Some(then_val), Some(else_val)) => {
+                    let then_length = then_val.len();
+                    let else_length = else_val.len();
+                    let offsets = 
OffsetBuffer::<T::Offset>::from_lengths(predicate.iter().map(
+                        |b| {
+                            if b { then_length } else { else_length }
+                        },
+                    ));
+
+                    let length = offsets.last().map(|o| 
o.as_usize()).unwrap_or(0);
+
+                    let bytes_iter = predicate
+                        .iter()
+                        .flat_map(|b| if b { then_val } else { else_val })
+                        .copied();
+
+                    let bytes = unsafe {
+                        // Safety: the iterator is trusted length as we limit 
it to the known length
+                        MutableBuffer::from_trusted_len_iter(
+                            bytes_iter
+                                // Limiting the bytes so the iterator will be 
trusted length
+                                .take(length),
+                        )
+                    };
+
+                    (bytes.into(), offsets, None)
+                }
+                (Some(then_val), None) => {
+                    
Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, then_val)
+                }
+                (None, Some(else_val)) => {
+                    // Flipping the boolean buffer as we want the opposite of 
the THEN case
+                    //
+                    // if the condition is true we want null so we need to NOT 
the value so we get 0 (meaning null)
+                    // if the condition is false we want the ELSE value so we 
need to NOT the value so we get 1 (meaning not null)
+                    let predicate = predicate.not();
+                    
Self::get_scalar_and_null_buffer_for_single_non_nullable(predicate, else_val)
+                }
+                (None, None) => {
+                    // All values are null
+                    let nulls = NullBuffer::new_null(result_len);
+
+                    (
+                        // Empty bytes
+                        Buffer::from(&[]),
+                        // All nulls so all lengths are 0
+                        
OffsetBuffer::<T::Offset>::from_lengths(std::iter::repeat_n(0, result_len)),
+                        Some(nulls),
+                    )
+                }
+            };
+
+        let output = unsafe {
+            // Safety: the values are based on valid inputs
+            // and `try_new` is expensive for strings as it validate that the 
input is valid utf8
+            GenericByteArray::<T>::new_unchecked(offsets, Buffer::from(bytes), 
nulls)
+        };
+
+        Ok(Arc::new(output))
+    }
+}
+
+fn combine_nulls_and_false(predicate: &BooleanArray) -> BooleanBuffer {
+    if let Some(nulls) = predicate.nulls().filter(|n| n.null_count() > 0) {
+        predicate.values().bitand(
+            // nulls are represented as 0 (false) in the values buffer
+            nulls.inner(),
+        )
+    } else {
+        predicate.values().clone()
+    }
+}

Review Comment:
   I'm pretty sure there is already a helper function in arrow for this



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to