alamb commented on a change in pull request #1074:
URL: https://github.com/apache/arrow-rs/pull/1074#discussion_r774535495



##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +899,224 @@ pub fn gt_eq_utf8_scalar<OffsetSize: 
StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_compare_scalar {
+    ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+        let right: i128 = $RIGHT.try_into().map_err(|_| {
+            ArrowError::ComputeError(String::from("Can not convert scalar to 
i128"))
+        })?;
+        match ($LEFT.data_type(), $RIGHT.get_arrow_type()) {
+            // (DataType::Boolean, DataType::Boolean) => {
+            //     typed_cmp_scalar!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+            // }
+            (DataType::Int8, DataType::Int8) => {
+                let right: i8 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left =
+                    $LEFT.as_any().downcast_ref::<Int8Array>().ok_or_else(|| {
+                        ArrowError::CastError(String::from("Left array cannot 
be cast"))
+                    })?;
+                $OP::<Int8Type>(left, right)
+            }
+            (DataType::Int16, DataType::Int16) => {
+                let right: i16 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left =
+                    $LEFT.as_any().downcast_ref::<Int16Array>().ok_or_else(|| {
+                        ArrowError::CastError(String::from("Left array cannot 
be cast"))
+                    })?;
+                $OP::<Int16Type>(left, right)
+            }
+            (DataType::Int32, DataType::Int32) => {
+                let right: i32 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left =
+                    $LEFT.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
+                        ArrowError::CastError(String::from("Left array cannot 
be cast"))
+                    })?;
+                $OP::<Int32Type>(left, right)
+            }
+            (DataType::Int64, DataType::Int64) => {
+                let right: i64 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left =
+                    $LEFT.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
+                        ArrowError::CastError(String::from("Left array cannot 
be cast"))
+                    })?;
+                $OP::<Int64Type>(left, right)
+            }
+            // (DataType::UInt8, DataType::UInt8) => {
+            //     dyn_cmp_scalar!($LEFT, right, UInt8Array, $OP, UInt8Type)
+            // }
+            // (DataType::UInt16, DataType::UInt16) => {
+            //     dyn_cmp_scalar!($LEFT, right, UInt16Array, $OP, UInt16Type)
+            // }
+            // (DataType::UInt32, DataType::UInt32) => {
+            //     dyn_cmp_scalar!($LEFT, right, UInt32Array, $OP, UInt32Type)
+            // }
+            // (DataType::UInt64, DataType::UInt64) => {
+            //     dyn_cmp_scalar!($LEFT, right, UInt64Array, $OP, UInt64Type)
+            // }
+            // (DataType::Float32, DataType::Float32) => {
+            //     dyn_cmp_scalar!($LEFT, right, Float32Array, $OP, 
Float32Type)
+            // }
+            // (DataType::Float64, DataType::Float64) => {
+            //     dyn_cmp_scalar!($LEFT, right, Float64Array, $OP, 
Float64Type)
+            // }
+            // (DataType::Utf8, DataType::Utf8) => {
+            //     let right: i32 = right.try_into().map_err(|_| {
+            //         ArrowError::ComputeError(String::from(
+            //             "Can not convert scalar to i128",
+            //         ))
+            //     })?;
+            //     let left =
+            //         $LEFT
+            //             .as_any()
+            //             .downcast_ref::<StringArray>()
+            //             .ok_or_else(|| {
+            //                 ArrowError::CastError(String::from(
+            //                     "Left array cannot be cast",
+            //                 ))
+            //             })?;
+            //     $OP::<Int32Type>(left, right)
+            // }
+            // (DataType::LargeUtf8, DataType::LargeUtf8) => {
+            //     dyn_cmp_scalar!($LEFT, $RIGHT, LargeStringArray, $OP, i64)
+            // }
+            (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+                "Comparing arrays of type {} is not yet implemented",
+                t1
+            ))),
+            (t1, t2) => Err(ArrowError::CastError(format!(
+                "Cannot compare an array with a scalar of different type ({} 
and {})",
+                t1, t2
+            ))),
+        }
+    }};
+}
+
+pub trait IntoArrowNumericType {
+    // type Arrow: ArrowNumericType<Native = Self>;
+    fn get_arrow_type(&self) -> &DataType;
+}
+
+impl IntoArrowNumericType for i8 {
+    // type Arrow = Int8Type;
+    fn get_arrow_type(&self) -> &DataType {
+        &DataType::Int8
+    }
+}
+
+impl IntoArrowNumericType for i16 {
+    // type Arrow = Int8Type;
+    fn get_arrow_type(&self) -> &DataType {
+        &DataType::Int16
+    }
+}
+
+impl IntoArrowNumericType for i32 {
+    // type Arrow = Int8Type;
+    fn get_arrow_type(&self) -> &DataType {
+        &DataType::Int32
+    }
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive 
values
+pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
+where
+    T: IntoArrowNumericType + TryInto<i128> + Copy + std::fmt::Debug,
+{
+    let result = match left.data_type() {
+        DataType::Dictionary(key_type, _) => match key_type.as_ref() {
+            DataType::UInt8 => {
+                let left = left
+                    .as_any()
+                    .downcast_ref::<DictionaryArray<UInt8Type>>()
+                    .unwrap();

Review comment:
       You might be able to use 
https://docs.rs/arrow/6.4.0/arrow/array/fn.as_dictionary_array.html here too:
   
   ```suggestion
                   let left = 
as_dictionary_array<DictionaryArray<UInt8Type>>()?;
   ```

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +899,224 @@ pub fn gt_eq_utf8_scalar<OffsetSize: 
StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_compare_scalar {
+    ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+        let right: i128 = $RIGHT.try_into().map_err(|_| {
+            ArrowError::ComputeError(String::from("Can not convert scalar to 
i128"))
+        })?;
+        match ($LEFT.data_type(), $RIGHT.get_arrow_type()) {
+            // (DataType::Boolean, DataType::Boolean) => {
+            //     typed_cmp_scalar!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+            // }
+            (DataType::Int8, DataType::Int8) => {
+                let right: i8 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left =
+                    $LEFT.as_any().downcast_ref::<Int8Array>().ok_or_else(|| {
+                        ArrowError::CastError(String::from("Left array cannot 
be cast"))
+                    })?;
+                $OP::<Int8Type>(left, right)
+            }
+            (DataType::Int16, DataType::Int16) => {
+                let right: i16 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left =
+                    $LEFT.as_any().downcast_ref::<Int16Array>().ok_or_else(|| {
+                        ArrowError::CastError(String::from("Left array cannot 
be cast"))
+                    })?;
+                $OP::<Int16Type>(left, right)
+            }
+            (DataType::Int32, DataType::Int32) => {
+                let right: i32 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left =
+                    $LEFT.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
+                        ArrowError::CastError(String::from("Left array cannot 
be cast"))
+                    })?;
+                $OP::<Int32Type>(left, right)
+            }
+            (DataType::Int64, DataType::Int64) => {
+                let right: i64 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left =
+                    $LEFT.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
+                        ArrowError::CastError(String::from("Left array cannot 
be cast"))
+                    })?;
+                $OP::<Int64Type>(left, right)
+            }
+            // (DataType::UInt8, DataType::UInt8) => {
+            //     dyn_cmp_scalar!($LEFT, right, UInt8Array, $OP, UInt8Type)
+            // }
+            // (DataType::UInt16, DataType::UInt16) => {
+            //     dyn_cmp_scalar!($LEFT, right, UInt16Array, $OP, UInt16Type)
+            // }
+            // (DataType::UInt32, DataType::UInt32) => {
+            //     dyn_cmp_scalar!($LEFT, right, UInt32Array, $OP, UInt32Type)
+            // }
+            // (DataType::UInt64, DataType::UInt64) => {
+            //     dyn_cmp_scalar!($LEFT, right, UInt64Array, $OP, UInt64Type)
+            // }
+            // (DataType::Float32, DataType::Float32) => {
+            //     dyn_cmp_scalar!($LEFT, right, Float32Array, $OP, 
Float32Type)
+            // }
+            // (DataType::Float64, DataType::Float64) => {
+            //     dyn_cmp_scalar!($LEFT, right, Float64Array, $OP, 
Float64Type)
+            // }
+            // (DataType::Utf8, DataType::Utf8) => {
+            //     let right: i32 = right.try_into().map_err(|_| {
+            //         ArrowError::ComputeError(String::from(
+            //             "Can not convert scalar to i128",
+            //         ))
+            //     })?;
+            //     let left =
+            //         $LEFT
+            //             .as_any()
+            //             .downcast_ref::<StringArray>()
+            //             .ok_or_else(|| {
+            //                 ArrowError::CastError(String::from(
+            //                     "Left array cannot be cast",
+            //                 ))
+            //             })?;
+            //     $OP::<Int32Type>(left, right)
+            // }
+            // (DataType::LargeUtf8, DataType::LargeUtf8) => {
+            //     dyn_cmp_scalar!($LEFT, $RIGHT, LargeStringArray, $OP, i64)
+            // }
+            (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+                "Comparing arrays of type {} is not yet implemented",
+                t1
+            ))),
+            (t1, t2) => Err(ArrowError::CastError(format!(
+                "Cannot compare an array with a scalar of different type ({} 
and {})",
+                t1, t2
+            ))),
+        }
+    }};
+}
+
+pub trait IntoArrowNumericType {
+    // type Arrow: ArrowNumericType<Native = Self>;
+    fn get_arrow_type(&self) -> &DataType;
+}
+
+impl IntoArrowNumericType for i8 {
+    // type Arrow = Int8Type;
+    fn get_arrow_type(&self) -> &DataType {
+        &DataType::Int8
+    }
+}
+
+impl IntoArrowNumericType for i16 {
+    // type Arrow = Int8Type;
+    fn get_arrow_type(&self) -> &DataType {
+        &DataType::Int16
+    }
+}
+
+impl IntoArrowNumericType for i32 {
+    // type Arrow = Int8Type;
+    fn get_arrow_type(&self) -> &DataType {
+        &DataType::Int32
+    }
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive 
values
+pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
+where
+    T: IntoArrowNumericType + TryInto<i128> + Copy + std::fmt::Debug,

Review comment:
       If `T` is going to have `TryInto<i128>` , I wonder if we still 
`IntoArrowNumericType` at all? I think it may not be necessary any more. 
   
   My thinking is that since `dyn_compare_scalar` converts `right` into `i128` 
immediately there are then conversion rules back to all of the primitive types 
needed to call `eq_scalar`

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +899,224 @@ pub fn gt_eq_utf8_scalar<OffsetSize: 
StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_compare_scalar {
+    ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+        let right: i128 = $RIGHT.try_into().map_err(|_| {
+            ArrowError::ComputeError(String::from("Can not convert scalar to 
i128"))
+        })?;
+        match ($LEFT.data_type(), $RIGHT.get_arrow_type()) {
+            // (DataType::Boolean, DataType::Boolean) => {
+            //     typed_cmp_scalar!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+            // }
+            (DataType::Int8, DataType::Int8) => {
+                let right: i8 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left =
+                    $LEFT.as_any().downcast_ref::<Int8Array>().ok_or_else(|| {
+                        ArrowError::CastError(String::from("Left array cannot 
be cast"))
+                    })?;

Review comment:
       I think you can use 
https://docs.rs/arrow/6.4.0/arrow/array/fn.as_primitive_array.html to simplify 
this
   
   So something like
   
   ```suggestion
                       as_primitive_array::<Int8Array>($LEFT)?;
   ```

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +898,126 @@ pub fn gt_eq_utf8_scalar<OffsetSize: 
StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_cmp_scalar {
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
+        let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Left array cannot be cast to {}",
+                type_name::<$T>()
+            ))
+        })?;
+        let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Right array cannot be cast to {}",
+                type_name::<$T>(),
+            ))
+        })?;
+        $OP::<$TT>(left, right)
+    }};
+}
+
+macro_rules! dyn_compare_scalar {
+    ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+        let right = $RIGHT.try_into().map_err(|_| {
+            ArrowError::ComputeError(format!(
+                "Can not convert scalar {:?} to i128",
+                $RIGHT
+            ))
+        });
+        match ($LEFT.data_type(), $RIGHT::Arrow) {
+            // (DataType::Boolean, DataType::Boolean) => {
+            //     typed_cmp_scalar!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+            // }
+            (DataType::Int8, DataType::Int8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int8Array, $OP, Int8Type)
+            }
+            (DataType::Int16, DataType::Int16) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int16Array, $OP, Int16Type)
+            }
+            (DataType::Int32, DataType::Int32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int32Array, $OP, Int32Type)
+            }
+            (DataType::Int64, DataType::Int64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int64Array, $OP, Int64Type)
+            }
+            (DataType::UInt8, DataType::UInt8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt8Array, $OP, UInt8Type)
+            }
+            (DataType::UInt16, DataType::UInt16) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt16Array, $OP, UInt16Type)
+            }
+            (DataType::UInt32, DataType::UInt32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt32Array, $OP, UInt32Type)
+            }
+            (DataType::UInt64, DataType::UInt64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt64Array, $OP, UInt64Type)
+            }
+            (DataType::Float32, DataType::Float32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Float32Array, $OP, Float32Type)
+            }
+            (DataType::Float64, DataType::Float64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Float64Array, $OP, Float64Type)
+            }
+            (DataType::Utf8, DataType::Utf8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, StringArray, $OP, i32)
+            }
+            (DataType::LargeUtf8, DataType::LargeUtf8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, LargeStringArray, $OP, i64)
+            }
+            (DataType::Dictionary(DataType::UInt8, DataType::UInt8), 
DataType::UInt8) => {
+                let values_comp =
+                    typed_compare_scalar!($LEFT.values(), $RIGHT, eq_scalar);
+                unpack_dict_comparison($LEFT, values_comp)
+            }
+            (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+                "Comparing arrays of type {} is not yet implemented",
+                t1
+            ))),
+            (t1, t2) => Err(ArrowError::CastError(format!(
+                "Cannot compare an array with a scalar of different type ({} 
and {})",
+                t1, t2
+            ))),
+        }
+    }};
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimtiveArrays, and DictionaryArrays that have primitive 
values
+pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
+where
+    T: IntoArrowNumericType + TryInto<i128> + Copy + std::fmt::Debug,
+{
+    dyn_compare_scalar!(left, right, eq_scalar)
+}
+
+/// unpacks the results of comparing left.values (as a boolean)
+///
+/// TODO add example
+///
+fn unpack_dict_comparison<K>(
+    left: &DictionaryArray<K>,
+    dict_comparison: BooleanArray,
+) -> Result<BooleanArray>
+where
+    K: ArrowNumericType,
+{
+    assert_eq!(dict_comparison.len(), left.values().len());

Review comment:
       I think it is an invariant (namely that `left` is the dictionary and 
`dict_comparison` is the result of comparing those values). 
   
   Perhaps we could rename the `left` parameter to `dict` to make this clearer




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to