This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch active_release
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/active_release by this push:
new b42649b implement eq_dyn and neq_dyn (#858) (#867)
b42649b is described below
commit b42649b0088fe7762c713a41a23c1abdf8d0496d
Author: Andrew Lamb <[email protected]>
AuthorDate: Wed Oct 27 08:44:25 2021 -0400
implement eq_dyn and neq_dyn (#858) (#867)
Co-authored-by: Jiayu Liu <[email protected]>
---
arrow/src/compute/kernels/comparison.rs | 188 +++++++++++++++++++++++++++++---
1 file changed, 171 insertions(+), 17 deletions(-)
diff --git a/arrow/src/compute/kernels/comparison.rs
b/arrow/src/compute/kernels/comparison.rs
index 81827b0..1f0cb1a 100644
--- a/arrow/src/compute/kernels/comparison.rs
+++ b/arrow/src/compute/kernels/comparison.rs
@@ -22,16 +22,19 @@
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
-use regex::Regex;
-use std::collections::HashMap;
-
use crate::array::*;
use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer,
MutableBuffer};
use crate::compute::binary_boolean_kernel;
use crate::compute::util::combine_option_bitmap;
-use crate::datatypes::{ArrowNumericType, DataType};
+use crate::datatypes::{
+ ArrowNumericType, DataType, Float32Type, Float64Type, Int16Type, Int32Type,
+ Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+};
use crate::error::{ArrowError, Result};
use crate::util::bit_util;
+use regex::Regex;
+use std::any::type_name;
+use std::collections::HashMap;
/// Helper function to perform boolean lambda function on values from two
arrays, this
/// version does not attempt to use SIMD.
@@ -974,7 +977,142 @@ where
Ok(BooleanArray::from(data))
}
-/// Perform `left == right` operation on two arrays.
+macro_rules! typed_cmp {
+ ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident) => {{
+ let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+ ArrowError::CastError(format!(
+ "Left array cannot be cast to {}",
+ type_name::<$T>()
+ ))
+ })?;
+ let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+ ArrowError::CastError(format!(
+ "Right array cannot be cast to {}",
+ type_name::<$T>(),
+ ))
+ })?;
+ $OP(left, right)
+ }};
+ ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
+ let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+ ArrowError::CastError(format!(
+ "Left array cannot be cast to {}",
+ type_name::<$T>()
+ ))
+ })?;
+ let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+ ArrowError::CastError(format!(
+ "Right array cannot be cast to {}",
+ type_name::<$T>(),
+ ))
+ })?;
+ $OP::<$TT>(left, right)
+ }};
+}
+
+macro_rules! typed_compares {
+ ($LEFT: expr, $RIGHT: expr, $OP_BOOL: ident, $OP_PRIM: ident, $OP_STR:
ident) => {{
+ match ($LEFT.data_type(), $RIGHT.data_type()) {
+ (DataType::Boolean, DataType::Boolean) => {
+ typed_cmp!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+ }
+ (DataType::Int8, DataType::Int8) => {
+ typed_cmp!($LEFT, $RIGHT, Int8Array, $OP_PRIM, Int8Type)
+ }
+ (DataType::Int16, DataType::Int16) => {
+ typed_cmp!($LEFT, $RIGHT, Int16Array, $OP_PRIM, Int16Type)
+ }
+ (DataType::Int32, DataType::Int32) => {
+ typed_cmp!($LEFT, $RIGHT, Int32Array, $OP_PRIM, Int32Type)
+ }
+ (DataType::Int64, DataType::Int64) => {
+ typed_cmp!($LEFT, $RIGHT, Int64Array, $OP_PRIM, Int64Type)
+ }
+ (DataType::UInt8, DataType::UInt8) => {
+ typed_cmp!($LEFT, $RIGHT, UInt8Array, $OP_PRIM, UInt8Type)
+ }
+ (DataType::UInt16, DataType::UInt16) => {
+ typed_cmp!($LEFT, $RIGHT, UInt16Array, $OP_PRIM, UInt16Type)
+ }
+ (DataType::UInt32, DataType::UInt32) => {
+ typed_cmp!($LEFT, $RIGHT, UInt32Array, $OP_PRIM, UInt32Type)
+ }
+ (DataType::UInt64, DataType::UInt64) => {
+ typed_cmp!($LEFT, $RIGHT, UInt64Array, $OP_PRIM, UInt64Type)
+ }
+ (DataType::Float32, DataType::Float32) => {
+ typed_cmp!($LEFT, $RIGHT, Float32Array, $OP_PRIM, Float32Type)
+ }
+ (DataType::Float64, DataType::Float64) => {
+ typed_cmp!($LEFT, $RIGHT, Float64Array, $OP_PRIM, Float64Type)
+ }
+ (DataType::Utf8, DataType::Utf8) => {
+ typed_cmp!($LEFT, $RIGHT, StringArray, $OP_STR, i32)
+ }
+ (DataType::LargeUtf8, DataType::LargeUtf8) => {
+ typed_cmp!($LEFT, $RIGHT, LargeStringArray, $OP_STR, i64)
+ }
+ (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+ "Comparing arrays of type {} is not yet implemented",
+ t1
+ ))),
+ (t1, t2) => Err(ArrowError::CastError(format!(
+ "Cannot compare two arrays of different types ({} and {})",
+ t1, t2
+ ))),
+ }
+ }};
+}
+
+/// Perform `left == right` operation on two (dynamic) [`Array`]s.
+///
+/// Only when two arrays are of the same type the comparison will happen
otherwise it will err
+/// with a casting error.
+pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
+ typed_compares!(left, right, eq_bool, eq, eq_utf8)
+}
+
+/// Perform `left != right` operation on two (dynamic) [`Array`]s.
+///
+/// Only when two arrays are of the same type the comparison will happen
otherwise it will err
+/// with a casting error.
+pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
+ typed_compares!(left, right, neq_bool, neq, neq_utf8)
+}
+
+/// Perform `left < right` operation on two (dynamic) [`Array`]s.
+///
+/// Only when two arrays are of the same type the comparison will happen
otherwise it will err
+/// with a casting error.
+pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
+ typed_compares!(left, right, lt_bool, lt, lt_utf8)
+}
+
+/// Perform `left <= right` operation on two (dynamic) [`Array`]s.
+///
+/// Only when two arrays are of the same type the comparison will happen
otherwise it will err
+/// with a casting error.
+pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
+ typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8)
+}
+
+/// Perform `left > right` operation on two (dynamic) [`Array`]s.
+///
+/// Only when two arrays are of the same type the comparison will happen
otherwise it will err
+/// with a casting error.
+pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
+ typed_compares!(left, right, gt_bool, gt, gt_utf8)
+}
+
+/// Perform `left >= right` operation on two (dynamic) [`Array`]s.
+///
+/// Only when two arrays are of the same type the comparison will happen
otherwise it will err
+/// with a casting error.
+pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
+ typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8)
+}
+
+/// Perform `left == right` operation on two [`PrimitiveArray`]s.
pub fn eq<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) ->
Result<BooleanArray>
where
T: ArrowNumericType,
@@ -985,7 +1123,7 @@ where
return compare_op!(left, right, |a, b| a == b);
}
-/// Perform `left == right` operation on an array and a scalar value.
+/// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar
value.
pub fn eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) ->
Result<BooleanArray>
where
T: ArrowNumericType,
@@ -996,7 +1134,7 @@ where
return compare_op_scalar!(left, right, |a, b| a == b);
}
-/// Perform `left != right` operation on two arrays.
+/// Perform `left != right` operation on two [`PrimitiveArray`]s.
pub fn neq<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) ->
Result<BooleanArray>
where
T: ArrowNumericType,
@@ -1007,7 +1145,7 @@ where
return compare_op!(left, right, |a, b| a != b);
}
-/// Perform `left != right` operation on an array and a scalar value.
+/// Perform `left != right` operation on a [`PrimitiveArray`] and a scalar
value.
pub fn neq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) ->
Result<BooleanArray>
where
T: ArrowNumericType,
@@ -1018,7 +1156,7 @@ where
return compare_op_scalar!(left, right, |a, b| a != b);
}
-/// Perform `left < right` operation on two arrays. Null values are less than
non-null
+/// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values
are less than non-null
/// values.
pub fn lt<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) ->
Result<BooleanArray>
where
@@ -1030,7 +1168,7 @@ where
return compare_op!(left, right, |a, b| a < b);
}
-/// Perform `left < right` operation on an array and a scalar value.
+/// Perform `left < right` operation on a [`PrimitiveArray`] and a scalar
value.
/// Null values are less than non-null values.
pub fn lt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) ->
Result<BooleanArray>
where
@@ -1042,7 +1180,7 @@ where
return compare_op_scalar!(left, right, |a, b| a < b);
}
-/// Perform `left <= right` operation on two arrays. Null values are less than
non-null
+/// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values
are less than non-null
/// values.
pub fn lt_eq<T>(
left: &PrimitiveArray<T>,
@@ -1057,7 +1195,7 @@ where
return compare_op!(left, right, |a, b| a <= b);
}
-/// Perform `left <= right` operation on an array and a scalar value.
+/// Perform `left <= right` operation on a [`PrimitiveArray`] and a scalar
value.
/// Null values are less than non-null values.
pub fn lt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) ->
Result<BooleanArray>
where
@@ -1069,7 +1207,7 @@ where
return compare_op_scalar!(left, right, |a, b| a <= b);
}
-/// Perform `left > right` operation on two arrays. Non-null values are
greater than null
+/// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null
values are greater than null
/// values.
pub fn gt<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) ->
Result<BooleanArray>
where
@@ -1081,7 +1219,7 @@ where
return compare_op!(left, right, |a, b| a > b);
}
-/// Perform `left > right` operation on an array and a scalar value.
+/// Perform `left > right` operation on a [`PrimitiveArray`] and a scalar
value.
/// Non-null values are greater than null values.
pub fn gt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) ->
Result<BooleanArray>
where
@@ -1093,7 +1231,7 @@ where
return compare_op_scalar!(left, right, |a, b| a > b);
}
-/// Perform `left >= right` operation on two arrays. Non-null values are
greater than null
+/// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null
values are greater than null
/// values.
pub fn gt_eq<T>(
left: &PrimitiveArray<T>,
@@ -1108,7 +1246,7 @@ where
return compare_op!(left, right, |a, b| a >= b);
}
-/// Perform `left >= right` operation on an array and a scalar value.
+/// Perform `left >= right` operation on a [`PrimitiveArray`] and a scalar
value.
/// Non-null values are greater than null values.
pub fn gt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) ->
Result<BooleanArray>
where
@@ -1260,11 +1398,17 @@ mod tests {
/// `EXPECTED` can be either `Vec<bool>` or `Vec<Option<bool>>`.
/// The main reason for this macro is that inputs and outputs align nicely
after `cargo fmt`.
macro_rules! cmp_i64 {
- ($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => {
+ ($KERNEL:ident, $DYN_KERNEL:ident, $A_VEC:expr, $B_VEC:expr,
$EXPECTED:expr) => {
let a = Int64Array::from($A_VEC);
let b = Int64Array::from($B_VEC);
let c = $KERNEL(&a, &b).unwrap();
assert_eq!(BooleanArray::from($EXPECTED), c);
+
+ // slice and test if the dynamic array works
+ let a = a.slice(0, a.len());
+ let b = b.slice(0, b.len());
+ let c = $DYN_KERNEL(a.as_ref(), b.as_ref()).unwrap();
+ assert_eq!(BooleanArray::from($EXPECTED), c);
};
}
@@ -1284,6 +1428,7 @@ mod tests {
fn test_primitive_array_eq() {
cmp_i64!(
eq,
+ eq_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![false, false, true, false, false, false, false, true, false,
false]
@@ -1330,6 +1475,7 @@ mod tests {
fn test_primitive_array_neq() {
cmp_i64!(
neq,
+ neq_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![true, true, false, true, true, true, true, false, true, true]
@@ -1479,6 +1625,7 @@ mod tests {
fn test_primitive_array_lt() {
cmp_i64!(
lt,
+ lt_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![false, false, false, true, true, false, false, false, true,
true]
@@ -1499,6 +1646,7 @@ mod tests {
fn test_primitive_array_lt_nulls() {
cmp_i64!(
lt,
+ lt_dyn,
vec![None, None, Some(1), Some(1), None, None, Some(2), Some(2),],
vec![None, Some(1), None, Some(1), None, Some(3), None, Some(3),],
vec![None, None, None, Some(false), None, None, None, Some(true)]
@@ -1519,6 +1667,7 @@ mod tests {
fn test_primitive_array_lt_eq() {
cmp_i64!(
lt_eq,
+ lt_eq_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![false, false, true, true, true, false, false, true, true,
true]
@@ -1539,6 +1688,7 @@ mod tests {
fn test_primitive_array_lt_eq_nulls() {
cmp_i64!(
lt_eq,
+ lt_eq_dyn,
vec![None, None, Some(1), None, None, Some(1), None, None,
Some(1)],
vec![None, Some(1), Some(0), None, Some(1), Some(2), None, None,
Some(3)],
vec![None, None, Some(false), None, None, Some(true), None, None,
Some(true)]
@@ -1559,6 +1709,7 @@ mod tests {
fn test_primitive_array_gt() {
cmp_i64!(
gt,
+ gt_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![true, true, false, false, false, true, true, false, false,
false]
@@ -1579,6 +1730,7 @@ mod tests {
fn test_primitive_array_gt_nulls() {
cmp_i64!(
gt,
+ gt_dyn,
vec![None, None, Some(1), None, None, Some(2), None, None,
Some(3)],
vec![None, Some(1), Some(1), None, Some(1), Some(1), None,
Some(1), Some(1)],
vec![None, None, Some(false), None, None, Some(true), None, None,
Some(true)]
@@ -1599,6 +1751,7 @@ mod tests {
fn test_primitive_array_gt_eq() {
cmp_i64!(
gt_eq,
+ gt_eq_dyn,
vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
vec![true, true, true, false, false, true, true, true, false,
false]
@@ -1619,6 +1772,7 @@ mod tests {
fn test_primitive_array_gt_eq_nulls() {
cmp_i64!(
gt_eq,
+ gt_eq_dyn,
vec![None, None, Some(1), None, Some(1), Some(2), None, None,
Some(1)],
vec![None, Some(1), None, None, Some(1), Some(1), None, Some(2),
Some(2)],
vec![None, None, None, None, Some(true), Some(true), None, None,
Some(false)]