alamb commented on a change in pull request #1074:
URL: https://github.com/apache/arrow-rs/pull/1074#discussion_r775243362



##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +900,261 @@ pub fn gt_eq_utf8_scalar<OffsetSize: 
StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_compare_scalar {
+    ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+        let right: i128 = $RIGHT.try_into().map_err(|_| {
+            ArrowError::ComputeError(String::from("Can not convert scalar to 
i128"))
+        })?;
+        match $LEFT.data_type() {
+            DataType::Int8 => {
+                let right: i8 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left = as_primitive_array::<Int8Type>($LEFT);
+                $OP::<Int8Type>(left, right)
+            }
+            DataType::Int16 => {
+                let right: i16 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left = as_primitive_array::<Int16Type>($LEFT);
+                $OP::<Int16Type>(left, right)
+            }
+            DataType::Int32 => {
+                let right: i32 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left = as_primitive_array::<Int32Type>($LEFT);
+                $OP::<Int32Type>(left, right)
+            }
+            DataType::Int64 => {
+                let right: i64 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left = as_primitive_array::<Int64Type>($LEFT);
+                $OP::<Int64Type>(left, right)
+            }
+            DataType::UInt8 => {
+                let right: u8 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left = as_primitive_array::<UInt8Type>($LEFT);
+                $OP::<UInt8Type>(left, right)
+            }
+            DataType::UInt16 => {
+                let right: u16 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left = as_primitive_array::<UInt16Type>($LEFT);
+                $OP::<UInt16Type>(left, right)
+            }
+            DataType::UInt32 => {
+                let right: u32 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left = as_primitive_array::<UInt32Type>($LEFT);
+                $OP::<UInt32Type>(left, right)
+            }
+            DataType::UInt64 => {
+                let right: u64 = right.try_into().map_err(|_| {
+                    ArrowError::ComputeError(String::from(
+                        "Can not convert scalar to i128",
+                    ))
+                })?;
+                let left = as_primitive_array::<UInt64Type>($LEFT);
+                $OP::<UInt64Type>(left, right)
+            }
+            _ => Err(ArrowError::ComputeError(String::from(
+                "Unsupported data type",
+            ))),
+        }
+    }};
+    ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{
+        let right: i128 = $RIGHT.try_into().map_err(|_| {
+            ArrowError::ComputeError(String::from("Can not convert scalar to 
i128"))
+        })?;
+        match $KT.as_ref() {
+            DataType::UInt8 => {
+                let left = as_dictionary_array::<UInt8Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            DataType::UInt16 => {
+                let left = as_dictionary_array::<UInt16Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            DataType::UInt32 => {
+                let left = as_dictionary_array::<UInt32Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            DataType::UInt64 => {
+                let left = as_dictionary_array::<UInt64Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            DataType::Int8 => {
+                let left = as_dictionary_array::<Int8Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            DataType::Int16 => {
+                let left = as_dictionary_array::<Int16Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            DataType::Int32 => {
+                let left = as_dictionary_array::<Int32Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            DataType::Int64 => {
+                let left = as_dictionary_array::<Int64Type>($LEFT);
+                unpack_dict_comparison(
+                    left,
+                    dyn_compare_scalar!(left.values(), right, $OP)?,
+                )
+            }
+            _ => Err(ArrowError::ComputeError(String::from("Unknown key 
type"))),
+        }
+    }};
+}
+
+macro_rules! dyn_compare_utf8_scalar {
+    ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{
+        match $KT.as_ref() {
+            DataType::UInt8 => {
+                let left = as_dictionary_array::<UInt8Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            DataType::UInt16 => {
+                let left = as_dictionary_array::<UInt16Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            DataType::UInt32 => {
+                let left = as_dictionary_array::<UInt32Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            DataType::UInt64 => {
+                let left = as_dictionary_array::<UInt64Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            DataType::Int8 => {
+                let left = as_dictionary_array::<Int8Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            DataType::Int16 => {
+                let left = as_dictionary_array::<Int16Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            DataType::Int32 => {
+                let left = as_dictionary_array::<Int32Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            DataType::Int64 => {
+                let left = as_dictionary_array::<Int64Type>($LEFT);
+                let values = as_string_array(left.values());
+                unpack_dict_comparison(left, $OP(values, $RIGHT)?)
+            }
+            _ => Err(ArrowError::ComputeError(String::from("Unknown key 
type"))),
+        }
+    }};
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive 
values
+pub fn eq_dyn_scalar<T>(left: Arc<dyn Array>, right: T) -> Result<BooleanArray>
+where
+    T: TryInto<i128> + Copy + std::fmt::Debug,
+{
+    match left.data_type() {
+        DataType::Dictionary(key_type, _) => {
+            return dyn_compare_scalar!(&left, right, key_type, eq_scalar);
+        }
+        _ => {
+            return dyn_compare_scalar!(&left, right, eq_scalar);
+        }
+    };
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports StringArrays, and DictionaryArrays that have string values
+pub fn eq_dyn_utf8_scalar(left: Arc<dyn Array>, right: &str) -> 
Result<BooleanArray> {
+    match left.data_type() {
+        DataType::Dictionary(key_type, _) => {
+            return dyn_compare_utf8_scalar!(&left, right, key_type, 
eq_utf8_scalar);
+        }
+        _ => {

Review comment:
       this probably wants to match on the `DataType::Utf8` and 
`DataType::LargeUtf8` but otherwise looks good to me

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -2522,4 +2779,60 @@ mod tests {
         regexp_is_match_utf8_scalar,
         vec![true, true, false, false]
     );
+    #[test]
+    fn test_eq_dyn_scalar() {
+        let array = Int32Array::from(vec![6, 7, 8, 8, 10]);
+        let array = Arc::new(array);
+        let a_eq = eq_dyn_scalar(array, 8).unwrap();

Review comment:
       
![200](https://user-images.githubusercontent.com/490673/147409848-bef2ebce-8307-477c-af72-daa79b3519da.gif)
   

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -2522,4 +2779,60 @@ mod tests {
         regexp_is_match_utf8_scalar,
         vec![true, true, false, false]
     );
+    #[test]
+    fn test_eq_dyn_scalar() {
+        let array = Int32Array::from(vec![6, 7, 8, 8, 10]);
+        let array = Arc::new(array);
+        let a_eq = eq_dyn_scalar(array, 8).unwrap();
+        assert_eq!(
+            a_eq,
+            BooleanArray::from(
+                vec![Some(false), Some(false), Some(true), Some(true), 
Some(false)]
+            )
+        );
+    }
+    #[test]
+    fn test_eq_dyn_scalar_with_dict() {
+        let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
+        let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
+        let mut builder = PrimitiveDictionaryBuilder::new(key_builder, 
value_builder);
+        builder.append(123).unwrap();
+        builder.append_null().unwrap();
+        builder.append(23).unwrap();
+        let array = Arc::new(builder.finish());
+        let a_eq = eq_dyn_scalar(array, 123).unwrap();
+        assert_eq!(
+            a_eq,
+            BooleanArray::from(vec![Some(true), None, Some(false)])
+        );
+    }
+    #[test]
+    fn test_eq_dyn_utf8_scalar() {
+        let array = StringArray::from(vec!["abc", "def", "xyz"]);
+        let array = Arc::new(array);
+        let a_eq = eq_dyn_utf8_scalar(array, "xyz").unwrap();

Review comment:
       🎉 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to