alamb commented on a change in pull request #1074:
URL: https://github.com/apache/arrow-rs/pull/1074#discussion_r774535495
##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +899,224 @@ pub fn gt_eq_utf8_scalar<OffsetSize:
StringOffsetSizeTrait>(
compare_op_scalar!(left, right, |a, b| a >= b)
}
+macro_rules! dyn_compare_scalar {
+ ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+ let right: i128 = $RIGHT.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from("Can not convert scalar to
i128"))
+ })?;
+ match ($LEFT.data_type(), $RIGHT.get_arrow_type()) {
+ // (DataType::Boolean, DataType::Boolean) => {
+ // typed_cmp_scalar!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+ // }
+ (DataType::Int8, DataType::Int8) => {
+ let right: i8 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to i128",
+ ))
+ })?;
+ let left =
+ $LEFT.as_any().downcast_ref::<Int8Array>().ok_or_else(|| {
+ ArrowError::CastError(String::from("Left array cannot
be cast"))
+ })?;
+ $OP::<Int8Type>(left, right)
+ }
+ (DataType::Int16, DataType::Int16) => {
+ let right: i16 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to i128",
+ ))
+ })?;
+ let left =
+ $LEFT.as_any().downcast_ref::<Int16Array>().ok_or_else(|| {
+ ArrowError::CastError(String::from("Left array cannot
be cast"))
+ })?;
+ $OP::<Int16Type>(left, right)
+ }
+ (DataType::Int32, DataType::Int32) => {
+ let right: i32 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to i128",
+ ))
+ })?;
+ let left =
+ $LEFT.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
+ ArrowError::CastError(String::from("Left array cannot
be cast"))
+ })?;
+ $OP::<Int32Type>(left, right)
+ }
+ (DataType::Int64, DataType::Int64) => {
+ let right: i64 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to i128",
+ ))
+ })?;
+ let left =
+ $LEFT.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
+ ArrowError::CastError(String::from("Left array cannot
be cast"))
+ })?;
+ $OP::<Int64Type>(left, right)
+ }
+ // (DataType::UInt8, DataType::UInt8) => {
+ // dyn_cmp_scalar!($LEFT, right, UInt8Array, $OP, UInt8Type)
+ // }
+ // (DataType::UInt16, DataType::UInt16) => {
+ // dyn_cmp_scalar!($LEFT, right, UInt16Array, $OP, UInt16Type)
+ // }
+ // (DataType::UInt32, DataType::UInt32) => {
+ // dyn_cmp_scalar!($LEFT, right, UInt32Array, $OP, UInt32Type)
+ // }
+ // (DataType::UInt64, DataType::UInt64) => {
+ // dyn_cmp_scalar!($LEFT, right, UInt64Array, $OP, UInt64Type)
+ // }
+ // (DataType::Float32, DataType::Float32) => {
+ // dyn_cmp_scalar!($LEFT, right, Float32Array, $OP,
Float32Type)
+ // }
+ // (DataType::Float64, DataType::Float64) => {
+ // dyn_cmp_scalar!($LEFT, right, Float64Array, $OP,
Float64Type)
+ // }
+ // (DataType::Utf8, DataType::Utf8) => {
+ // let right: i32 = right.try_into().map_err(|_| {
+ // ArrowError::ComputeError(String::from(
+ // "Can not convert scalar to i128",
+ // ))
+ // })?;
+ // let left =
+ // $LEFT
+ // .as_any()
+ // .downcast_ref::<StringArray>()
+ // .ok_or_else(|| {
+ // ArrowError::CastError(String::from(
+ // "Left array cannot be cast",
+ // ))
+ // })?;
+ // $OP::<Int32Type>(left, right)
+ // }
+ // (DataType::LargeUtf8, DataType::LargeUtf8) => {
+ // dyn_cmp_scalar!($LEFT, $RIGHT, LargeStringArray, $OP, i64)
+ // }
+ (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+ "Comparing arrays of type {} is not yet implemented",
+ t1
+ ))),
+ (t1, t2) => Err(ArrowError::CastError(format!(
+ "Cannot compare an array with a scalar of different type ({}
and {})",
+ t1, t2
+ ))),
+ }
+ }};
+}
+
+pub trait IntoArrowNumericType {
+ // type Arrow: ArrowNumericType<Native = Self>;
+ fn get_arrow_type(&self) -> &DataType;
+}
+
+impl IntoArrowNumericType for i8 {
+ // type Arrow = Int8Type;
+ fn get_arrow_type(&self) -> &DataType {
+ &DataType::Int8
+ }
+}
+
+impl IntoArrowNumericType for i16 {
+ // type Arrow = Int8Type;
+ fn get_arrow_type(&self) -> &DataType {
+ &DataType::Int16
+ }
+}
+
+impl IntoArrowNumericType for i32 {
+ // type Arrow = Int8Type;
+ fn get_arrow_type(&self) -> &DataType {
+ &DataType::Int32
+ }
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive
values
+pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
+where
+ T: IntoArrowNumericType + TryInto<i128> + Copy + std::fmt::Debug,
+{
+ let result = match left.data_type() {
+ DataType::Dictionary(key_type, _) => match key_type.as_ref() {
+ DataType::UInt8 => {
+ let left = left
+ .as_any()
+ .downcast_ref::<DictionaryArray<UInt8Type>>()
+ .unwrap();
Review comment:
You might be able to use
https://docs.rs/arrow/6.4.0/arrow/array/fn.as_dictionary_array.html here too:
```suggestion
let left =
as_dictionary_array<DictionaryArray<UInt8Type>>()?;
```
##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +899,224 @@ pub fn gt_eq_utf8_scalar<OffsetSize:
StringOffsetSizeTrait>(
compare_op_scalar!(left, right, |a, b| a >= b)
}
+macro_rules! dyn_compare_scalar {
+ ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+ let right: i128 = $RIGHT.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from("Can not convert scalar to
i128"))
+ })?;
+ match ($LEFT.data_type(), $RIGHT.get_arrow_type()) {
+ // (DataType::Boolean, DataType::Boolean) => {
+ // typed_cmp_scalar!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+ // }
+ (DataType::Int8, DataType::Int8) => {
+ let right: i8 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to i128",
+ ))
+ })?;
+ let left =
+ $LEFT.as_any().downcast_ref::<Int8Array>().ok_or_else(|| {
+ ArrowError::CastError(String::from("Left array cannot
be cast"))
+ })?;
+ $OP::<Int8Type>(left, right)
+ }
+ (DataType::Int16, DataType::Int16) => {
+ let right: i16 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to i128",
+ ))
+ })?;
+ let left =
+ $LEFT.as_any().downcast_ref::<Int16Array>().ok_or_else(|| {
+ ArrowError::CastError(String::from("Left array cannot
be cast"))
+ })?;
+ $OP::<Int16Type>(left, right)
+ }
+ (DataType::Int32, DataType::Int32) => {
+ let right: i32 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to i128",
+ ))
+ })?;
+ let left =
+ $LEFT.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
+ ArrowError::CastError(String::from("Left array cannot
be cast"))
+ })?;
+ $OP::<Int32Type>(left, right)
+ }
+ (DataType::Int64, DataType::Int64) => {
+ let right: i64 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to i128",
+ ))
+ })?;
+ let left =
+ $LEFT.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
+ ArrowError::CastError(String::from("Left array cannot
be cast"))
+ })?;
+ $OP::<Int64Type>(left, right)
+ }
+ // (DataType::UInt8, DataType::UInt8) => {
+ // dyn_cmp_scalar!($LEFT, right, UInt8Array, $OP, UInt8Type)
+ // }
+ // (DataType::UInt16, DataType::UInt16) => {
+ // dyn_cmp_scalar!($LEFT, right, UInt16Array, $OP, UInt16Type)
+ // }
+ // (DataType::UInt32, DataType::UInt32) => {
+ // dyn_cmp_scalar!($LEFT, right, UInt32Array, $OP, UInt32Type)
+ // }
+ // (DataType::UInt64, DataType::UInt64) => {
+ // dyn_cmp_scalar!($LEFT, right, UInt64Array, $OP, UInt64Type)
+ // }
+ // (DataType::Float32, DataType::Float32) => {
+ // dyn_cmp_scalar!($LEFT, right, Float32Array, $OP,
Float32Type)
+ // }
+ // (DataType::Float64, DataType::Float64) => {
+ // dyn_cmp_scalar!($LEFT, right, Float64Array, $OP,
Float64Type)
+ // }
+ // (DataType::Utf8, DataType::Utf8) => {
+ // let right: i32 = right.try_into().map_err(|_| {
+ // ArrowError::ComputeError(String::from(
+ // "Can not convert scalar to i128",
+ // ))
+ // })?;
+ // let left =
+ // $LEFT
+ // .as_any()
+ // .downcast_ref::<StringArray>()
+ // .ok_or_else(|| {
+ // ArrowError::CastError(String::from(
+ // "Left array cannot be cast",
+ // ))
+ // })?;
+ // $OP::<Int32Type>(left, right)
+ // }
+ // (DataType::LargeUtf8, DataType::LargeUtf8) => {
+ // dyn_cmp_scalar!($LEFT, $RIGHT, LargeStringArray, $OP, i64)
+ // }
+ (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+ "Comparing arrays of type {} is not yet implemented",
+ t1
+ ))),
+ (t1, t2) => Err(ArrowError::CastError(format!(
+ "Cannot compare an array with a scalar of different type ({}
and {})",
+ t1, t2
+ ))),
+ }
+ }};
+}
+
+pub trait IntoArrowNumericType {
+ // type Arrow: ArrowNumericType<Native = Self>;
+ fn get_arrow_type(&self) -> &DataType;
+}
+
+impl IntoArrowNumericType for i8 {
+ // type Arrow = Int8Type;
+ fn get_arrow_type(&self) -> &DataType {
+ &DataType::Int8
+ }
+}
+
+impl IntoArrowNumericType for i16 {
+ // type Arrow = Int8Type;
+ fn get_arrow_type(&self) -> &DataType {
+ &DataType::Int16
+ }
+}
+
+impl IntoArrowNumericType for i32 {
+ // type Arrow = Int8Type;
+ fn get_arrow_type(&self) -> &DataType {
+ &DataType::Int32
+ }
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive
values
+pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
+where
+ T: IntoArrowNumericType + TryInto<i128> + Copy + std::fmt::Debug,
Review comment:
If `T` is going to have `TryInto<i128>` , I wonder if we still
`IntoArrowNumericType` at all? I think it may not be necessary any more.
My thinking is that since `dyn_compare_scalar` converts `right` into `i128`
immediately there are then conversion rules back to all of the primitive types
needed to call `eq_scalar`
##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +899,224 @@ pub fn gt_eq_utf8_scalar<OffsetSize:
StringOffsetSizeTrait>(
compare_op_scalar!(left, right, |a, b| a >= b)
}
+macro_rules! dyn_compare_scalar {
+ ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+ let right: i128 = $RIGHT.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from("Can not convert scalar to
i128"))
+ })?;
+ match ($LEFT.data_type(), $RIGHT.get_arrow_type()) {
+ // (DataType::Boolean, DataType::Boolean) => {
+ // typed_cmp_scalar!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+ // }
+ (DataType::Int8, DataType::Int8) => {
+ let right: i8 = right.try_into().map_err(|_| {
+ ArrowError::ComputeError(String::from(
+ "Can not convert scalar to i128",
+ ))
+ })?;
+ let left =
+ $LEFT.as_any().downcast_ref::<Int8Array>().ok_or_else(|| {
+ ArrowError::CastError(String::from("Left array cannot
be cast"))
+ })?;
Review comment:
I think you can use
https://docs.rs/arrow/6.4.0/arrow/array/fn.as_primitive_array.html to simplify
this
So something like
```suggestion
as_primitive_array::<Int8Array>($LEFT)?;
```
##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +898,126 @@ pub fn gt_eq_utf8_scalar<OffsetSize:
StringOffsetSizeTrait>(
compare_op_scalar!(left, right, |a, b| a >= b)
}
+macro_rules! dyn_cmp_scalar {
+ ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
+ let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+ ArrowError::CastError(format!(
+ "Left array cannot be cast to {}",
+ type_name::<$T>()
+ ))
+ })?;
+ let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+ ArrowError::CastError(format!(
+ "Right array cannot be cast to {}",
+ type_name::<$T>(),
+ ))
+ })?;
+ $OP::<$TT>(left, right)
+ }};
+}
+
+macro_rules! dyn_compare_scalar {
+ ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+ let right = $RIGHT.try_into().map_err(|_| {
+ ArrowError::ComputeError(format!(
+ "Can not convert scalar {:?} to i128",
+ $RIGHT
+ ))
+ });
+ match ($LEFT.data_type(), $RIGHT::Arrow) {
+ // (DataType::Boolean, DataType::Boolean) => {
+ // typed_cmp_scalar!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+ // }
+ (DataType::Int8, DataType::Int8) => {
+ dyn_cmp_scalar!($LEFT, $RIGHT, Int8Array, $OP, Int8Type)
+ }
+ (DataType::Int16, DataType::Int16) => {
+ dyn_cmp_scalar!($LEFT, $RIGHT, Int16Array, $OP, Int16Type)
+ }
+ (DataType::Int32, DataType::Int32) => {
+ dyn_cmp_scalar!($LEFT, $RIGHT, Int32Array, $OP, Int32Type)
+ }
+ (DataType::Int64, DataType::Int64) => {
+ dyn_cmp_scalar!($LEFT, $RIGHT, Int64Array, $OP, Int64Type)
+ }
+ (DataType::UInt8, DataType::UInt8) => {
+ dyn_cmp_scalar!($LEFT, $RIGHT, UInt8Array, $OP, UInt8Type)
+ }
+ (DataType::UInt16, DataType::UInt16) => {
+ dyn_cmp_scalar!($LEFT, $RIGHT, UInt16Array, $OP, UInt16Type)
+ }
+ (DataType::UInt32, DataType::UInt32) => {
+ dyn_cmp_scalar!($LEFT, $RIGHT, UInt32Array, $OP, UInt32Type)
+ }
+ (DataType::UInt64, DataType::UInt64) => {
+ dyn_cmp_scalar!($LEFT, $RIGHT, UInt64Array, $OP, UInt64Type)
+ }
+ (DataType::Float32, DataType::Float32) => {
+ dyn_cmp_scalar!($LEFT, $RIGHT, Float32Array, $OP, Float32Type)
+ }
+ (DataType::Float64, DataType::Float64) => {
+ dyn_cmp_scalar!($LEFT, $RIGHT, Float64Array, $OP, Float64Type)
+ }
+ (DataType::Utf8, DataType::Utf8) => {
+ dyn_cmp_scalar!($LEFT, $RIGHT, StringArray, $OP, i32)
+ }
+ (DataType::LargeUtf8, DataType::LargeUtf8) => {
+ dyn_cmp_scalar!($LEFT, $RIGHT, LargeStringArray, $OP, i64)
+ }
+ (DataType::Dictionary(DataType::UInt8, DataType::UInt8),
DataType::UInt8) => {
+ let values_comp =
+ typed_compare_scalar!($LEFT.values(), $RIGHT, eq_scalar);
+ unpack_dict_comparison($LEFT, values_comp)
+ }
+ (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+ "Comparing arrays of type {} is not yet implemented",
+ t1
+ ))),
+ (t1, t2) => Err(ArrowError::CastError(format!(
+ "Cannot compare an array with a scalar of different type ({}
and {})",
+ t1, t2
+ ))),
+ }
+ }};
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimtiveArrays, and DictionaryArrays that have primitive
values
+pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
+where
+ T: IntoArrowNumericType + TryInto<i128> + Copy + std::fmt::Debug,
+{
+ dyn_compare_scalar!(left, right, eq_scalar)
+}
+
+/// unpacks the results of comparing left.values (as a boolean)
+///
+/// TODO add example
+///
+fn unpack_dict_comparison<K>(
+ left: &DictionaryArray<K>,
+ dict_comparison: BooleanArray,
+) -> Result<BooleanArray>
+where
+ K: ArrowNumericType,
+{
+ assert_eq!(dict_comparison.len(), left.values().len());
Review comment:
I think it is an invariant (namely that `left` is the dictionary and
`dict_comparison` is the result of comparing those values).
Perhaps we could rename the `left` parameter to `dict` to make this clearer
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]