This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 0d825c1 Define eq_dyn_scalar API (#1074)
0d825c1 is described below
commit 0d825c196e343805c7500bdc06af0c6e941e2577
Author: Matthew Turner <[email protected]>
AuthorDate: Sat Jan 1 07:06:08 2022 -0500
Define eq_dyn_scalar API (#1074)
* Squash
* Cleanup error messages
---
arrow/src/compute/kernels/comparison.rs | 374 +++++++++++++++++++++++++++++++-
1 file changed, 370 insertions(+), 4 deletions(-)
diff --git a/arrow/src/compute/kernels/comparison.rs
b/arrow/src/compute/kernels/comparison.rs
index f78588e..3e7a084 100644
--- a/arrow/src/compute/kernels/comparison.rs
+++ b/arrow/src/compute/kernels/comparison.rs
@@ -21,22 +21,24 @@
//! detection is provided, you should enable the specific SIMD intrinsics using
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
+//!
use crate::array::*;
use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer,
MutableBuffer};
use crate::compute::binary_boolean_kernel;
use crate::compute::util::combine_option_bitmap;
use crate::datatypes::{
- ArrowNumericType, DataType, Date32Type, Date64Type, Float32Type,
Float64Type,
- Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit,
TimestampMicrosecondType,
- TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
UInt16Type,
- UInt32Type, UInt64Type, UInt8Type,
+ ArrowNativeType, ArrowNumericType, DataType, Date32Type, Date64Type,
Float32Type,
+ Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit,
+ TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType,
+ TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use crate::error::{ArrowError, Result};
use crate::util::bit_util;
use regex::{escape, Regex};
use std::any::type_name;
use std::collections::HashMap;
+use std::sync::Arc;
/// Helper function to perform boolean lambda function on values from two
arrays, this
/// version does not attempt to use SIMD.
@@ -888,6 +890,303 @@ pub fn gt_eq_utf8_scalar<OffsetSize:
StringOffsetSizeTrait>(
compare_op_scalar!(left, right, |a, b| a >= b)
}
+macro_rules! dyn_compare_scalar {
+ // Applies `LEFT OP RIGHT` when `LEFT` is a `DictionaryArray`
+ ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+ let right: i128 = $RIGHT.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from("Can not convert scalar to
i128"))
+ })?;
+ match $LEFT.data_type() {
+ DataType::Int8 => {
+ let right: i8 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from("Can not convert
scalar to i8"))
+ })?;
+ let left = as_primitive_array::<Int8Type>($LEFT);
+ $OP::<Int8Type>(left, right)
+ }
+ DataType::Int16 => {
+ let right: i16 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to i16",
+ ))
+ })?;
+ let left = as_primitive_array::<Int16Type>($LEFT);
+ $OP::<Int16Type>(left, right)
+ }
+ DataType::Int32 => {
+ let right: i32 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to i32",
+ ))
+ })?;
+ let left = as_primitive_array::<Int32Type>($LEFT);
+ $OP::<Int32Type>(left, right)
+ }
+ DataType::Int64 => {
+ let right: i64 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to i64",
+ ))
+ })?;
+ let left = as_primitive_array::<Int64Type>($LEFT);
+ $OP::<Int64Type>(left, right)
+ }
+ DataType::UInt8 => {
+ let right: u8 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from("Can not convert
scalar to u8"))
+ })?;
+ let left = as_primitive_array::<UInt8Type>($LEFT);
+ $OP::<UInt8Type>(left, right)
+ }
+ DataType::UInt16 => {
+ let right: u16 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to u16",
+ ))
+ })?;
+ let left = as_primitive_array::<UInt16Type>($LEFT);
+ $OP::<UInt16Type>(left, right)
+ }
+ DataType::UInt32 => {
+ let right: u32 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to u32",
+ ))
+ })?;
+ let left = as_primitive_array::<UInt32Type>($LEFT);
+ $OP::<UInt32Type>(left, right)
+ }
+ DataType::UInt64 => {
+ let right: u64 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to u64",
+ ))
+ })?;
+ let left = as_primitive_array::<UInt64Type>($LEFT);
+ $OP::<UInt64Type>(left, right)
+ }
+ _ => Err(ArrowError::ComputeError(String::from(
+ "Unsupported data type",
+ ))),
+ }
+ }};
+ // Applies `LEFT OP RIGHT` when `LEFT` is a `DictionaryArray` with keys of
type `KT`
+ ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{
+ let right: i128 = $RIGHT.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from("Can not convert scalar to
i128"))
+ })?;
+ match $KT.as_ref() {
+ DataType::UInt8 => {
+ let left = as_dictionary_array::<UInt8Type>($LEFT);
+ unpack_dict_comparison(
+ left,
+ dyn_compare_scalar!(left.values(), right, $OP)?,
+ )
+ }
+ DataType::UInt16 => {
+ let left = as_dictionary_array::<UInt16Type>($LEFT);
+ unpack_dict_comparison(
+ left,
+ dyn_compare_scalar!(left.values(), right, $OP)?,
+ )
+ }
+ DataType::UInt32 => {
+ let left = as_dictionary_array::<UInt32Type>($LEFT);
+ unpack_dict_comparison(
+ left,
+ dyn_compare_scalar!(left.values(), right, $OP)?,
+ )
+ }
+ DataType::UInt64 => {
+ let left = as_dictionary_array::<UInt64Type>($LEFT);
+ unpack_dict_comparison(
+ left,
+ dyn_compare_scalar!(left.values(), right, $OP)?,
+ )
+ }
+ DataType::Int8 => {
+ let left = as_dictionary_array::<Int8Type>($LEFT);
+ unpack_dict_comparison(
+ left,
+ dyn_compare_scalar!(left.values(), right, $OP)?,
+ )
+ }
+ DataType::Int16 => {
+ let left = as_dictionary_array::<Int16Type>($LEFT);
+ unpack_dict_comparison(
+ left,
+ dyn_compare_scalar!(left.values(), right, $OP)?,
+ )
+ }
+ DataType::Int32 => {
+ let left = as_dictionary_array::<Int32Type>($LEFT);
+ unpack_dict_comparison(
+ left,
+ dyn_compare_scalar!(left.values(), right, $OP)?,
+ )
+ }
+ DataType::Int64 => {
+ let left = as_dictionary_array::<Int64Type>($LEFT);
+ unpack_dict_comparison(
+ left,
+ dyn_compare_scalar!(left.values(), right, $OP)?,
+ )
+ }
+ _ => Err(ArrowError::ComputeError(String::from("Unknown key
type"))),
+ }
+ }};
+}
+
+macro_rules! dyn_compare_utf8_scalar {
+ ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{
+ match $KT.as_ref() {
+ DataType::UInt8 => {
+ let left = as_dictionary_array::<UInt8Type>($LEFT);
+ let values = as_string_array(left.values());
+ unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+ }
+ DataType::UInt16 => {
+ let left = as_dictionary_array::<UInt16Type>($LEFT);
+ let values = as_string_array(left.values());
+ unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+ }
+ DataType::UInt32 => {
+ let left = as_dictionary_array::<UInt32Type>($LEFT);
+ let values = as_string_array(left.values());
+ unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+ }
+ DataType::UInt64 => {
+ let left = as_dictionary_array::<UInt64Type>($LEFT);
+ let values = as_string_array(left.values());
+ unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+ }
+ DataType::Int8 => {
+ let left = as_dictionary_array::<Int8Type>($LEFT);
+ let values = as_string_array(left.values());
+ unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+ }
+ DataType::Int16 => {
+ let left = as_dictionary_array::<Int16Type>($LEFT);
+ let values = as_string_array(left.values());
+ unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+ }
+ DataType::Int32 => {
+ let left = as_dictionary_array::<Int32Type>($LEFT);
+ let values = as_string_array(left.values());
+ unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+ }
+ DataType::Int64 => {
+ let left = as_dictionary_array::<Int64Type>($LEFT);
+ let values = as_string_array(left.values());
+ unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+ }
+ _ => Err(ArrowError::ComputeError(String::from("Unknown key
type"))),
+ }
+ }};
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive
values
+pub fn eq_dyn_scalar<T>(left: Arc<dyn Array>, right: T) -> Result<BooleanArray>
+where
+ T: TryInto<i128> + Copy + std::fmt::Debug,
+{
+ match left.data_type() {
+ DataType::Dictionary(key_type, value_type) => match
value_type.as_ref() {
+ DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64
+ | DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type,
eq_scalar)}
+ _ => Err(ArrowError::ComputeError(
+ "Kernel only supports PrimitiveArray or DictionaryArray with
Primitive values".to_string(),
+ ))
+ }
+ DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64
+ | DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64 => {
+ dyn_compare_scalar!(&left, right, eq_scalar)
+ }
+ _ => Err(ArrowError::ComputeError(
+ "Kernel only supports PrimitiveArray or DictionaryArray with
Primitive values".to_string(),
+ ))
+ }
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports StringArrays, and DictionaryArrays that have string values
+pub fn eq_dyn_utf8_scalar(left: Arc<dyn Array>, right: &str) ->
Result<BooleanArray> {
+ let result = match left.data_type() {
+ DataType::Dictionary(key_type, value_type) => match
value_type.as_ref() {
+ DataType::Utf8 | DataType::LargeUtf8 => {
+ dyn_compare_utf8_scalar!(&left, right, key_type,
eq_utf8_scalar)
+ }
+ _ => Err(ArrowError::ComputeError(
+ "Kernel only supports Utf8 or LargeUtf8 arrays or
DictionaryArray with Utf8 or LargeUtf8 values".to_string(),
+ )),
+ },
+ DataType::Utf8 | DataType::LargeUtf8 => {
+ let left = as_string_array(&left);
+ eq_utf8_scalar(left, right)
+ }
+ _ => Err(ArrowError::ComputeError(
+ "Kernel only supports Utf8 or LargeUtf8 arrays".to_string(),
+ )),
+ };
+ result
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports BooleanArrays, and DictionaryArrays that have string values
+pub fn eq_dyn_bool_scalar(left: Arc<dyn Array>, right: bool) ->
Result<BooleanArray> {
+ let result = match left.data_type() {
+ DataType::Boolean => {
+ let left = as_boolean_array(&left);
+ eq_bool_scalar(left, right)
+ }
+ _ => Err(ArrowError::ComputeError(
+ "Kernel only supports BooleanArray".to_string(),
+ )),
+ };
+ result
+}
+
+/// unpacks the results of comparing left.values (as a boolean)
+///
+/// TODO add example
+///
+fn unpack_dict_comparison<K>(
+ dict: &DictionaryArray<K>,
+ dict_comparison: BooleanArray,
+) -> Result<BooleanArray>
+where
+ K: ArrowNumericType,
+{
+ assert_eq!(dict_comparison.len(), dict.values().len());
+
+ let result: BooleanArray = dict
+ .keys()
+ .iter()
+ .map(|key| {
+ key.map(|key| unsafe {
+ // safety lengths were verified above
+ let key = key.to_usize().expect("Dictionary index not usize");
+ dict_comparison.value_unchecked(key)
+ })
+ })
+ .collect();
+
+ Ok(result)
+}
+
/// Helper function to perform boolean lambda function on values from two
arrays using
/// SIMD.
#[cfg(feature = "simd")]
@@ -2646,4 +2945,71 @@ mod tests {
regexp_is_match_utf8_scalar,
vec![true, true, false, false]
);
+ #[test]
+ fn test_eq_dyn_scalar() {
+ let array = Int32Array::from(vec![6, 7, 8, 8, 10]);
+ let array = Arc::new(array);
+ let a_eq = eq_dyn_scalar(array, 8).unwrap();
+ assert_eq!(
+ a_eq,
+ BooleanArray::from(
+ vec![Some(false), Some(false), Some(true), Some(true),
Some(false)]
+ )
+ );
+ }
+ #[test]
+ fn test_eq_dyn_scalar_with_dict() {
+ let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
+ let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
+ let mut builder = PrimitiveDictionaryBuilder::new(key_builder,
value_builder);
+ builder.append(123).unwrap();
+ builder.append_null().unwrap();
+ builder.append(23).unwrap();
+ let array = Arc::new(builder.finish());
+ let a_eq = eq_dyn_scalar(array, 123).unwrap();
+ assert_eq!(
+ a_eq,
+ BooleanArray::from(vec![Some(true), None, Some(false)])
+ );
+ }
+ #[test]
+ fn test_eq_dyn_utf8_scalar() {
+ let array = StringArray::from(vec!["abc", "def", "xyz"]);
+ let array = Arc::new(array);
+ let a_eq = eq_dyn_utf8_scalar(array, "xyz").unwrap();
+ assert_eq!(
+ a_eq,
+ BooleanArray::from(vec![Some(false), Some(false), Some(true)])
+ );
+ }
+ #[test]
+ fn test_eq_dyn_utf8_scalar_with_dict() {
+ let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
+ let value_builder = StringBuilder::new(100);
+ let mut builder = StringDictionaryBuilder::new(key_builder,
value_builder);
+ builder.append("abc").unwrap();
+ builder.append_null().unwrap();
+ builder.append("def").unwrap();
+ builder.append("def").unwrap();
+ builder.append("abc").unwrap();
+ let array = Arc::new(builder.finish());
+ let a_eq = eq_dyn_utf8_scalar(array, "def").unwrap();
+ assert_eq!(
+ a_eq,
+ BooleanArray::from(
+ vec![Some(false), None, Some(true), Some(true), Some(false)]
+ )
+ );
+ }
+
+ #[test]
+ fn test_eq_dyn_bool_scalar() {
+ let array = BooleanArray::from(vec![true, false, true]);
+ let array = Arc::new(array);
+ let a_eq = eq_dyn_bool_scalar(array, false).unwrap();
+ assert_eq!(
+ a_eq,
+ BooleanArray::from(vec![Some(false), Some(true), Some(false)])
+ );
+ }
}