This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch active_release
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/active_release by this push:
     new b42649b  implement eq_dyn and neq_dyn (#858) (#867)
b42649b is described below

commit b42649b0088fe7762c713a41a23c1abdf8d0496d
Author: Andrew Lamb <[email protected]>
AuthorDate: Wed Oct 27 08:44:25 2021 -0400

    implement eq_dyn and neq_dyn (#858) (#867)
    
    Co-authored-by: Jiayu Liu <[email protected]>
---
 arrow/src/compute/kernels/comparison.rs | 188 +++++++++++++++++++++++++++++---
 1 file changed, 171 insertions(+), 17 deletions(-)

diff --git a/arrow/src/compute/kernels/comparison.rs 
b/arrow/src/compute/kernels/comparison.rs
index 81827b0..1f0cb1a 100644
--- a/arrow/src/compute/kernels/comparison.rs
+++ b/arrow/src/compute/kernels/comparison.rs
@@ -22,16 +22,19 @@
 //! `RUSTFLAGS="-C target-feature=+avx2"` for example.  See the documentation
 //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
 
-use regex::Regex;
-use std::collections::HashMap;
-
 use crate::array::*;
 use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, 
MutableBuffer};
 use crate::compute::binary_boolean_kernel;
 use crate::compute::util::combine_option_bitmap;
-use crate::datatypes::{ArrowNumericType, DataType};
+use crate::datatypes::{
+    ArrowNumericType, DataType, Float32Type, Float64Type, Int16Type, Int32Type,
+    Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+};
 use crate::error::{ArrowError, Result};
 use crate::util::bit_util;
+use regex::Regex;
+use std::any::type_name;
+use std::collections::HashMap;
 
 /// Helper function to perform boolean lambda function on values from two 
arrays, this
 /// version does not attempt to use SIMD.
@@ -974,7 +977,142 @@ where
     Ok(BooleanArray::from(data))
 }
 
-/// Perform `left == right` operation on two arrays.
+macro_rules! typed_cmp {
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident) => {{
+        let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Left array cannot be cast to {}",
+                type_name::<$T>()
+            ))
+        })?;
+        let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Right array cannot be cast to {}",
+                type_name::<$T>(),
+            ))
+        })?;
+        $OP(left, right)
+    }};
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
+        let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Left array cannot be cast to {}",
+                type_name::<$T>()
+            ))
+        })?;
+        let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Right array cannot be cast to {}",
+                type_name::<$T>(),
+            ))
+        })?;
+        $OP::<$TT>(left, right)
+    }};
+}
+
+macro_rules! typed_compares {
+    ($LEFT: expr, $RIGHT: expr, $OP_BOOL: ident, $OP_PRIM: ident, $OP_STR: 
ident) => {{
+        match ($LEFT.data_type(), $RIGHT.data_type()) {
+            (DataType::Boolean, DataType::Boolean) => {
+                typed_cmp!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+            }
+            (DataType::Int8, DataType::Int8) => {
+                typed_cmp!($LEFT, $RIGHT, Int8Array, $OP_PRIM, Int8Type)
+            }
+            (DataType::Int16, DataType::Int16) => {
+                typed_cmp!($LEFT, $RIGHT, Int16Array, $OP_PRIM, Int16Type)
+            }
+            (DataType::Int32, DataType::Int32) => {
+                typed_cmp!($LEFT, $RIGHT, Int32Array, $OP_PRIM, Int32Type)
+            }
+            (DataType::Int64, DataType::Int64) => {
+                typed_cmp!($LEFT, $RIGHT, Int64Array, $OP_PRIM, Int64Type)
+            }
+            (DataType::UInt8, DataType::UInt8) => {
+                typed_cmp!($LEFT, $RIGHT, UInt8Array, $OP_PRIM, UInt8Type)
+            }
+            (DataType::UInt16, DataType::UInt16) => {
+                typed_cmp!($LEFT, $RIGHT, UInt16Array, $OP_PRIM, UInt16Type)
+            }
+            (DataType::UInt32, DataType::UInt32) => {
+                typed_cmp!($LEFT, $RIGHT, UInt32Array, $OP_PRIM, UInt32Type)
+            }
+            (DataType::UInt64, DataType::UInt64) => {
+                typed_cmp!($LEFT, $RIGHT, UInt64Array, $OP_PRIM, UInt64Type)
+            }
+            (DataType::Float32, DataType::Float32) => {
+                typed_cmp!($LEFT, $RIGHT, Float32Array, $OP_PRIM, Float32Type)
+            }
+            (DataType::Float64, DataType::Float64) => {
+                typed_cmp!($LEFT, $RIGHT, Float64Array, $OP_PRIM, Float64Type)
+            }
+            (DataType::Utf8, DataType::Utf8) => {
+                typed_cmp!($LEFT, $RIGHT, StringArray, $OP_STR, i32)
+            }
+            (DataType::LargeUtf8, DataType::LargeUtf8) => {
+                typed_cmp!($LEFT, $RIGHT, LargeStringArray, $OP_STR, i64)
+            }
+            (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+                "Comparing arrays of type {} is not yet implemented",
+                t1
+            ))),
+            (t1, t2) => Err(ArrowError::CastError(format!(
+                "Cannot compare two arrays of different types ({} and {})",
+                t1, t2
+            ))),
+        }
+    }};
+}
+
+/// Perform `left == right` operation on two (dynamic) [`Array`]s.
+///
+/// Only when two arrays are of the same type the comparison will happen 
otherwise it will err
+/// with a casting error.
+pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
+    typed_compares!(left, right, eq_bool, eq, eq_utf8)
+}
+
+/// Perform `left != right` operation on two (dynamic) [`Array`]s.
+///
+/// Only when two arrays are of the same type the comparison will happen 
otherwise it will err
+/// with a casting error.
+pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
+    typed_compares!(left, right, neq_bool, neq, neq_utf8)
+}
+
+/// Perform `left < right` operation on two (dynamic) [`Array`]s.
+///
+/// Only when two arrays are of the same type the comparison will happen 
otherwise it will err
+/// with a casting error.
+pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
+    typed_compares!(left, right, lt_bool, lt, lt_utf8)
+}
+
+/// Perform `left <= right` operation on two (dynamic) [`Array`]s.
+///
+/// Only when two arrays are of the same type the comparison will happen 
otherwise it will err
+/// with a casting error.
+pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
+    typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8)
+}
+
+/// Perform `left > right` operation on two (dynamic) [`Array`]s.
+///
+/// Only when two arrays are of the same type the comparison will happen 
otherwise it will err
+/// with a casting error.
+pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
+    typed_compares!(left, right, gt_bool, gt, gt_utf8)
+}
+
+/// Perform `left >= right` operation on two (dynamic) [`Array`]s.
+///
+/// Only when two arrays are of the same type the comparison will happen 
otherwise it will err
+/// with a casting error.
+pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
+    typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8)
+}
+
+/// Perform `left == right` operation on two [`PrimitiveArray`]s.
 pub fn eq<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> 
Result<BooleanArray>
 where
     T: ArrowNumericType,
@@ -985,7 +1123,7 @@ where
     return compare_op!(left, right, |a, b| a == b);
 }
 
-/// Perform `left == right` operation on an array and a scalar value.
+/// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar 
value.
 pub fn eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> 
Result<BooleanArray>
 where
     T: ArrowNumericType,
@@ -996,7 +1134,7 @@ where
     return compare_op_scalar!(left, right, |a, b| a == b);
 }
 
-/// Perform `left != right` operation on two arrays.
+/// Perform `left != right` operation on two [`PrimitiveArray`]s.
 pub fn neq<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> 
Result<BooleanArray>
 where
     T: ArrowNumericType,
@@ -1007,7 +1145,7 @@ where
     return compare_op!(left, right, |a, b| a != b);
 }
 
-/// Perform `left != right` operation on an array and a scalar value.
+/// Perform `left != right` operation on a [`PrimitiveArray`] and a scalar 
value.
 pub fn neq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> 
Result<BooleanArray>
 where
     T: ArrowNumericType,
@@ -1018,7 +1156,7 @@ where
     return compare_op_scalar!(left, right, |a, b| a != b);
 }
 
-/// Perform `left < right` operation on two arrays. Null values are less than 
non-null
+/// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values 
are less than non-null
 /// values.
 pub fn lt<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> 
Result<BooleanArray>
 where
@@ -1030,7 +1168,7 @@ where
     return compare_op!(left, right, |a, b| a < b);
 }
 
-/// Perform `left < right` operation on an array and a scalar value.
+/// Perform `left < right` operation on a [`PrimitiveArray`] and a scalar 
value.
 /// Null values are less than non-null values.
 pub fn lt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> 
Result<BooleanArray>
 where
@@ -1042,7 +1180,7 @@ where
     return compare_op_scalar!(left, right, |a, b| a < b);
 }
 
-/// Perform `left <= right` operation on two arrays. Null values are less than 
non-null
+/// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values 
are less than non-null
 /// values.
 pub fn lt_eq<T>(
     left: &PrimitiveArray<T>,
@@ -1057,7 +1195,7 @@ where
     return compare_op!(left, right, |a, b| a <= b);
 }
 
-/// Perform `left <= right` operation on an array and a scalar value.
+/// Perform `left <= right` operation on a [`PrimitiveArray`] and a scalar 
value.
 /// Null values are less than non-null values.
 pub fn lt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> 
Result<BooleanArray>
 where
@@ -1069,7 +1207,7 @@ where
     return compare_op_scalar!(left, right, |a, b| a <= b);
 }
 
-/// Perform `left > right` operation on two arrays. Non-null values are 
greater than null
+/// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null 
values are greater than null
 /// values.
 pub fn gt<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> 
Result<BooleanArray>
 where
@@ -1081,7 +1219,7 @@ where
     return compare_op!(left, right, |a, b| a > b);
 }
 
-/// Perform `left > right` operation on an array and a scalar value.
+/// Perform `left > right` operation on a [`PrimitiveArray`] and a scalar 
value.
 /// Non-null values are greater than null values.
 pub fn gt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> 
Result<BooleanArray>
 where
@@ -1093,7 +1231,7 @@ where
     return compare_op_scalar!(left, right, |a, b| a > b);
 }
 
-/// Perform `left >= right` operation on two arrays. Non-null values are 
greater than null
+/// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null 
values are greater than null
 /// values.
 pub fn gt_eq<T>(
     left: &PrimitiveArray<T>,
@@ -1108,7 +1246,7 @@ where
     return compare_op!(left, right, |a, b| a >= b);
 }
 
-/// Perform `left >= right` operation on an array and a scalar value.
+/// Perform `left >= right` operation on a [`PrimitiveArray`] and a scalar 
value.
 /// Non-null values are greater than null values.
 pub fn gt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> 
Result<BooleanArray>
 where
@@ -1260,11 +1398,17 @@ mod tests {
     /// `EXPECTED` can be either `Vec<bool>` or `Vec<Option<bool>>`.
     /// The main reason for this macro is that inputs and outputs align nicely 
after `cargo fmt`.
     macro_rules! cmp_i64 {
-        ($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => {
+        ($KERNEL:ident, $DYN_KERNEL:ident, $A_VEC:expr, $B_VEC:expr, 
$EXPECTED:expr) => {
             let a = Int64Array::from($A_VEC);
             let b = Int64Array::from($B_VEC);
             let c = $KERNEL(&a, &b).unwrap();
             assert_eq!(BooleanArray::from($EXPECTED), c);
+
+            // slice and test if the dynamic array works
+            let a = a.slice(0, a.len());
+            let b = b.slice(0, b.len());
+            let c = $DYN_KERNEL(a.as_ref(), b.as_ref()).unwrap();
+            assert_eq!(BooleanArray::from($EXPECTED), c);
         };
     }
 
@@ -1284,6 +1428,7 @@ mod tests {
     fn test_primitive_array_eq() {
         cmp_i64!(
             eq,
+            eq_dyn,
             vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
             vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
             vec![false, false, true, false, false, false, false, true, false, 
false]
@@ -1330,6 +1475,7 @@ mod tests {
     fn test_primitive_array_neq() {
         cmp_i64!(
             neq,
+            neq_dyn,
             vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
             vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
             vec![true, true, false, true, true, true, true, false, true, true]
@@ -1479,6 +1625,7 @@ mod tests {
     fn test_primitive_array_lt() {
         cmp_i64!(
             lt,
+            lt_dyn,
             vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
             vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
             vec![false, false, false, true, true, false, false, false, true, 
true]
@@ -1499,6 +1646,7 @@ mod tests {
     fn test_primitive_array_lt_nulls() {
         cmp_i64!(
             lt,
+            lt_dyn,
             vec![None, None, Some(1), Some(1), None, None, Some(2), Some(2),],
             vec![None, Some(1), None, Some(1), None, Some(3), None, Some(3),],
             vec![None, None, None, Some(false), None, None, None, Some(true)]
@@ -1519,6 +1667,7 @@ mod tests {
     fn test_primitive_array_lt_eq() {
         cmp_i64!(
             lt_eq,
+            lt_eq_dyn,
             vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
             vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
             vec![false, false, true, true, true, false, false, true, true, 
true]
@@ -1539,6 +1688,7 @@ mod tests {
     fn test_primitive_array_lt_eq_nulls() {
         cmp_i64!(
             lt_eq,
+            lt_eq_dyn,
             vec![None, None, Some(1), None, None, Some(1), None, None, 
Some(1)],
             vec![None, Some(1), Some(0), None, Some(1), Some(2), None, None, 
Some(3)],
             vec![None, None, Some(false), None, None, Some(true), None, None, 
Some(true)]
@@ -1559,6 +1709,7 @@ mod tests {
     fn test_primitive_array_gt() {
         cmp_i64!(
             gt,
+            gt_dyn,
             vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
             vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
             vec![true, true, false, false, false, true, true, false, false, 
false]
@@ -1579,6 +1730,7 @@ mod tests {
     fn test_primitive_array_gt_nulls() {
         cmp_i64!(
             gt,
+            gt_dyn,
             vec![None, None, Some(1), None, None, Some(2), None, None, 
Some(3)],
             vec![None, Some(1), Some(1), None, Some(1), Some(1), None, 
Some(1), Some(1)],
             vec![None, None, Some(false), None, None, Some(true), None, None, 
Some(true)]
@@ -1599,6 +1751,7 @@ mod tests {
     fn test_primitive_array_gt_eq() {
         cmp_i64!(
             gt_eq,
+            gt_eq_dyn,
             vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
             vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10],
             vec![true, true, true, false, false, true, true, true, false, 
false]
@@ -1619,6 +1772,7 @@ mod tests {
     fn test_primitive_array_gt_eq_nulls() {
         cmp_i64!(
             gt_eq,
+            gt_eq_dyn,
             vec![None, None, Some(1), None, Some(1), Some(2), None, None, 
Some(1)],
             vec![None, Some(1), None, None, Some(1), Some(1), None, Some(2), 
Some(2)],
             vec![None, None, None, None, Some(true), Some(true), None, None, 
Some(false)]

Reply via email to