rluvaton commented on code in PR #8653:
URL: https://github.com/apache/arrow-rs/pull/8653#discussion_r2443324095
##########
arrow-select/src/zip.rs:
##########
@@ -166,6 +195,353 @@ pub fn zip(
Ok(make_array(data))
}
+/// Zipper for 2 scalars
+///
+/// Useful for using in `IF <expr> THEN <scalar> ELSE <scalar> END` expressions
+///
+#[derive(Debug, Clone)]
+pub struct ScalarZipper {
+ zip_impl: Arc<dyn ZipImpl>,
+}
+
+impl ScalarZipper {
+ pub fn try_new(truthy: &dyn Datum, falsy: &dyn Datum) -> Result<Self,
ArrowError> {
+ let (truthy, truthy_is_scalar) = truthy.get();
+ let (falsy, falsy_is_scalar) = falsy.get();
+
+ if truthy.data_type() != falsy.data_type() {
+ return Err(ArrowError::InvalidArgumentError(
+ "arguments need to have the same data type".into(),
+ ));
+ }
+
+ if !truthy_is_scalar {
+ return Err(ArrowError::InvalidArgumentError(
+ "only scalar arrays are supported".into(),
+ ));
+ }
+
+ if !falsy_is_scalar {
+ return Err(ArrowError::InvalidArgumentError(
+ "only scalar arrays are supported".into(),
+ ));
+ }
+
+ if truthy.len() != 1 {
+ return Err(ArrowError::InvalidArgumentError(
+ "scalar arrays must have 1 element".into(),
+ ));
+ }
+ if falsy.len() != 1 {
+ return Err(ArrowError::InvalidArgumentError(
+ "scalar arrays must have 1 element".into(),
+ ));
+ }
+
+ macro_rules! primitive_size_helper {
+ ($t:ty) => {
+ Arc::new(PrimitiveScalarImpl::<$t>::new(truthy, falsy)) as
Arc<dyn ZipImpl>
+ };
+ }
+
+ let zip_impl = downcast_primitive! {
+ truthy.data_type() => (primitive_size_helper),
+ DataType::Utf8 => {
+ Arc::new(BytesScalarImpl::<Utf8Type>::new(truthy, falsy)) as
Arc<dyn ZipImpl>
+ },
+ DataType::LargeUtf8 => {
+ Arc::new(BytesScalarImpl::<LargeUtf8Type>::new(truthy, falsy))
as Arc<dyn ZipImpl>
+ },
+ DataType::Binary => {
+ Arc::new(BytesScalarImpl::<BinaryType>::new(truthy, falsy)) as
Arc<dyn ZipImpl>
+ },
+ DataType::LargeBinary => {
+ Arc::new(BytesScalarImpl::<LargeBinaryType>::new(truthy,
falsy)) as Arc<dyn ZipImpl>
+ },
+ _ => {
+ Arc::new(FallbackImpl::new(truthy, falsy)) as Arc<dyn ZipImpl>
+ },
+ };
+
+ Ok(Self { zip_impl })
+ }
+}
+
+/// Impl for creating output array based on input boolean array
+trait ZipImpl: Debug {
+ /// Creating output array based on input boolean array
+ fn create_output(&self, input: &BooleanArray) -> Result<ArrayRef,
ArrowError>;
+}
+
+#[derive(Debug, PartialEq)]
+struct FallbackImpl {
+ truthy: ArrayData,
+ falsy: ArrayData,
+}
+
+impl FallbackImpl {
+ fn new(left: &dyn Array, right: &dyn Array) -> Self {
+ Self {
+ truthy: left.to_data(),
+ falsy: right.to_data(),
+ }
+ }
+}
+
+impl ZipImpl for FallbackImpl {
+ fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef,
ArrowError> {
+ zip_impl(predicate, &self.truthy, false, &self.falsy, false)
+ }
+}
+
+struct PrimitiveScalarImpl<T: ArrowPrimitiveType> {
+ data_type: DataType,
+ truthy: Option<T::Native>,
+ falsy: Option<T::Native>,
+}
+
+impl<T: ArrowPrimitiveType> Debug for PrimitiveScalarImpl<T> {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("PrimitiveScalarImpl")
+ .field("data_type", &self.data_type)
+ .field("then_value", &self.truthy)
+ .field("else_value", &self.falsy)
+ .finish()
+ }
+}
+
+impl<T: ArrowPrimitiveType> PrimitiveScalarImpl<T> {
+ fn new(then_value: &dyn Array, else_value: &dyn Array) -> Self {
+ Self {
+ data_type: then_value.data_type().clone(),
+ truthy: Self::get_value_from_scalar(then_value),
+ falsy: Self::get_value_from_scalar(else_value),
+ }
+ }
+
+ fn get_value_from_scalar(scalar: &dyn Array) -> Option<T::Native> {
+ if scalar.is_null(0) {
+ None
+ } else {
+ let value = scalar.as_primitive::<T>().value(0);
+
+ Some(value)
+ }
+ }
+}
+
+impl<T: ArrowPrimitiveType> PrimitiveScalarImpl<T> {
+ fn get_scalar_and_null_buffer_for_single_non_nullable(
+ predicate: BooleanBuffer,
+ value: T::Native,
+ ) -> (Vec<T::Native>, Option<NullBuffer>) {
+ let result_len = predicate.len();
+ let nulls = NullBuffer::new(predicate);
+ let scalars = vec![value; result_len];
+
+ (scalars, Some(nulls))
+ }
+}
+
+impl<T: ArrowPrimitiveType> ZipImpl for PrimitiveScalarImpl<T> {
+ fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef,
ArrowError> {
+ let result_len = predicate.len();
+ // Nulls are treated as false
+ let predicate = combine_nulls_and_false(predicate);
+
+ let (scalars, nulls): (Vec<T::Native>, Option<NullBuffer>) = match
(self.truthy, self.falsy)
+ {
+ (Some(then_val), Some(else_val)) => {
+ let scalars: Vec<T::Native> = predicate
+ .iter()
+ .map(|b| if b { then_val } else { else_val })
+ .collect();
Review Comment:
This should use conditional move which will be faster than `SlicesIterator`
in most cases
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]