viirya commented on a change in pull request #1263:
URL: https://github.com/apache/arrow-rs/pull/1263#discussion_r803923255
##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -2030,6 +2030,271 @@ macro_rules! typed_compares {
}};
}
+macro_rules! typed_dict_cmp {
+ ($LEFT: expr, $RIGHT: expr, $OP: expr, $KT: tt) => {{
+ match ($LEFT.value_type(), $RIGHT.value_type()) {
+ (DataType::Boolean, DataType::Boolean) => {
+ cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP)
+ }
+ (DataType::Int8, DataType::Int8) => {
+ cmp_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP)
+ }
+ (DataType::Int16, DataType::Int16) => {
+ cmp_dict::<$KT, Int16Type, _>($LEFT, $RIGHT, $OP)
+ }
+ (DataType::Int32, DataType::Int32) => {
+ cmp_dict::<$KT, Int32Type, _>($LEFT, $RIGHT, $OP)
+ }
+ (DataType::Int64, DataType::Int64) => {
+ cmp_dict::<$KT, Int64Type, _>($LEFT, $RIGHT, $OP)
+ }
+ (DataType::UInt8, DataType::UInt8) => {
+ cmp_dict::<$KT, UInt8Type, _>($LEFT, $RIGHT, $OP)
+ }
+ (DataType::UInt16, DataType::UInt16) => {
+ cmp_dict::<$KT, UInt16Type, _>($LEFT, $RIGHT, $OP)
+ }
+ (DataType::UInt32, DataType::UInt32) => {
+ cmp_dict::<$KT, UInt32Type, _>($LEFT, $RIGHT, $OP)
+ }
+ (DataType::UInt64, DataType::UInt64) => {
+ cmp_dict::<$KT, UInt64Type, _>($LEFT, $RIGHT, $OP)
+ }
+ (DataType::Utf8, DataType::Utf8) => {
+ cmp_dict_utf8::<$KT, i32, _>($LEFT, $RIGHT, $OP)
+ }
+ (DataType::LargeUtf8, DataType::LargeUtf8) => {
+ cmp_dict_utf8::<$KT, i64, _>($LEFT, $RIGHT, $OP)
+ }
+ (DataType::Binary, DataType::Binary) => {
+ cmp_dict_binary::<$KT, i32, _>($LEFT, $RIGHT, $OP)
+ }
+ (DataType::LargeBinary, DataType::LargeBinary) => {
+ cmp_dict_binary::<$KT, i64, _>($LEFT, $RIGHT, $OP)
+ }
+ (
+ DataType::Timestamp(TimeUnit::Nanosecond, _),
+ DataType::Timestamp(TimeUnit::Nanosecond, _),
+ ) => {
+ cmp_dict::<$KT, TimestampNanosecondType, _>($LEFT, $RIGHT, $OP)
+ }
+ (
+ DataType::Timestamp(TimeUnit::Microsecond, _),
+ DataType::Timestamp(TimeUnit::Microsecond, _),
+ ) => {
+ cmp_dict::<$KT, TimestampMicrosecondType, _>($LEFT, $RIGHT,
$OP)
+ }
+ (
+ DataType::Timestamp(TimeUnit::Millisecond, _),
+ DataType::Timestamp(TimeUnit::Millisecond, _),
+ ) => {
+ cmp_dict::<$KT, TimestampMillisecondType, _>($LEFT, $RIGHT,
$OP)
+ }
+ (
+ DataType::Timestamp(TimeUnit::Second, _),
+ DataType::Timestamp(TimeUnit::Second, _),
+ ) => {
+ cmp_dict::<$KT, TimestampSecondType, _>($LEFT, $RIGHT, $OP)
+ }
+ (DataType::Date32, DataType::Date32) => {
+ cmp_dict::<$KT, Date32Type, _>($LEFT, $RIGHT, $OP)
+ }
+ (DataType::Date64, DataType::Date64) => {
+ cmp_dict::<$KT, Date64Type, _>($LEFT, $RIGHT, $OP)
+ }
+ (
+ DataType::Interval(IntervalUnit::YearMonth),
+ DataType::Interval(IntervalUnit::YearMonth),
+ ) => {
+ cmp_dict::<$KT, IntervalYearMonthType, _>($LEFT, $RIGHT, $OP)
+ }
+ (
+ DataType::Interval(IntervalUnit::DayTime),
+ DataType::Interval(IntervalUnit::DayTime),
+ ) => {
+ cmp_dict::<$KT, IntervalDayTimeType, _>($LEFT, $RIGHT, $OP)
+ }
+ (
+ DataType::Interval(IntervalUnit::MonthDayNano),
+ DataType::Interval(IntervalUnit::MonthDayNano),
+ ) => {
+ cmp_dict::<$KT, IntervalMonthDayNanoType, _>($LEFT, $RIGHT,
$OP)
+ }
+ (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+ "Comparing dictionary arrays of value type {} is not yet
implemented",
+ t1
+ ))),
+ (t1, t2) => Err(ArrowError::CastError(format!(
+ "Cannot compare two dictionary arrays of different value types
({} and {})",
+ t1, t2
+ ))),
+ }
+ }};
+}
+
+macro_rules! typed_dict_compares {
+ // Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are
`DictionaryArray`
+ ($LEFT: expr, $RIGHT: expr, $OP: expr) => {{
+ match ($LEFT.data_type(), $RIGHT.data_type()) {
+ (DataType::Dictionary(left_key_type, _),
DataType::Dictionary(right_key_type, _))=> {
+ match (left_key_type.as_ref(), right_key_type.as_ref()) {
+ (DataType::Int8, DataType::Int8) => {
+ let left = as_dictionary_array::<Int8Type>($LEFT);
+ let right = as_dictionary_array::<Int8Type>($RIGHT);
+ typed_dict_cmp!(left, right, $OP, Int8Type)
+ }
+ (DataType::Int16, DataType::Int16) => {
+ let left = as_dictionary_array::<Int16Type>($LEFT);
+ let right = as_dictionary_array::<Int16Type>($RIGHT);
+ typed_dict_cmp!(left, right, $OP, Int16Type)
+ }
+ (DataType::Int32, DataType::Int32) => {
+ let left = as_dictionary_array::<Int32Type>($LEFT);
+ let right = as_dictionary_array::<Int32Type>($RIGHT);
+ typed_dict_cmp!(left, right, $OP, Int32Type)
+ }
+ (DataType::Int64, DataType::Int64) => {
+ let left = as_dictionary_array::<Int64Type>($LEFT);
+ let right = as_dictionary_array::<Int64Type>($RIGHT);
+ typed_dict_cmp!(left, right, $OP, Int64Type)
+ }
+ (DataType::UInt8, DataType::UInt8) => {
+ let left = as_dictionary_array::<UInt8Type>($LEFT);
+ let right = as_dictionary_array::<UInt8Type>($RIGHT);
+ typed_dict_cmp!(left, right, $OP, UInt8Type)
+ }
+ (DataType::UInt16, DataType::UInt16) => {
+ let left = as_dictionary_array::<UInt16Type>($LEFT);
+ let right = as_dictionary_array::<UInt16Type>($RIGHT);
+ typed_dict_cmp!(left, right, $OP, UInt16Type)
+ }
+ (DataType::UInt32, DataType::UInt32) => {
+ let left = as_dictionary_array::<UInt32Type>($LEFT);
+ let right = as_dictionary_array::<UInt32Type>($RIGHT);
+ typed_dict_cmp!(left, right, $OP, UInt32Type)
+ }
+ (DataType::UInt64, DataType::UInt64) => {
+ let left = as_dictionary_array::<UInt64Type>($LEFT);
+ let right = as_dictionary_array::<UInt64Type>($RIGHT);
+ typed_dict_cmp!(left, right, $OP, UInt64Type)
+ }
+ (t1, t2) if t1 == t2 =>
Err(ArrowError::NotYetImplemented(format!(
+ "Comparing dictionary arrays of type {} is not yet
implemented",
+ t1
+ ))),
+ (t1, t2) => Err(ArrowError::CastError(format!(
+ "Cannot compare two dictionary arrays of different key
types ({} and {})",
+ t1, t2
+ ))),
+ }
+ }
+ (t1, t2) => Err(ArrowError::CastError(format!(
+ "Cannot compare dictionary array with non-dictionary array ({}
and {})",
+ t1, t2
+ ))),
+ }
+ }};
+}
+
+/// Helper function to perform boolean lambda function on values from two
dictionary arrays, this
+/// version does not attempt to use SIMD.
+macro_rules! compare_dict_op {
+ ($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{
+ if $left.len() != $right.len() {
+ return Err(ArrowError::ComputeError(
+ "Cannot perform comparison operation on arrays of different
length"
+ .to_string(),
+ ));
+ }
+ let left_values =
$left.values().as_any().downcast_ref::<$value_ty>().unwrap();
+ let right_values = $right
+ .values()
+ .as_any()
+ .downcast_ref::<$value_ty>()
+ .unwrap();
+
+ let result = $left
+ .keys()
+ .iter()
+ .zip($right.keys().iter())
+ .map(|(left_key, right_key)| {
+ if let (Some(left_k), Some(right_k)) = (left_key, right_key) {
+ let left_key = left_k.to_usize().expect("Dictionary index
not usize");
+ let right_key =
+ right_k.to_usize().expect("Dictionary index not
usize");
+ unsafe {
+ let left_value = left_values.value_unchecked(left_key);
+ let right_value =
right_values.value_unchecked(right_key);
+ Some($op(left_value, right_value))
+ }
+ } else {
+ None
+ }
+ })
+ .collect();
+
+ Ok(result)
+ }};
+}
+
+/// Perform given operation on two `DictionaryArray`s.
+/// Only when two arrays are of the same type the comparison will happen
otherwise it will err
Review comment:
changed as suggested
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]