This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 0d825c1  Define eq_dyn_scalar API (#1074)
0d825c1 is described below

commit 0d825c196e343805c7500bdc06af0c6e941e2577
Author: Matthew Turner <[email protected]>
AuthorDate: Sat Jan 1 07:06:08 2022 -0500

    Define eq_dyn_scalar API (#1074)
    
    * Squash
    
    * Cleanup error messages
---
 arrow/src/compute/kernels/comparison.rs | 374 +++++++++++++++++++++++++++++++-
 1 file changed, 370 insertions(+), 4 deletions(-)

diff --git a/arrow/src/compute/kernels/comparison.rs 
b/arrow/src/compute/kernels/comparison.rs
index f78588e..3e7a084 100644
--- a/arrow/src/compute/kernels/comparison.rs
+++ b/arrow/src/compute/kernels/comparison.rs
@@ -21,22 +21,24 @@
 //! detection is provided, you should enable the specific SIMD intrinsics using
 //! `RUSTFLAGS="-C target-feature=+avx2"` for example.  See the documentation
 //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
+//!
 
 use crate::array::*;
 use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, 
MutableBuffer};
 use crate::compute::binary_boolean_kernel;
 use crate::compute::util::combine_option_bitmap;
 use crate::datatypes::{
-    ArrowNumericType, DataType, Date32Type, Date64Type, Float32Type, 
Float64Type,
-    Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit, 
TimestampMicrosecondType,
-    TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, 
UInt16Type,
-    UInt32Type, UInt64Type, UInt8Type,
+    ArrowNativeType, ArrowNumericType, DataType, Date32Type, Date64Type, 
Float32Type,
+    Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit,
+    TimestampMicrosecondType, TimestampMillisecondType, 
TimestampNanosecondType,
+    TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
 };
 use crate::error::{ArrowError, Result};
 use crate::util::bit_util;
 use regex::{escape, Regex};
 use std::any::type_name;
 use std::collections::HashMap;
+use std::sync::Arc;
 
 /// Helper function to perform boolean lambda function on values from two 
arrays, this
 /// version does not attempt to use SIMD.
@@ -888,6 +890,303 @@ pub fn gt_eq_utf8_scalar<OffsetSize: 
StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_compare_scalar {
+    // Applies `LEFT OP RIGHT` when `LEFT` is a `DictionaryArray`
+    ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+        let right: i128 = $RIGHT.try_into().map_err(|_| {
+            ArrowError::ComputeError(String::from("Can not convert scalar to 
i128"))
+        })?;
+        match $LEFT.data_type() {
+            DataType::Int8 => {
+                let right: i8 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from("Can not convert 
scalar to i8"))
+                })?;
+                let left = as_primitive_array::<Int8Type>($LEFT);
+                $OP::<Int8Type>(left, right)
+            }
+            DataType::Int16 => {
+                let right: i16 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i16",
+                    ))
+                })?;
+                let left = as_primitive_array::<Int16Type>($LEFT);
+                $OP::<Int16Type>(left, right)
+            }
+            DataType::Int32 => {
+                let right: i32 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i32",
+                    ))
+                })?;
+                let left = as_primitive_array::<Int32Type>($LEFT);
+                $OP::<Int32Type>(left, right)
+            }
+            DataType::Int64 => {
+                let right: i64 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i64",
+                    ))
+                })?;
+                let left = as_primitive_array::<Int64Type>($LEFT);
+                $OP::<Int64Type>(left, right)
+            }
+            DataType::UInt8 => {
+                let right: u8 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from("Can not convert 
scalar to u8"))
+                })?;
+                let left = as_primitive_array::<UInt8Type>($LEFT);
+                $OP::<UInt8Type>(left, right)
+            }
+            DataType::UInt16 => {
+                let right: u16 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to u16",
+                    ))
+                })?;
+                let left = as_primitive_array::<UInt16Type>($LEFT);
+                $OP::<UInt16Type>(left, right)
+            }
+            DataType::UInt32 => {
+                let right: u32 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to u32",
+                    ))
+                })?;
+                let left = as_primitive_array::<UInt32Type>($LEFT);
+                $OP::<UInt32Type>(left, right)
+            }
+            DataType::UInt64 => {
+                let right: u64 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to u64",
+                    ))
+                })?;
+                let left = as_primitive_array::<UInt64Type>($LEFT);
+                $OP::<UInt64Type>(left, right)
+            }
+            _ => Err(ArrowError::ComputeError(String::from(
+                "Unsupported data type",
+            ))),
+        }
+    }};
+    // Applies `LEFT OP RIGHT` when `LEFT` is a `DictionaryArray` with keys of 
type `KT`
+    ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{
+        let right: i128 = $RIGHT.try_into().map_err(|_| {
+            ArrowError::ComputeError(String::from("Can not convert scalar to 
i128"))
+        })?;
+        match $KT.as_ref() {
+            DataType::UInt8 => {
+                let left = as_dictionary_array::<UInt8Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            DataType::UInt16 => {
+                let left = as_dictionary_array::<UInt16Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            DataType::UInt32 => {
+                let left = as_dictionary_array::<UInt32Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            DataType::UInt64 => {
+                let left = as_dictionary_array::<UInt64Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            DataType::Int8 => {
+                let left = as_dictionary_array::<Int8Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            DataType::Int16 => {
+                let left = as_dictionary_array::<Int16Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            DataType::Int32 => {
+                let left = as_dictionary_array::<Int32Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            DataType::Int64 => {
+                let left = as_dictionary_array::<Int64Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            _ => Err(ArrowError::ComputeError(String::from("Unknown key 
type"))),
+        }
+    }};
+}
+
+macro_rules! dyn_compare_utf8_scalar {
+    ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{
+        match $KT.as_ref() {
+            DataType::UInt8 => {
+                let left = as_dictionary_array::<UInt8Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            DataType::UInt16 => {
+                let left = as_dictionary_array::<UInt16Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            DataType::UInt32 => {
+                let left = as_dictionary_array::<UInt32Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            DataType::UInt64 => {
+                let left = as_dictionary_array::<UInt64Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            DataType::Int8 => {
+                let left = as_dictionary_array::<Int8Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            DataType::Int16 => {
+                let left = as_dictionary_array::<Int16Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            DataType::Int32 => {
+                let left = as_dictionary_array::<Int32Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            DataType::Int64 => {
+                let left = as_dictionary_array::<Int64Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            _ => Err(ArrowError::ComputeError(String::from("Unknown key 
type"))),
+        }
+    }};
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive 
values
+pub fn eq_dyn_scalar<T>(left: Arc<dyn Array>, right: T) -> Result<BooleanArray>
+where
+    T: TryInto<i128> + Copy + std::fmt::Debug,
+{
+    match left.data_type() {
+        DataType::Dictionary(key_type, value_type) => match 
value_type.as_ref() {
+            DataType::Int8
+            | DataType::Int16
+            | DataType::Int32
+            | DataType::Int64
+            | DataType::UInt8
+            | DataType::UInt16
+            | DataType::UInt32
+            | DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, 
eq_scalar)}
+            _ => Err(ArrowError::ComputeError(
+                "Kernel only supports PrimitiveArray or DictionaryArray with 
Primitive values".to_string(),
+            ))
+        }
+        DataType::Int8
+        | DataType::Int16
+        | DataType::Int32
+        | DataType::Int64
+        | DataType::UInt8
+        | DataType::UInt16
+        | DataType::UInt32
+        | DataType::UInt64 => {
+            dyn_compare_scalar!(&left, right, eq_scalar)
+        }
+        _ => Err(ArrowError::ComputeError(
+            "Kernel only supports PrimitiveArray or DictionaryArray with 
Primitive values".to_string(),
+        ))
+    }
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports StringArrays, and DictionaryArrays that have string values
+pub fn eq_dyn_utf8_scalar(left: Arc<dyn Array>, right: &str) -> 
Result<BooleanArray> {
+    let result = match left.data_type() {
+        DataType::Dictionary(key_type, value_type) => match 
value_type.as_ref() {
+            DataType::Utf8 | DataType::LargeUtf8 => {
+                dyn_compare_utf8_scalar!(&left, right, key_type, 
eq_utf8_scalar)
+            }
+            _ => Err(ArrowError::ComputeError(
+                "Kernel only supports Utf8 or LargeUtf8 arrays or 
DictionaryArray with Utf8 or LargeUtf8 values".to_string(),
+            )),
+        },
+        DataType::Utf8 | DataType::LargeUtf8 => {
+            let left = as_string_array(&left);
+            eq_utf8_scalar(left, right)
+        }
+        _ => Err(ArrowError::ComputeError(
+            "Kernel only supports Utf8 or LargeUtf8 arrays".to_string(),
+        )),
+    };
+    result
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports BooleanArrays, and DictionaryArrays that have string values
+pub fn eq_dyn_bool_scalar(left: Arc<dyn Array>, right: bool) -> 
Result<BooleanArray> {
+    let result = match left.data_type() {
+        DataType::Boolean => {
+            let left = as_boolean_array(&left);
+            eq_bool_scalar(left, right)
+        }
+        _ => Err(ArrowError::ComputeError(
+            "Kernel only supports BooleanArray".to_string(),
+        )),
+    };
+    result
+}
+
+/// unpacks the results of comparing left.values (as a boolean)
+///
+/// TODO add example
+///
+fn unpack_dict_comparison<K>(
+    dict: &DictionaryArray<K>,
+    dict_comparison: BooleanArray,
+) -> Result<BooleanArray>
+where
+    K: ArrowNumericType,
+{
+    assert_eq!(dict_comparison.len(), dict.values().len());
+
+    let result: BooleanArray = dict
+        .keys()
+        .iter()
+        .map(|key| {
+            key.map(|key| unsafe {
+                // safety lengths were verified above
+                let key = key.to_usize().expect("Dictionary index not usize");
+                dict_comparison.value_unchecked(key)
+            })
+        })
+        .collect();
+
+    Ok(result)
+}
+
 /// Helper function to perform boolean lambda function on values from two 
arrays using
 /// SIMD.
 #[cfg(feature = "simd")]
@@ -2646,4 +2945,71 @@ mod tests {
         regexp_is_match_utf8_scalar,
         vec![true, true, false, false]
     );
+    #[test]
+    fn test_eq_dyn_scalar() {
+        let array = Int32Array::from(vec![6, 7, 8, 8, 10]);
+        let array = Arc::new(array);
+        let a_eq = eq_dyn_scalar(array, 8).unwrap();
+        assert_eq!(
+            a_eq,
+            BooleanArray::from(
+                vec![Some(false), Some(false), Some(true), Some(true), 
Some(false)]
+            )
+        );
+    }
+    #[test]
+    fn test_eq_dyn_scalar_with_dict() {
+        let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
+        let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
+        let mut builder = PrimitiveDictionaryBuilder::new(key_builder, 
value_builder);
+        builder.append(123).unwrap();
+        builder.append_null().unwrap();
+        builder.append(23).unwrap();
+        let array = Arc::new(builder.finish());
+        let a_eq = eq_dyn_scalar(array, 123).unwrap();
+        assert_eq!(
+            a_eq,
+            BooleanArray::from(vec![Some(true), None, Some(false)])
+        );
+    }
+    #[test]
+    fn test_eq_dyn_utf8_scalar() {
+        let array = StringArray::from(vec!["abc", "def", "xyz"]);
+        let array = Arc::new(array);
+        let a_eq = eq_dyn_utf8_scalar(array, "xyz").unwrap();
+        assert_eq!(
+            a_eq,
+            BooleanArray::from(vec![Some(false), Some(false), Some(true)])
+        );
+    }
+    #[test]
+    fn test_eq_dyn_utf8_scalar_with_dict() {
+        let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
+        let value_builder = StringBuilder::new(100);
+        let mut builder = StringDictionaryBuilder::new(key_builder, 
value_builder);
+        builder.append("abc").unwrap();
+        builder.append_null().unwrap();
+        builder.append("def").unwrap();
+        builder.append("def").unwrap();
+        builder.append("abc").unwrap();
+        let array = Arc::new(builder.finish());
+        let a_eq = eq_dyn_utf8_scalar(array, "def").unwrap();
+        assert_eq!(
+            a_eq,
+            BooleanArray::from(
+                vec![Some(false), None, Some(true), Some(true), Some(false)]
+            )
+        );
+    }
+
+    #[test]
+    fn test_eq_dyn_bool_scalar() {
+        let array = BooleanArray::from(vec![true, false, true]);
+        let array = Arc::new(array);
+        let a_eq = eq_dyn_bool_scalar(array, false).unwrap();
+        assert_eq!(
+            a_eq,
+            BooleanArray::from(vec![Some(false), Some(true), Some(false)])
+        );
+    }
 }

Reply via email to